Распознаватели
Модели распознавания текста.
- class manuscript.recognizers.TRBA(weights=None, config=None, charset=None, device=None, force_download=False, rotate_threshold=1.5, region_preparer='bbox', region_preparer_options=None, min_text_size=5, **kwargs)[исходный код]
Базовые классы:
BaseRecognizerИнициализация модели распознавания текста TRBA с использованием ONNX Runtime.
- Параметры:
weights (str or Path, optional) – Путь или идентификатор весов ONNX-модели. Поддерживаются: - локальный путь к файлу:
"path/to/model.onnx"- HTTP/HTTPS URL:"https://example.com/model.onnx"- релиз GitHub:"github://owner/repo/tag/file.onnx"- Google Drive:"gdrive:FILE_ID"- имя пресета:"trba_lite_g1"или"trba_base_g1"(из pretrained_registry) -None: автоматическая загрузка пресета по умолчанию (trba_lite_g1)config (str or Path, optional) – Путь или идентификатор JSON-конфигурации модели. Поддерживает те же схемы, что и
weights. ЕслиNone, конфигурация определяется по расположению весов или используется конфигурация по умолчанию для пресетов.charset (str or Path, optional) – Путь или идентификатор файла набора символов. Если
None, набор символов ищется рядом с весами или используется вариант по умолчанию из пакета.device ({"cuda", "coreml", "cpu"}, optional) – Устройство вычислений. Если
None, автоматически выбирается CPU. Для ускорения GPU/CoreML: - CUDA (NVIDIA):pip install onnxruntime-gpu- CoreML (Apple Silicon M1/M2/M3):pip install onnxruntime-siliconПо умолчаниюNone(CPU).rotate_threshold (float or None, optional) – Aspect-ratio threshold for rotating vertical text-span crops before recognition. If
height > width * rotate_threshold, crop is rotated 90 degrees clockwise. Set to0orNoneto disable. Default is1.5.region_preparer ({"bbox", "polygon_mask", "quad_warp"} or callable, optional) – Strategy used to convert
Pagepolygons into recognition crops."bbox"extracts axis-aligned bounding boxes for arbitrary polygons."polygon_mask"masks pixels outside the polygon inside a tight crop and also supports arbitrary polygons."quad_warp"rectifies only 4-point polygons with a perspective transform before recognition. A custom callable may also be provided and should return a list of prepared text regions. Default is"bbox".region_preparer_options (dict or None, optional) – Optional configuration for built-in region preparers. Defaults to
None. Typical options arepadfor"bbox"and"polygon_mask", oroutput_size=(width, height)for"quad_warp". Non-quad polygons passed to"quad_warp"fall back to bbox crops by default.min_text_size (int, optional) – Minimum crop width/height in pixels to run recognition for a text span. Text spans below this threshold are skipped. Default is
5.**kwargs – Дополнительные параметры конфигурации (зарезервированы для будущего использования).
force_download (bool)
- Исключение:
FileNotFoundError – Если указанные файлы не существуют.
ValueError – Если формат весов некорректен.
Заметки
Класс предоставляет три основных публичных метода:
predict- run recognition over text spans in aPageobject.train- high-level training entrypoint to train a TRBA model on custom datasets.export- static method to export PyTorch model to ONNX format.
Модель использует ONNX Runtime для быстрого инференса на CPU и GPU. Для ускорения на GPU установите:
pip install onnxruntime-gpuПримеры
Создание распознавателя с пресетом по умолчанию (автозагрузка):
>>> from manuscript.recognizers import TRBA >>> recognizer = TRBA()
Загрузка из локального ONNX-файла:
>>> recognizer = TRBA(weights="path/to/model.onnx")
Загрузка из релиза GitHub:
>>> recognizer = TRBA( ... weights="github://owner/repo/v1.0/model.onnx", ... config="github://owner/repo/v1.0/config.json" ... )
Принудительное выполнение на CPU:
>>> recognizer = TRBA(weights="model.onnx", device="cpu")
Методы
__call__(*args, **kwargs)Call self as a function.
export(weights_path, config_path, ...[, ...])Export TRBA PyTorch model to ONNX format.
predict(page[, image, batch_size, ...])Recognize text for text spans in a
Pageand return updatedPage.runtime_providers()Get ONNX Runtime execution providers based on device.
train(train_csvs, train_roots[, val_csvs, ...])Train TRBA text recognition model on custom datasets.
- __init__(weights=None, config=None, charset=None, device=None, force_download=False, rotate_threshold=1.5, region_preparer='bbox', region_preparer_options=None, min_text_size=5, **kwargs)[исходный код]
- charset_registry = {'trba_base_g1': 'github://konstantinkozhin/manuscript-ocr/v0.1.0/trba_base_g1.txt', 'trba_lite_g1': 'github://konstantinkozhin/manuscript-ocr/v0.1.0/trba_lite_g1.txt', 'trba_lite_g2': 'github://konstantinkozhin/manuscript-ocr/v0.1.0/trba_lite_g2.txt'}
- config_registry = {'trba_base_g1': 'github://konstantinkozhin/manuscript-ocr/v0.1.0/trba_base_g1.json', 'trba_lite_g1': 'github://konstantinkozhin/manuscript-ocr/v0.1.0/trba_lite_g1.json', 'trba_lite_g2': 'github://konstantinkozhin/manuscript-ocr/v0.1.0/trba_lite_g2.json'}
- static export(weights_path, config_path, charset_path, output_path, opset_version=14, simplify=True)[исходный код]
Export TRBA PyTorch model to ONNX format.
Этот метод преобразует обученную модель TRBA из PyTorch в формат ONNX, который может использоваться для более быстрого инференса с ONNX Runtime. Экспортированная модель может быть загружена с помощью
TRBA(weights="model.onnx").- Параметры:
weights_path (str or Path) – Путь к файлу весов модели PyTorch (.pth).
config_path (str or Path) – Путь к JSON-файлу конфигурации модели. Используется для определения архитектуры модели (img_h, img_w, max_len, hidden_size и т. д.).
charset_path (str or Path) – Путь к файлу набора символов (charset.txt). Используется для определения количества классов модели.
output_path (str or Path) – Путь, по которому будет сохранена ONNX-модель (.onnx).
opset_version (int, optional) – Версия opset ONNX для экспорта. По умолчанию 14.
simplify (bool, optional) – Если True, применяется упрощение графа ONNX с помощью onnx-simplifier для оптимизации модели. Требуется пакет
onnx-simplifier. По умолчанию True.
- Результат:
ONNX-модель сохраняется в
output_path.- Тип результата:
None
- Исключение:
ImportError – Если необходимые пакеты (torch, onnx) не установлены.
FileNotFoundError – Если
weights_pathилиconfig_pathне существуют.
Заметки
Экспортированная ONNX-модель имеет один выход:
logits: предсказания символов с формой(batch, max_length+1, num_classes)
Модель использует жадное декодирование (argmax) и поддерживает динамический размер батча. Длина последовательности зафиксирована как
max_length + 1из конфигурации (аналогично режиму инференса PyTorch для совместимости).Экспортируемая архитектура: - CNN-бэкбон (SE-ResNet-31 или SE-ResNet-31-Lite) - двунаправленный LSTM-энкодер - декодер внимания (жадное декодирование)
Примечание: экспортируется только декодер внимания. CTC-голова используется только во время обучения и не включена в ONNX-модель.
Примеры
Экспорт модели TRBA в ONNX:
>>> from manuscript.recognizers import TRBA >>> TRBA.export( ... weights_path="experiments/best_model/best_acc_weights.pth", ... config_path="experiments/best_model/config.json", ... charset_path="configs/charset.txt", ... output_path="trba_model.onnx" ... ) Loading TRBA model... === TRBA ONNX Export === Max decoding length: 40 Input size: 64x256 [OK] ONNX model saved to: trba_model.onnx
Экспорт с пользовательским opset:
>>> TRBA.export( ... weights_path="model.pth", ... config_path="config.json", ... charset_path="charset.txt", ... output_path="model.onnx", ... opset_version=16, ... simplify=False ... )
Использование экспортированной модели для инференса:
>>> from manuscript.detectors import EAST >>> recognizer = TRBA(weights="trba_model.onnx") >>> detector = EAST() >>> det = detector.predict("page.jpg") >>> result = recognizer.predict(det["page"], image="page.jpg")
См. также
TRBA.__init__Инициализация распознавателя TRBA с ONNX-моделью.
- predict(page, image=None, batch_size=32, debug_save_dir=None)[исходный код]
Recognize text for text spans in a
Pageand return updatedPage.- Параметры:
page (Page) – Page object with detected text-span polygons.
image (str, Path, numpy.ndarray, or PIL.Image, optional) – Source page image used to extract text regions. If
None, recognition is skipped and a deep copy ofpageis returned.batch_size (int, optional) – Number of prepared text regions to process simultaneously.
debug_save_dir (str or Path, optional) – If provided, saves the prepared recognition crops to this directory as
*.pngfiles together withindex.json. Crops are saved afterregion_preparerand auto-rotation, i.e. in the same orientation that goes into recognizer inference.
- Результат:
New Page object with recognized
textandrecognition_confidencefilled for processed text spans.- Тип результата:
- pretrained_registry: Dict[str, str] = {'trba_base_g1': 'github://konstantinkozhin/manuscript-ocr/v0.1.0/trba_base_g1.onnx', 'trba_lite_g1': 'github://konstantinkozhin/manuscript-ocr/v0.1.0/trba_lite_g1.onnx', 'trba_lite_g2': 'github://konstantinkozhin/manuscript-ocr/v0.1.0/trba_lite_g2.onnx'}
- static train(train_csvs, train_roots, val_csvs=None, val_roots=None, *, exp_dir=None, charset_path=None, encoding='utf-8', img_h=64, img_w=256, max_len=25, hidden_size=256, num_encoder_layers=3, cnn_in_channels=3, cnn_out_channels=512, cnn_backbone='seresnet31', ctc_weight=0.3, ctc_weight_decay_epochs=50, ctc_weight_min=0.0, max_grad_norm=5.0, batch_size=32, epochs=20, lr=0.001, optimizer='AdamW', scheduler='OneCycleLR', weight_decay=0.0, momentum=0.9, val_interval=1, val_size=3000, train_proportions=None, num_workers=0, seed=42, resume_from=None, save_interval=None, device='cuda', freeze_cnn='none', freeze_enc_rnn='none', freeze_attention='none', pretrain_weights='default', **extra_config)[исходный код]
Train TRBA text recognition model on custom datasets.
- Параметры:
train_csvs (str, Path or sequence of paths) – Путь или пути к CSV-файлам обучения. Каждый CSV должен содержать столбцы:
image_path(относительноtrain_roots) иtext(эталонная транскрипция).train_roots (str, Path or sequence of paths) – Корневая директория или директории, содержащие обучающие изображения. Должны иметь ту же длину, что и
train_csvs.val_csvs (str, Path, sequence of paths, or None, optional) – Путь или пути к CSV-файлам валидации с тем же форматом, что и
train_csvs. ЕслиNone, валидация не выполняется. По умолчаниюNone.val_roots (str, Path, sequence of paths, or None, optional) – Корневая директория или директории для изображений валидации. Должны соответствовать длине
val_csvs, если они указаны. По умолчаниюNone.exp_dir (str or Path, optional) – Директория эксперимента, в которую будут сохраняться чекпойнты и логи. Если
None, автоматически создаётся на основе временной метки. По умолчаниюNone.charset_path (str or Path, optional) – Путь к файлу набора символов. Если
None, используется набор символов по умолчанию из пакета. По умолчаниюNone.encoding (str, optional) – Кодировка текста для чтения CSV-файлов. По умолчанию
"utf-8".img_h (int, optional) – Целевая высота входных изображений (в пикселях). По умолчанию 64.
img_w (int, optional) – Целевая ширина входных изображений (в пикселях). По умолчанию 256.
max_len (int, optional) – Максимальная длина последовательности для распознавания текста. По умолчанию 25.
hidden_size (int, optional) – Размер скрытого слоя RNN-энкодера/декодера. По умолчанию 256.
num_encoder_layers (int, optional) – Количество двунаправленных LSTM-слоёв в энкодере. По умолчанию 2.
cnn_in_channels (int, optional) – Количество входных каналов для CNN-бэкбона (3 для RGB, 1 для оттенков серого). По умолчанию 3.
cnn_out_channels (int, optional) – Количество выходных каналов CNN-бэкбона. По умолчанию 512.
cnn_backbone ({"seresnet31", "seresnet31-lite"}, optional) – Вариант CNN-бэкбона.
"seresnet31"использует стандартный SE-ResNet-31, а"seresnet31-lite"включает облегчённую depthwise-версию. По умолчанию"seresnet31".ctc_weight (float, optional) – Начальный вес CTC-функции потерь во время обучения (CTC всегда используется для стабильности):
loss = attn_loss * (1 - ctc_weight) + ctc_loss * ctc_weight. Вес CTC уменьшается по эпохам. По умолчанию 0.3.ctc_weight_decay_epochs (int, optional) – Количество эпох, за которые вес CTC уменьшается до минимума. По умолчанию 50.
ctc_weight_min (float, optional) – Минимальное значение веса CTC после уменьшения. По умолчанию 0.0.
max_grad_norm (float, optional) – Максимальная норма градиента для клиппинга (предотвращает взрыв градиентов/NaN). По умолчанию 5.0.
batch_size (int, optional) – Размер батча обучения. По умолчанию 32.
epochs (int, optional) – Количество эпох обучения. По умолчанию 20.
lr (float, optional) – Скорость обучения. По умолчанию 1e-3.
optimizer ({"Adam", "SGD", "AdamW"}, optional) – Тип оптимизатора. По умолчанию
"AdamW".scheduler ({"ReduceLROnPlateau", "CosineAnnealingLR", "OneCycleLR", "None"}, optional) –
Тип планировщика скорости обучения:
"OneCycleLR"- one-cycle policy with cosine annealing (default, recommended)"ReduceLROnPlateau"- reduce LR on validation loss plateau"CosineAnnealingLR"- cosine annealing over epochs"None"orNone- constant learning rate
По умолчанию
"OneCycleLR".weight_decay (float, optional) – Коэффициент L2-регуляризации (weight decay). По умолчанию 0.0.
momentum (float, optional) – Параметр momentum для оптимизатора SGD. По умолчанию 0.9.
val_interval (int, optional) – Выполнять валидацию каждые N эпох. По умолчанию 1.
val_size (int, optional) – Максимальное количество примеров для валидации. По умолчанию 3000.
train_proportions (sequence of float, optional) – Доли выборки для нескольких обучающих датасетов. Должны суммироваться до 1.0 и соответствовать длине
train_csvs. ЕслиNone, датасеты объединяются с равными весами. По умолчаниюNone.num_workers (int, optional) – Количество рабочих процессов для загрузки данных. По умолчанию 0.
seed (int, optional) – Случайное зерно (seed) для воспроизводимости результатов. По умолчанию 42.
resume_from (str or Path, optional) – Путь к файлу чекпойнта для возобновления обучения. По умолчанию
None.save_interval (int, optional) – Сохранять чекпойнт каждые N эпох. Если
None, сохраняется только лучшая модель. По умолчаниюNone.device ({"cuda", "cpu"}, optional) – Устройство для обучения. По умолчанию
"cuda".freeze_cnn ({"none", "all", "first", "last"}, optional) – Политика заморозки CNN. По умолчанию
"none".freeze_enc_rnn ({"none", "all", "first", "last"}, optional) – Политика заморозки RNN-энкодера. По умолчанию
"none".freeze_attention ({"none", "all"}, optional) – Политика заморозки модуля внимания (attention). По умолчанию
"none".pretrain_weights (str, Path, bool, or None, optional) –
Предобученные веса для инициализации:
"default"orTrue- use release weightsNoneorFalse- train from scratchstr/Path - path or URL to custom weights file
По умолчанию
"default".**extra_config (dict, optional) – Дополнительные параметры конфигурации, передаваемые в настройки обучения.
- Результат:
Путь к лучшему чекпойнту модели, сохранённому в ходе обучения.
- Тип результата:
Примеры
Обучение на одном датасете с валидацией:
>>> from manuscript.recognizers import TRBA >>> >>> best_model = TRBA.train( ... train_csvs="data/train.csv", ... train_roots="data/train_images", ... val_csvs="data/val.csv", ... val_roots="data/val_images", ... exp_dir="./experiments/trba_exp1", ... epochs=50, ... batch_size=64, ... img_h=64, ... img_w=256, ... ) >>> print(f"Best model saved at: {best_model}")
Обучение на нескольких датасетах с пользовательскими пропорциями:
>>> train_csvs = ["data/dataset1/train.csv", "data/dataset2/train.csv"] >>> train_roots = ["data/dataset1/images", "data/dataset2/images"] >>> train_proportions = [0.7, 0.3] # 70% from dataset1, 30% from dataset2 >>> >>> best_model = TRBA.train( ... train_csvs=train_csvs, ... train_roots=train_roots, ... train_proportions=train_proportions, ... val_csvs="data/val.csv", ... val_roots="data/val_images", ... epochs=100, ... lr=5e-4, ... optimizer="AdamW", ... weight_decay=1e-4, ... )
Возобновление обучения из чекпойнта:
>>> best_model = TRBA.train( ... train_csvs="data/train.csv", ... train_roots="data/train_images", ... resume_from="experiments/trba_exp1/checkpoints/last.pth", ... epochs=100, ... )
Дообучение с предобученных весов с замороженным CNN:
>>> best_model = TRBA.train( ... train_csvs="data/finetune.csv", ... train_roots="data/finetune_images", ... pretrain_weights="default", ... freeze_cnn="all", ... epochs=20, ... lr=1e-4, ... )
Обучение с использованием CTC для стабильности (всегда включено):
>>> best_model = TRBA.train( ... train_csvs="data/train.csv", ... train_roots="data/train_images", ... optimizer="AdamW", ... scheduler="OneCycleLR", ... lr=1e-3, ... ctc_weight=0.3, ... ctc_weight_decay_epochs=50, ... max_grad_norm=5.0, ... epochs=100, ... )