Source code for manuscript.detectors._east

import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import cv2
import numpy as np
import onnxruntime as ort

from manuscript.api.base import BaseModel

from ...data import Block, Line, Page, Word
from ...utils import read_image, organize_page
from .lanms import locality_aware_nms
from .utils import (
    decode_quads_from_maps,
    expand_boxes,
)

# Optional imports for training (not needed for inference)
try:
    import torch
    from torch.utils.data import ConcatDataset
    from .dataset import EASTDataset
    from .east import EASTModel
    from .train_utils import _run_training

    _TORCH_AVAILABLE = True
except ImportError:
    torch = None
    ConcatDataset = None
    EASTDataset = None
    EASTModel = None
    _run_training = None
    _TORCH_AVAILABLE = False


[docs] class EAST(BaseModel): """ Initialize EAST text detector with ONNX Runtime. Parameters ---------- weights : str or Path, optional Path or identifier for ONNX model weights. Supports: - Local file path: ``"path/to/model.onnx"`` - HTTP/HTTPS URL: ``"https://example.com/model.onnx"`` - GitHub release: ``"github://owner/repo/tag/file.onnx"`` - Google Drive: ``"gdrive:FILE_ID"`` - Preset name: ``"east_50_g1"`` - ``None``: auto-downloads default preset (east_50_g1) device : str, optional Compute device: ``"cuda"``, ``"coreml"``, or ``"cpu"``. If None, automatically selects CPU. For GPU/CoreML acceleration: - CUDA (NVIDIA): ``pip install onnxruntime-gpu`` - CoreML (Apple Silicon M1/M2/M3): ``pip install onnxruntime-silicon`` Default is ``None`` (CPU). target_size : int, optional Input image size for inference. Images are resized to ``(target_size, target_size)``. Default is 1280. expand_ratio_w : float, optional Horizontal expansion factor applied to detected boxes after NMS. Default is 0.7. expand_ratio_h : float, optional Vertical expansion factor applied to detected boxes after NMS. Default is 0.7. expand_power : float, optional Power for non-linear box expansion. Controls how expansion scales with box size. - 1.0 = linear (small and large boxes expand equally) - <1.0 = small boxes expand more (e.g., 0.5, recommended for character-level detection) - >1.0 = large boxes expand more Default is 0.5. score_thresh : float, optional Confidence threshold for selecting candidate detections before NMS. Default is 0.7. iou_threshold : float, optional IoU threshold for locality-aware NMS merging phase. Default is 0.2. iou_threshold_standard : float, optional IoU threshold for standard NMS after locality-aware merging. If None, uses the same value as iou_threshold. Default is None. score_geo_scale : float, optional Scale factor for decoding geometry/score maps. Default is 0.25. quantization : int, optional Quantization resolution for point coordinates during decoding. Default is 2. axis_aligned_output : bool, optional If True, outputs axis-aligned rectangles instead of original quads. Default is True. remove_area_anomalies : bool, optional If True, removes quads with extremely large area relative to the distribution. Default is False. anomaly_sigma_threshold : float, optional Sigma threshold for anomaly area filtering. Default is 5.0. anomaly_min_box_count : int, optional Minimum number of boxes required before anomaly filtering. Default is 30. use_tta : bool, optional Enable Test-Time Augmentation (TTA). When enabled, inference is run on both the original and horizontally flipped image, and results are merged. This can improve detection of partially visible or edge text. Default is False. tta_iou_thresh : float, optional IoU threshold for merging boxes from original and flipped images during TTA. Boxes with IoU > threshold are considered matches and merged. Default is 0.1. Notes ----- The class provides two main public methods: - ``predict`` — run inference on a single image and return detections. - ``train`` — high-level training entrypoint to train an EAST model on custom datasets. The detector uses ONNX Runtime for fast inference on CPU and GPU. For GPU acceleration, install: ``pip install onnxruntime-gpu`` """ default_weights_name = "east_50_g1" pretrained_registry = { "east_50_g1": "github://konstantinkozhin/manuscript-ocr/v0.1.0/east_50_g1.onnx", }
[docs] def __init__( self, weights: Optional[Union[str, Path]] = None, device: Optional[str] = None, *, target_size: int = 1280, expand_ratio_w: float = 1.4, expand_ratio_h: float = 1.5, expand_power: float = 0.6, score_thresh: float = 0.6, iou_threshold: float = 0.05, iou_threshold_standard: Optional[float] = 0.05, score_geo_scale: float = 0.25, quantization: int = 2, axis_aligned_output: bool = True, remove_area_anomalies: bool = False, anomaly_sigma_threshold: float = 5.0, anomaly_min_box_count: int = 30, use_tta: bool = False, tta_iou_thresh: float = 0.1, **kwargs, ): super().__init__(weights=weights, device=device, **kwargs) self.onnx_session = None self.target_size = target_size self.expand_ratio_w = expand_ratio_w self.expand_ratio_h = expand_ratio_h self.expand_power = expand_power self.score_thresh = score_thresh self.iou_threshold = iou_threshold self.iou_threshold_standard = iou_threshold_standard self.score_geo_scale = score_geo_scale self.quantization = quantization self.axis_aligned_output = axis_aligned_output self.remove_area_anomalies = remove_area_anomalies self.anomaly_sigma_threshold = anomaly_sigma_threshold self.anomaly_min_box_count = anomaly_min_box_count # TTA parameters self.use_tta = use_tta self.tta_iou_thresh = tta_iou_thresh
def _initialize_session(self): if self.onnx_session is not None: return providers = self.runtime_providers() self.onnx_session = ort.InferenceSession( self.weights, providers=providers, ) self._log_device_info(self.onnx_session) def _scale_boxes_to_original( self, boxes: np.ndarray, orig_size: Tuple[int, int] ) -> np.ndarray: if len(boxes) == 0: return boxes orig_h, orig_w = orig_size scale_x = orig_w / self.target_size scale_y = orig_h / self.target_size scaled = boxes.copy() scaled[:, 0:8:2] *= scale_x scaled[:, 1:8:2] *= scale_y return scaled def _convert_to_axis_aligned(self, quads: np.ndarray) -> np.ndarray: if len(quads) == 0: return quads aligned = quads.copy() coords = aligned[:, :8].reshape(-1, 4, 2) x_min = coords[:, :, 0].min(axis=1) x_max = coords[:, :, 0].max(axis=1) y_min = coords[:, :, 1].min(axis=1) y_max = coords[:, :, 1].max(axis=1) rects = np.stack( [ x_min, y_min, x_max, y_min, x_max, y_max, x_min, y_max, ], axis=1, ) aligned[:, :8] = rects.reshape(-1, 8) return aligned @staticmethod def _polygon_area_batch(polys: np.ndarray) -> np.ndarray: if polys.size == 0: return np.zeros((0,), dtype=np.float32) x = polys[:, :, 0] y = polys[:, :, 1] return 0.5 * np.abs( np.sum(x * np.roll(y, -1, axis=1) - y * np.roll(x, -1, axis=1), axis=1) ) def _is_quad_inside(self, inner: np.ndarray, outer: np.ndarray) -> bool: contour = outer.reshape(-1, 1, 2).astype(np.float32) for point in inner.astype(np.float32): if ( cv2.pointPolygonTest(contour, (float(point[0]), float(point[1])), False) < 0 ): return False return True def _remove_fully_contained_boxes(self, quads: np.ndarray) -> np.ndarray: if len(quads) <= 1: return quads coords = quads[:, :8].reshape(-1, 4, 2) areas = self._polygon_area_batch(coords) keep = np.ones(len(quads), dtype=bool) order = np.argsort(areas) for idx in order: if not keep[idx]: continue inner = coords[idx] inner_area = areas[idx] for jdx in range(len(quads)): if idx == jdx or not keep[jdx]: continue if areas[jdx] + 1e-6 < inner_area: continue if self._is_quad_inside(inner, coords[jdx]): keep[idx] = False break return quads[keep] def _remove_area_anomalies(self, quads: np.ndarray) -> np.ndarray: if ( not self.remove_area_anomalies or len(quads) == 0 or len(quads) <= self.anomaly_min_box_count ): return quads coords = quads[:, :8].reshape(-1, 4, 2) areas = self._polygon_area_batch(coords).astype(np.float32) mean = float(np.mean(areas)) std = float(np.std(areas)) if std == 0.0: return quads threshold = mean + self.anomaly_sigma_threshold * std keep = areas <= threshold if not np.any(keep): return quads return quads[keep] @staticmethod def _box_iou( box1: Tuple[int, int, int, int], box2: Tuple[int, int, int, int] ) -> float: """ Calculate IoU between two axis-aligned boxes. Parameters ---------- box1, box2 : tuple of (x0, y0, x1, y1) Bounding box coordinates. Returns ------- float Intersection over Union value in [0, 1]. """ x0 = max(box1[0], box2[0]) y0 = max(box1[1], box2[1]) x1 = min(box1[2], box2[2]) y1 = min(box1[3], box2[3]) if x1 <= x0 or y1 <= y0: return 0.0 inter = (x1 - x0) * (y1 - y0) area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) return inter / float(area1 + area2 - inter) def _merge_tta_boxes( self, boxes_orig: List[Tuple[Tuple[int, int, int, int], float]], boxes_flipped: List[Tuple[Tuple[int, int, int, int], float]], ) -> List[Tuple[Tuple[int, int, int, int], float]]: """ Merge detection boxes from original and horizontally flipped images. For each box in the original image, find matching box in flipped image (based on IoU threshold). Matched boxes are merged by extending x-coordinates and keeping y-coordinates from original, with averaged scores. This approach keeps vertical positioning stable while potentially extending horizontal coverage when both views detect overlapping regions. Parameters ---------- boxes_orig : list of ((x0, y0, x1, y1), score) Boxes from original image detection. boxes_flipped : list of ((x0, y0, x1, y1), score) Boxes from flipped image detection (already transformed back to original coordinate space). Returns ------- list of ((x0, y0, x1, y1), score) Merged boxes - only boxes that have matches in both views. """ merged = [] for box1, score1 in boxes_orig: for box2, score2 in boxes_flipped: iou = self._box_iou(box1, box2) if iou > self.tta_iou_thresh: # Merge: extend x-coordinates, keep y from original merged_box = ( min(box1[0], box2[0]), # x0: take minimum box1[1], # y0: keep from original max(box1[2], box2[2]), # x1: take maximum box1[3], # y1: keep from original ) avg_score = (score1 + score2) / 2 merged.append((merged_box, avg_score)) break # Found match, move to next original box return merged def _run_inference_on_image( self, img: np.ndarray ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Run ONNX inference on a single image. Parameters ---------- img : np.ndarray RGB image with shape (H, W, 3). Returns ------- tuple of (final_quads_nms, score_map, geo_map) Detected quads after NMS and raw maps. """ resized = cv2.resize(img, (self.target_size, self.target_size)) img_norm = (resized.astype(np.float32) / 255.0 - 0.5) / 0.5 img_input = img_norm.transpose(2, 0, 1)[np.newaxis, :, :, :] input_name = self.onnx_session.get_inputs()[0].name output_names = [out.name for out in self.onnx_session.get_outputs()] outputs = self.onnx_session.run(output_names, {input_name: img_input}) score_map = outputs[0].squeeze(0).squeeze(0) geo_map = outputs[1].squeeze(0) final_quads = decode_quads_from_maps( score_map=score_map, geo_map=geo_map.transpose(1, 2, 0), score_thresh=self.score_thresh, scale=1.0 / self.score_geo_scale, quantization=self.quantization, ) final_quads_nms = locality_aware_nms( final_quads, iou_threshold=self.iou_threshold, iou_threshold_standard=self.iou_threshold_standard, ) return final_quads_nms, score_map, geo_map
[docs] def predict( self, img_or_path: Union[str, Path, np.ndarray], return_maps: bool = False, sort_reading_order: bool = True, split_into_columns: bool = True, max_columns: int = 10, ) -> Dict[str, Any]: """ Run EAST inference and return detection results. Parameters ---------- img_or_path : str or pathlib.Path or numpy.ndarray Path to an image file or an RGB image provided as a NumPy array with shape ``(H, W, 3)`` in ``uint8`` format. return_maps : bool, optional If True, returns raw model score and geometry maps under keys ``"score_map"`` and ``"geo_map"``. Default is False. sort_reading_order : bool, optional If True, sorts detected words in natural reading order (left-to-right, top-to-bottom) and groups them into text lines. Default is True. split_into_columns : bool, optional If True and ``sort_reading_order=True``, segments the page into columns (separate Blocks). If False, treats entire page as single column. Only used when ``sort_reading_order=True``. Default is True. max_columns : int, optional Maximum number of columns to detect when ``split_into_columns=True``. Higher values allow more columns to be detected. Only used when ``sort_reading_order=True`` and ``split_into_columns=True``. Default is 10. Returns ------- dict Dictionary with the following keys: - ``"page"`` : Page Parsed detection result as a Page object containing Block(s) with Line(s) of Word objects. Each Word has polygon coordinates and confidence scores. Words and Lines have reading order indices. - ``"score_map"`` : numpy.ndarray or None Raw score map produced by the network if ``return_maps=True``. - ``"geo_map"`` : numpy.ndarray or None Raw geometry map if ``return_maps=True``. Notes ----- The method performs: (1) image loading, (2) resizing and normalization, (3) model inference, (4) quad decoding, (5) NMS, (6) box expansion, (7) scaling coordinates back to original size, (8) optional reading order sorting into lines. **Test-Time Augmentation (TTA):** When ``use_tta=True`` is set during initialization, the method runs inference on both the original and horizontally flipped image, then merges results. Boxes from both views are matched by IoU and merged by taking the union of coordinates with averaged scores. This can improve detection of text near image edges or partially visible text. For visualization, use the external ``visualize_page`` utility: >>> from manuscript.utils import visualize_page >>> result = model.predict(img_path) >>> vis_img = visualize_page(img, result["page"]) Examples -------- Perform inference and get structured output: >>> from manuscript.detectors import EAST >>> model = EAST() >>> img_path = r"example/ocr_example_image.jpg" >>> result = model.predict(img_path) >>> page = result["page"] >>> # Access first line's first word >>> first_word = page.blocks[0].lines[0].words[0] >>> print(f"Confidence: {first_word.detection_confidence}") Visualize results separately: >>> from manuscript.utils import visualize_page, read_image >>> result = model.predict(img_path) >>> img = read_image(img_path) >>> vis_img = visualize_page(img, result["page"]) >>> vis_img.show() """ if self.onnx_session is None: self._initialize_session() img = read_image(img_or_path) orig_h, orig_w = img.shape[:2] # Run inference on original image final_quads_nms, score_map, geo_map = self._run_inference_on_image(img) # TTA: Run on horizontally flipped image and merge results if self.use_tta: # Flip image horizontally img_flipped = np.fliplr(img).copy() final_quads_flipped, _, _ = self._run_inference_on_image(img_flipped) # Scale both to original size first scaled_orig = self._scale_boxes_to_original(final_quads_nms, (orig_h, orig_w)) scaled_flipped = self._scale_boxes_to_original(final_quads_flipped, (orig_h, orig_w)) # Convert to axis-aligned boxes for IoU matching # Keep track of original quads for non-axis-aligned output boxes_orig = [] quads_orig_list = [] # Store original quads for quad in scaled_orig: pts = quad[:8].reshape(4, 2) x0, y0 = pts.min(axis=0) x1, y1 = pts.max(axis=0) score = float(np.clip(quad[8], 0.0, 1.0)) boxes_orig.append(((int(x0), int(y0), int(x1), int(y1)), score)) quads_orig_list.append(quad.copy()) boxes_flipped = [] for quad in scaled_flipped: pts = quad[:8].reshape(4, 2) x0, y0 = pts.min(axis=0) x1, y1 = pts.max(axis=0) # Mirror x coordinates back x0_mirrored = orig_w - x1 x1_mirrored = orig_w - x0 score = float(np.clip(quad[8], 0.0, 1.0)) boxes_flipped.append(((int(x0_mirrored), int(y0), int(x1_mirrored), int(y1)), score)) # Find which original boxes have matches in flipped view matched_orig_indices = [] for idx, (box1, score1) in enumerate(boxes_orig): for box2, score2 in boxes_flipped: iou = self._box_iou(box1, box2) if iou > self.tta_iou_thresh: matched_orig_indices.append(idx) break # Found match, move to next original box # Filter to keep only matched quads (preserves original 4-point polygons) if matched_orig_indices: matched_quads = np.stack( [quads_orig_list[i] for i in matched_orig_indices], axis=0 ) else: matched_quads = np.empty((0, 9), dtype=np.float32) # Expand boxes expanded = expand_boxes( matched_quads, expand_w=self.expand_ratio_w, expand_h=self.expand_ratio_h, expand_power=self.expand_power, ) else: # No TTA: standard processing expanded = expand_boxes( final_quads_nms, expand_w=self.expand_ratio_w, expand_h=self.expand_ratio_h, expand_power=self.expand_power, ) expanded = self._scale_boxes_to_original(expanded, (orig_h, orig_w)) processed_quads = self._remove_fully_contained_boxes(expanded) processed_quads = self._remove_area_anomalies(processed_quads) output_quads = ( self._convert_to_axis_aligned(processed_quads) if self.axis_aligned_output else processed_quads ) words: List[Word] = [] for quad in output_quads: pts = quad[:8].reshape(4, 2) score = float(np.clip(quad[8], 0.0, 1.0)) words.append(Word(polygon=pts.tolist(), detection_confidence=score)) if sort_reading_order and len(words) > 0: initial_page = Page( blocks=[Block(lines=[Line(words=words, order=0)], order=0)] ) page = organize_page( initial_page, max_splits=max_columns, use_columns=split_into_columns ) else: for idx, w in enumerate(words): w.order = idx page = Page(blocks=[Block(lines=[Line(words=words, order=0)], order=0)]) return { "page": page, "score_map": score_map if return_maps else None, "geo_map": geo_map if return_maps else None, }
[docs] @staticmethod def train( train_images: Union[str, Path, Sequence[Union[str, Path]]], train_anns: Union[str, Path, Sequence[Union[str, Path]]], val_images: Union[str, Path, Sequence[Union[str, Path]]], val_anns: Union[str, Path, Sequence[Union[str, Path]]], *, experiment_root: str = "./experiments", model_name: str = "resnet_quad", backbone_name: str = "resnet50", pretrained_backbone: bool = True, freeze_first: bool = True, target_size: int = 1024, score_geo_scale: Optional[float] = None, epochs: int = 500, batch_size: int = 3, accumulation_steps: int = 1, lr: float = 1e-3, grad_clip: float = 5.0, early_stop: int = 100, use_sam: bool = True, sam_type: str = "asam", use_lookahead: bool = True, use_ema: bool = False, use_multiscale: bool = True, use_ohem: bool = True, ohem_ratio: float = 0.5, use_focal_geo: bool = True, focal_gamma: float = 2.0, resume_from: Optional[Union[str, Path]] = None, val_interval: int = 1, num_workers: int = 0, device: Optional["torch.device"] = None, ) -> "torch.nn.Module": """ Train EAST model on custom datasets. Parameters ---------- train_images : str, Path or sequence of paths Path(s) to training image folders. train_anns : str, Path or sequence of paths Path(s) to COCO-format JSON annotation files corresponding to ``train_images``. val_images : str, Path or sequence of paths Path(s) to validation image folders. val_anns : str, Path or sequence of paths Path(s) to COCO-format JSON annotation files corresponding to ``val_images``. experiment_root : str, optional Base directory where experiment folders will be created. Default is ``"./experiments"``. model_name : str, optional Folder name inside ``experiment_root`` for logs and checkpoints. Default is ``"resnet_quad"``. backbone_name : {"resnet50", "resnet101"}, optional Backbone architecture to use. Options: - ``"resnet50"`` — ResNet-50 (faster, less parameters) - ``"resnet101"`` — ResNet-101 (slower, more capacity) Default is ``"resnet50"``. pretrained_backbone : bool, optional Use ImageNet-pretrained backbone weights. Default ``True``. freeze_first : bool, optional Freeze lowest layers of the backbone. Default ``True``. target_size : int, optional Resize shorter side of images to this size. Default ``1024``. score_geo_scale : float, optional Multiplier to recover original coordinates from score/geo maps. If None, automatically taken from the model. Default ``None``. epochs : int, optional Number of training epochs. Default ``500``. batch_size : int, optional Batch size per GPU. Default ``3``. accumulation_steps : int, optional Number of gradient accumulation steps. Effective batch size will be ``batch_size * accumulation_steps``. Use this to train with larger effective batch sizes when GPU memory is limited. For example: - ``batch_size=2, accumulation_steps=4`` → effective batch size = 8 - ``batch_size=1, accumulation_steps=8`` → effective batch size = 8 Default is ``1`` (no accumulation). lr : float, optional Learning rate. Default ``1e-3``. grad_clip : float, optional Gradient clipping value (L2 norm). Default ``5.0``. early_stop : int, optional Patience (epochs without improvement) for early stopping. Default ``100``. use_sam : bool, optional Enable SAM optimizer. Default ``True``. sam_type : {"sam", "asam"}, optional Variant of SAM to use. Default ``"asam"``. use_lookahead : bool, optional Wrap optimizer with Lookahead. Default ``True``. use_ema : bool, optional Maintain EMA version of model weights. Default ``False``. use_multiscale : bool, optional Random multi-scale training. Default ``True``. use_ohem : bool, optional Online Hard Example Mining. Default ``True``. ohem_ratio : float, optional Ratio of hard negatives for OHEM. Default ``0.5``. use_focal_geo : bool, optional Apply focal loss to geometry channels. Default ``True``. focal_gamma : float, optional Gamma for focal geometry loss. Default ``2.0``. resume_from : str or Path, optional Resume training from a previous experiment: a) experiment directory, b) `.../checkpoints/`, c) direct path to `last_state.pt`. Default ``None``. val_interval : int, optional Run validation every N epochs. Default ``1``. num_workers : int, optional Number of worker processes for data loading. Set to 0 for single-process loading (safer on Windows). Default ``0``. device : torch.device, optional CUDA or CPU device. Auto-selects if None. Returns ------- torch.nn.Module Best model weights (EMA if enabled, otherwise base model). Examples -------- Train on two datasets with validation: >>> from manuscript.detectors import EAST >>> >>> train_images = [ ... "/data/archive/train_images", ... "/data/ddi/train_images" ... ] >>> train_anns = [ ... "/data/archive/train.json", ... "/data/ddi/train.json" ... ] >>> val_images = [ ... "/data/archive/test_images", ... "/data/ddi/test_images" ... ] >>> val_anns = [ ... "/data/archive/test.json", ... "/data/ddi/test.json" ... ] >>> >>> best_model = EAST.train( ... train_images=train_images, ... train_anns=train_anns, ... val_images=val_images, ... val_anns=val_anns, ... backbone_name="resnet50", ... target_size=256, ... epochs=20, ... batch_size=4, ... use_sam=False, ... freeze_first=False, ... val_interval=3, ... ) >>> print("Best checkpoint loaded:", best_model) """ if not _TORCH_AVAILABLE: raise ImportError( "PyTorch is required for training. " "Install with: pip install manuscript-ocr[dev]" ) if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = EASTModel( backbone_name=backbone_name, pretrained_backbone=pretrained_backbone, freeze_first=freeze_first, ).to(device) if score_geo_scale is None: score_geo_scale = model.score_scale def make_dataset(imgs, anns, name: Optional[str] = None): return EASTDataset( images_folder=imgs, coco_annotation_file=anns, target_size=target_size, score_geo_scale=score_geo_scale, dataset_name=name, ) def _dataset_base_name( img_path: Union[str, Path], ann_path: Union[str, Path] ) -> str: ann = Path(os.fspath(ann_path)) parts: List[str] = [] if ann.parent.name: parts.append(ann.parent.name) if ann.stem: parts.append(ann.stem) if not parts: img = Path(os.fspath(img_path)) if img.parent.name: parts.append(img.parent.name) stem = img.stem or img.name if stem: parts.append(stem) return "/".join(parts) def _unique_dataset_name( img_path: Union[str, Path], ann_path: Union[str, Path], counts: Dict[str, int], idx: int, kind: str, ) -> str: base = _dataset_base_name(img_path, ann_path) if not base: base = f"{kind}_{idx}" count = counts.get(base, 0) counts[base] = count + 1 if count == 0: return base return f"{base}_{count + 1}" train_imgs_list = ( train_images if isinstance(train_images, (list, tuple)) else [train_images] ) train_anns_list = ( train_anns if isinstance(train_anns, (list, tuple)) else [train_anns] ) val_imgs_list = ( val_images if isinstance(val_images, (list, tuple)) else [val_images] ) val_anns_list = val_anns if isinstance(val_anns, (list, tuple)) else [val_anns] assert len(train_imgs_list) == len( train_anns_list ), "train_images and train_anns must have the same length" assert len(val_imgs_list) == len( val_anns_list ), "val_images and val_anns must have the same length" train_datasets = [] train_name_counts: Dict[str, int] = {} for idx, (imgs, anns) in enumerate( zip(train_imgs_list, train_anns_list), start=1 ): dataset_name = _unique_dataset_name( imgs, anns, train_name_counts, idx=idx, kind="train" ) train_datasets.append(make_dataset(imgs, anns, name=dataset_name)) val_datasets = [] val_name_counts: Dict[str, int] = {} for idx, (imgs, anns) in enumerate(zip(val_imgs_list, val_anns_list), start=1): dataset_name = _unique_dataset_name( imgs, anns, val_name_counts, idx=idx, kind="val" ) val_datasets.append(make_dataset(imgs, anns, name=dataset_name)) train_ds = ConcatDataset(train_datasets) val_ds = ConcatDataset(val_datasets) val_dataset_names = [ds.dataset_name for ds in val_datasets] def _resolve_path(path: Union[str, Path]) -> Path: p = Path(path) if p.is_absolute(): return p project_root = Path(__file__).resolve().parents[4] candidate = (project_root / p).resolve() if candidate.exists(): return candidate return (Path.cwd() / p).resolve() def _is_experiment_checkpoint(file_path: Path) -> bool: parent = file_path.parent return parent.name == "checkpoints" or (parent / "training_config.json").exists() def _resolve_resume_target( target: Union[str, Path], default_experiment_dir: str, ) -> Tuple[str, Optional[Path]]: resolved = _resolve_path(target) if not resolved.exists(): raise FileNotFoundError( f"resume_from target does not exist: {resolved}" ) if resolved.is_file(): resume_state = resolved if _is_experiment_checkpoint(resolved): checkpoints_dir = resolved.parent if checkpoints_dir.name == "checkpoints": experiment_dir = checkpoints_dir.parent else: experiment_dir = checkpoints_dir return os.path.abspath(os.fspath(experiment_dir)), resume_state else: return default_experiment_dir, resume_state experiment_dir = resolved checkpoints_dir = ( resolved if resolved.name == "checkpoints" else resolved / "checkpoints" ) default_state = checkpoints_dir / "last_state.pt" resume_state = default_state if default_state.exists() else None return os.path.abspath(os.fspath(experiment_dir)), resume_state default_experiment_dir = os.path.abspath(os.path.join(experiment_root, model_name)) resume_state_path: Optional[Path] = None if resume_from is None: experiment_dir = default_experiment_dir resume_flag = False else: experiment_dir, resume_state_path = _resolve_resume_target( resume_from, default_experiment_dir ) resume_flag = True best_model = _run_training( experiment_dir=experiment_dir, model=model, train_dataset=train_ds, val_dataset=val_ds, device=device, num_epochs=epochs, batch_size=batch_size, accumulation_steps=accumulation_steps, lr=lr, grad_clip=grad_clip, early_stop=early_stop, use_sam=use_sam, sam_type=sam_type, use_lookahead=use_lookahead, use_ema=use_ema, use_multiscale=use_multiscale, use_ohem=use_ohem, ohem_ratio=ohem_ratio, use_focal_geo=use_focal_geo, focal_gamma=focal_gamma, val_interval=val_interval, num_workers=num_workers, backbone_name=backbone_name, target_size=target_size, pretrained_backbone=pretrained_backbone, val_datasets=val_datasets, val_dataset_names=val_dataset_names, resume=resume_flag, resume_state_path=( os.fspath(resume_state_path) if resume_state_path else None ), ) return best_model
[docs] @staticmethod def export( weights_path: Union[str, Path], output_path: Union[str, Path], backbone_name: str = None, input_size: int = 1280, opset_version: int = 14, simplify: bool = True, ) -> None: """ Export EAST PyTorch model to ONNX format. This method converts a trained EAST model from PyTorch to ONNX format, which can be used for faster inference with ONNX Runtime. The exported model can be loaded using ``EAST(weights_path="model.onnx", use_onnx=True)``. Parameters ---------- weights_path : str or Path Path to the PyTorch model weights file (.pth). output_path : str or Path Path where the ONNX model will be saved (.onnx). backbone_name : {"resnet50", "resnet101"}, optional Backbone architecture of the model. If None, will be automatically detected from the checkpoint. Must match the architecture used during training. Default is None (auto-detect). input_size : int, optional Input image size (height and width). The model will accept images of shape ``(batch, 3, input_size, input_size)``. Default is 1280. opset_version : int, optional ONNX opset version to use for export. Default is 14. simplify : bool, optional If True, applies ONNX graph simplification using onnx-simplifier to optimize the model. Requires ``onnx-simplifier`` package. Default is True. Returns ------- None The ONNX model is saved to ``output_path``. Raises ------ ImportError If required packages (torch, onnx) are not installed. FileNotFoundError If ``weights_path`` does not exist. ValueError If backbone_name doesn't match the checkpoint architecture. Notes ----- The exported ONNX model has two outputs: - ``score_map``: Text confidence map with shape ``(batch, 1, H, W)`` - ``geo_map``: Geometry map with shape ``(batch, 8, H, W)`` The model supports dynamic batch size and image dimensions through dynamic axes configuration. **Automatic Backbone Detection:** The method automatically detects the backbone architecture from the checkpoint by analyzing the number of parameters in layer4. This prevents mismatches between checkpoint and architecture that could lead to incorrect exports. Examples -------- Export with automatic backbone detection: >>> from manuscript.detectors import EAST >>> EAST.export_to_onnx( ... weights_path="east_resnet50.pth", ... output_path="east_model.onnx" ... ) Auto-detected backbone: resnet50 Exporting to ONNX (opset 14)... [OK] ONNX model saved to: east_model.onnx Export with explicit backbone: >>> EAST.export_to_onnx( ... weights_path="custom_weights.pth", ... output_path="custom_model.onnx", ... backbone_name="resnet101", ... input_size=1024, ... simplify=False ... ) Use the exported model for inference: >>> detector = EAST( ... weights_path="east_model.onnx", ... use_onnx=True, ... device="cuda" ... ) >>> result = detector.predict("image.jpg") See Also -------- EAST.__init__ : Initialize EAST detector with ONNX support using ``use_onnx=True``. """ if not _TORCH_AVAILABLE: raise ImportError( "PyTorch is required for exporting models. " "Install with: pip install manuscript-ocr[dev]" ) class EASTWrapper(torch.nn.Module): def __init__(self, east_model): super().__init__() self.east = east_model def forward(self, x): output = self.east(x) return output["score"], output["geometry"] weights_path = Path(weights_path) if not weights_path.exists(): raise FileNotFoundError(f"Weights file not found: {weights_path}") print(f"Loading checkpoint from {weights_path}...") checkpoint = torch.load(str(weights_path), map_location="cpu") # Auto-detect backbone if not specified if backbone_name is None: print("Auto-detecting backbone architecture...") # Detect by checking layer3 parameters count # ResNet50: layer3 has 6 bottleneck blocks (layer3.0 - layer3.5) # ResNet101: layer3 has 23 bottleneck blocks (layer3.0 - layer3.22) # Count backbone.extractor.* parameters starting with "layer3" layer3_keys = [ k for k in checkpoint.keys() if "backbone.extractor.layer3" in k ] # Check if layer3.10 exists (only in resnet101) has_layer3_10 = any("layer3.10" in k for k in layer3_keys) if has_layer3_10: detected_backbone = "resnet101" else: detected_backbone = "resnet50" print(f" Detected {len(layer3_keys)} layer3 parameters") print(f" Auto-detected backbone: {detected_backbone}") backbone_name = detected_backbone else: print(f"Using specified backbone: {backbone_name}") # Verify backbone matches checkpoint print("Verifying backbone matches checkpoint...") layer3_keys = [ k for k in checkpoint.keys() if "backbone.extractor.layer3" in k ] has_layer3_10 = any("layer3.10" in k for k in layer3_keys) expected_backbone = "resnet101" if has_layer3_10 else "resnet50" if expected_backbone != backbone_name: raise ValueError( f"Backbone mismatch! Checkpoint is {expected_backbone}, " f"but you specified {backbone_name}. " f"Either use backbone_name='{expected_backbone}' or set backbone_name=None for auto-detection." ) print(f" [OK] Backbone matches checkpoint ({backbone_name})") print(f"\nLoading PyTorch model...") east_model = EASTModel( backbone_name=backbone_name, pretrained_backbone=False, pretrained_model_path=str(weights_path), ) east_model.eval() model = EASTWrapper(east_model) model.eval() dummy_input = torch.randn(1, 3, input_size, input_size) print(f"Model architecture: {model.__class__.__name__}") print(f"Input shape: {dummy_input.shape}") with torch.no_grad(): score_map, geo_map = model(dummy_input) print(f"Output shapes:") print(f" - score_map: {score_map.shape}") print(f" - geo_map: {geo_map.shape}") print(f"\nExporting to ONNX (opset {opset_version})...") torch.onnx.export( model, dummy_input, str(output_path), export_params=True, opset_version=opset_version, do_constant_folding=True, input_names=["input"], output_names=["score_map", "geo_map"], dynamic_axes={ "input": {0: "batch_size", 2: "height", 3: "width"}, "score_map": {0: "batch_size", 2: "height", 3: "width"}, "geo_map": {0: "batch_size", 2: "height", 3: "width"}, }, verbose=False, ) print(f"[OK] ONNX model saved to: {output_path}") import onnx import onnxsim print("\nVerifying ONNX model...") onnx_model = onnx.load(str(output_path)) onnx.checker.check_model(onnx_model) print("[OK] ONNX model is valid") if simplify: print("\nSimplifying ONNX model...") model_simplified, check = onnxsim.simplify(onnx_model) if check: onnx.save(model_simplified, str(output_path)) print("[OK] ONNX model simplified") else: print("[WARNING] Simplification failed, using original model") # Test ONNX inference and compare with PyTorch try: import onnxruntime as ort print(f"\nTesting ONNX inference...") session = ort.InferenceSession( str(output_path), providers=["CPUExecutionProvider"] ) ort_inputs = {"input": dummy_input.numpy()} ort_outputs = session.run(None, ort_inputs) print(f"[OK] ONNX inference works!") print(f" ONNX score_map shape: {ort_outputs[0].shape}") print(f" ONNX geo_map shape: {ort_outputs[1].shape}") # Compare with PyTorch print(f"\nComparing ONNX vs PyTorch outputs...") torch_score = score_map.numpy() torch_geo = geo_map.numpy() onnx_score = ort_outputs[0] onnx_geo = ort_outputs[1] score_max_diff = abs(torch_score - onnx_score).max() geo_max_diff = abs(torch_geo - onnx_geo).max() print(f" Max difference in score_map: {score_max_diff:.6f}") print(f" Max difference in geo_map: {geo_max_diff:.6f}") if score_max_diff < 1e-4 and geo_max_diff < 1e-4: print(f" [OK] Outputs match!") elif score_max_diff < 1e-3 and geo_max_diff < 1e-3: print(f" [WARNING] Small differences detected (acceptable)") else: print(f" [ERROR] Outputs differ significantly!") print(f" This may indicate a problem with the export.") except ImportError: print(f"[WARNING] onnxruntime not installed, skipping inference test") print(f" Install with: pip install onnxruntime") except Exception as e: print(f"[WARNING] ONNX inference test failed: {e}") file_size_mb = Path(output_path).stat().st_size / (1024 * 1024) print(f"\n[OK] Export complete! Model size: {file_size_mb:.1f} MB") print(f"\n=== Summary ===") print(f"Backbone: {backbone_name}") print(f"Input shape: [batch_size, 3, {input_size}, {input_size}]") print(f"Output shapes:") print(f" - score_map: [batch_size, 1, H, W]") print(f" - geo_map: [batch_size, 8, H, W]")