Распознаватели
Модели распознавания текста.
- class manuscript.recognizers.TRBA(weights=None, config=None, charset=None, device=None, **kwargs)[исходный код]
Базовые классы:
BaseModelИнициализация модели распознавания текста TRBA с использованием ONNX Runtime.
- Параметры:
weights (str or Path, optional) – Путь или идентификатор весов ONNX-модели. Поддерживаются: - локальный путь к файлу:
"path/to/model.onnx"- HTTP/HTTPS URL:"https://example.com/model.onnx"- релиз GitHub:"github://owner/repo/tag/file.onnx"- Google Drive:"gdrive:FILE_ID"- имя пресета:"trba_lite_g1"или"trba_base_g1"(из pretrained_registry) -None: автоматическая загрузка пресета по умолчанию (trba_lite_g1)config (str or Path, optional) – Путь или идентификатор JSON-конфигурации модели. Поддерживает те же схемы, что и
weights. ЕслиNone, конфигурация определяется по расположению весов или используется конфигурация по умолчанию для пресетов.charset (str or Path, optional) – Путь или идентификатор файла набора символов. Если
None, набор символов ищется рядом с весами или используется вариант по умолчанию из пакета.device ({"cuda", "coreml", "cpu"}, optional) – Устройство вычислений. Если
None, автоматически выбирается CPU. Для ускорения GPU/CoreML: - CUDA (NVIDIA):pip install onnxruntime-gpu- CoreML (Apple Silicon M1/M2/M3):pip install onnxruntime-siliconПо умолчаниюNone(CPU).**kwargs – Дополнительные параметры конфигурации (зарезервированы для будущего использования).
- Исключение:
FileNotFoundError – Если указанные файлы не существуют.
ValueError – Если формат весов некорректен.
Заметки
Класс предоставляет три основных публичных метода:
predict— запуск распознавания текста на обрезанных изображениях слов.train— высокоуровневая точка входа для обучения модели TRBA на пользовательских датасетах.export— статический метод экспорта модели PyTorch в формат ONNX.
Модель использует ONNX Runtime для быстрого инференса на CPU и GPU. Для ускорения на GPU установите:
pip install onnxruntime-gpuПримеры
Создание распознавателя с пресетом по умолчанию (автозагрузка):
>>> from manuscript.recognizers import TRBA >>> recognizer = TRBA()
Загрузка из локального ONNX-файла:
>>> recognizer = TRBA(weights="path/to/model.onnx")
Загрузка из релиза GitHub:
>>> recognizer = TRBA( ... weights="github://owner/repo/v1.0/model.onnx", ... config="github://owner/repo/v1.0/config.json" ... )
Принудительное выполнение на CPU:
>>> recognizer = TRBA(weights="model.onnx", device="cpu")
Методы
__call__(*args, **kwargs)Call self as a function.
export(weights_path, config_path, ...[, ...])Export TRBA PyTorch model to ONNX format.
predict(images[, batch_size])Запуск распознавания текста на одном или нескольких изображениях слов.
runtime_providers()Get ONNX Runtime execution providers based on device.
train(train_csvs, train_roots[, val_csvs, ...])Train TRBA text recognition model on custom datasets.
-
pretrained_registry:
Dict[str,str] = {'trba_base_g1': 'github://konstantinkozhin/manuscript-ocr/v0.1.0/trba_base_g1.onnx', 'trba_lite_g1': 'github://konstantinkozhin/manuscript-ocr/v0.1.0/trba_lite_g1.onnx'}
- config_registry = {'trba_base_g1': 'github://konstantinkozhin/manuscript-ocr/v0.1.0/trba_base_g1.json', 'trba_lite_g1': 'github://konstantinkozhin/manuscript-ocr/v0.1.0/trba_lite_g1.json'}
- charset_registry = {'trba_base_g1': 'github://konstantinkozhin/manuscript-ocr/v0.1.0/trba_base_g1.txt', 'trba_lite_g1': 'github://konstantinkozhin/manuscript-ocr/v0.1.0/trba_lite_g1.txt'}
- __init__(weights=None, config=None, charset=None, device=None, **kwargs)[исходный код]
- predict(images, batch_size=32)[исходный код]
Запуск распознавания текста на одном или нескольких изображениях слов.
- Параметры:
images (str, Path, numpy.ndarray, PIL.Image, or list thereof) –
Single image or list of images to recognize. Each image can be:
Path to image file (str or Path)
RGB numpy array with shape
(H, W, 3)inuint8PIL Image object
batch_size (int, optional) – Количество изображений, обрабатываемых одновременно. Большие батчи работают быстрее, но требуют больше памяти. По умолчанию 32.
- Результат:
Результаты распознавания в виде списка словарей, каждый из которых содержит: -
"text": str — распознанный текст -"confidence": float — уверенность распознавания в диапазоне [0, 1]. Если входом является одно изображение, возвращается список из одного элемента.- Тип результата:
Примеры
Распознавание одного изображения:
>>> from manuscript.recognizers import TRBA >>> recognizer = TRBA() >>> results = recognizer.predict("word_image.jpg") >>> print(f"Text: '{results[0]['text']}' (confidence: {results[0]['confidence']:.3f})")
Обработка массивов numpy:
>>> import cv2 >>> img = cv2.imread("word.jpg") >>> img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) >>> results = recognizer.predict(img_rgb) >>> print(results[0]["text"])
- static train(train_csvs, train_roots, val_csvs=None, val_roots=None, *, exp_dir=None, charset_path=None, encoding='utf-8', img_h=64, img_w=256, max_len=25, hidden_size=256, num_encoder_layers=3, cnn_in_channels=3, cnn_out_channels=512, cnn_backbone='seresnet31', ctc_weight=0.3, ctc_weight_decay_epochs=50, ctc_weight_min=0.0, max_grad_norm=5.0, batch_size=32, epochs=20, lr=0.001, optimizer='AdamW', scheduler='OneCycleLR', weight_decay=0.0, momentum=0.9, val_interval=1, val_size=3000, train_proportions=None, num_workers=0, seed=42, resume_from=None, save_interval=None, device='cuda', freeze_cnn='none', freeze_enc_rnn='none', freeze_attention='none', pretrain_weights='default', **extra_config)[исходный код]
Train TRBA text recognition model on custom datasets.
- Параметры:
train_csvs (str, Path or sequence of paths) – Путь или пути к CSV-файлам обучения. Каждый CSV должен содержать столбцы:
image_path(относительноtrain_roots) иtext(эталонная транскрипция).train_roots (str, Path or sequence of paths) – Корневая директория или директории, содержащие обучающие изображения. Должны иметь ту же длину, что и
train_csvs.val_csvs (str, Path, sequence of paths, or None, optional) – Путь или пути к CSV-файлам валидации с тем же форматом, что и
train_csvs. ЕслиNone, валидация не выполняется. По умолчаниюNone.val_roots (str, Path, sequence of paths, or None, optional) – Корневая директория или директории для изображений валидации. Должны соответствовать длине
val_csvs, если они указаны. По умолчаниюNone.exp_dir (str or Path, optional) – Директория эксперимента, в которую будут сохраняться чекпойнты и логи. Если
None, автоматически создаётся на основе временной метки. По умолчаниюNone.charset_path (str or Path, optional) – Путь к файлу набора символов. Если
None, используется набор символов по умолчанию из пакета. По умолчаниюNone.encoding (str, optional) – Кодировка текста для чтения CSV-файлов. По умолчанию
"utf-8".img_h (int, optional) – Целевая высота входных изображений (в пикселях). По умолчанию 64.
img_w (int, optional) – Целевая ширина входных изображений (в пикселях). По умолчанию 256.
max_len (int, optional) – Максимальная длина последовательности для распознавания текста. По умолчанию 25.
hidden_size (int, optional) – Размер скрытого слоя RNN-энкодера/декодера. По умолчанию 256.
num_encoder_layers (int, optional) – Количество двунаправленных LSTM-слоёв в энкодере. По умолчанию 2.
cnn_in_channels (int, optional) – Количество входных каналов для CNN-бэкбона (3 для RGB, 1 для оттенков серого). По умолчанию 3.
cnn_out_channels (int, optional) – Количество выходных каналов CNN-бэкбона. По умолчанию 512.
cnn_backbone ({"seresnet31", "seresnet31-lite"}, optional) – Вариант CNN-бэкбона.
"seresnet31"использует стандартный SE-ResNet-31, а"seresnet31-lite"включает облегчённую depthwise-версию. По умолчанию"seresnet31".ctc_weight (float, optional) – Начальный вес CTC-функции потерь во время обучения (CTC всегда используется для стабильности):
loss = attn_loss * (1 - ctc_weight) + ctc_loss * ctc_weight. Вес CTC уменьшается по эпохам. По умолчанию 0.3.ctc_weight_decay_epochs (int, optional) – Количество эпох, за которые вес CTC уменьшается до минимума. По умолчанию 50.
ctc_weight_min (float, optional) – Минимальное значение веса CTC после уменьшения. По умолчанию 0.0.
max_grad_norm (float, optional) – Максимальная норма градиента для клиппинга (предотвращает взрыв градиентов/NaN). По умолчанию 5.0.
batch_size (int, optional) – Размер батча обучения. По умолчанию 32.
epochs (int, optional) – Количество эпох обучения. По умолчанию 20.
lr (float, optional) – Скорость обучения. По умолчанию 1e-3.
optimizer ({"Adam", "SGD", "AdamW"}, optional) – Тип оптимизатора. По умолчанию
"AdamW".scheduler ({"ReduceLROnPlateau", "CosineAnnealingLR", "OneCycleLR", "None"}, optional) – Тип планировщика скорости обучения: -
"OneCycleLR"— one-cycle политика с косинусным затуханием (по умолчанию, рекомендуется) -"ReduceLROnPlateau"— уменьшение LR при плато валидационной ошибки -"CosineAnnealingLR"— косинусное затухание по эпохам -"None"илиNone— постоянная скорость обучения По умолчанию"OneCycleLR".weight_decay (float, optional) – Коэффициент L2-регуляризации (weight decay). По умолчанию 0.0.
momentum (float, optional) – Параметр momentum для оптимизатора SGD. По умолчанию 0.9.
val_interval (int, optional) – Выполнять валидацию каждые N эпох. По умолчанию 1.
val_size (int, optional) – Максимальное количество примеров для валидации. По умолчанию 3000.
train_proportions (sequence of float, optional) – Доли выборки для нескольких обучающих датасетов. Должны суммироваться до 1.0 и соответствовать длине
train_csvs. ЕслиNone, датасеты объединяются с равными весами. По умолчаниюNone.num_workers (int, optional) – Количество рабочих процессов для загрузки данных. По умолчанию 0.
seed (int, optional) – Случайное зерно (seed) для воспроизводимости результатов. По умолчанию 42.
resume_from (str or Path, optional) – Путь к файлу чекпойнта для возобновления обучения. По умолчанию
None.save_interval (int, optional) – Сохранять чекпойнт каждые N эпох. Если
None, сохраняется только лучшая модель. По умолчаниюNone.device ({"cuda", "cpu"}, optional) – Устройство для обучения. По умолчанию
"cuda".freeze_cnn ({"none", "all", "first", "last"}, optional) – Политика заморозки CNN. По умолчанию
"none".freeze_enc_rnn ({"none", "all", "first", "last"}, optional) – Политика заморозки RNN-энкодера. По умолчанию
"none".freeze_attention ({"none", "all"}, optional) – Политика заморозки модуля внимания (attention). По умолчанию
"none".pretrain_weights (str, Path, bool, or None, optional) – Предобученные веса для инициализации: -
"default"илиTrue— использовать релизные веса -NoneилиFalse— обучение с нуля - str/Path — путь или URL к пользовательскому файлу весов По умолчанию"default".**extra_config (dict, optional) – Дополнительные параметры конфигурации, передаваемые в настройки обучения.
- Результат:
Путь к лучшему чекпойнту модели, сохранённому в ходе обучения.
- Тип результата:
Примеры
Обучение на одном датасете с валидацией:
>>> from manuscript.recognizers import TRBA >>> >>> best_model = TRBA.train( ... train_csvs="data/train.csv", ... train_roots="data/train_images", ... val_csvs="data/val.csv", ... val_roots="data/val_images", ... exp_dir="./experiments/trba_exp1", ... epochs=50, ... batch_size=64, ... img_h=64, ... img_w=256, ... ) >>> print(f"Best model saved at: {best_model}")
Обучение на нескольких датасетах с пользовательскими пропорциями:
>>> train_csvs = ["data/dataset1/train.csv", "data/dataset2/train.csv"] >>> train_roots = ["data/dataset1/images", "data/dataset2/images"] >>> train_proportions = [0.7, 0.3] # 70% from dataset1, 30% from dataset2 >>> >>> best_model = TRBA.train( ... train_csvs=train_csvs, ... train_roots=train_roots, ... train_proportions=train_proportions, ... val_csvs="data/val.csv", ... val_roots="data/val_images", ... epochs=100, ... lr=5e-4, ... optimizer="AdamW", ... weight_decay=1e-4, ... )
Возобновление обучения из чекпойнта:
>>> best_model = TRBA.train( ... train_csvs="data/train.csv", ... train_roots="data/train_images", ... resume_from="experiments/trba_exp1/checkpoints/last.pth", ... epochs=100, ... )
Дообучение с предобученных весов с замороженным CNN:
>>> best_model = TRBA.train( ... train_csvs="data/finetune.csv", ... train_roots="data/finetune_images", ... pretrain_weights="default", ... freeze_cnn="all", ... epochs=20, ... lr=1e-4, ... )
Обучение с использованием CTC для стабильности (всегда включено):
>>> best_model = TRBA.train( ... train_csvs="data/train.csv", ... train_roots="data/train_images", ... optimizer="AdamW", ... scheduler="OneCycleLR", ... lr=1e-3, ... ctc_weight=0.3, ... ctc_weight_decay_epochs=50, ... max_grad_norm=5.0, ... epochs=100, ... )
- static export(weights_path, config_path, charset_path, output_path, opset_version=14, simplify=True)[исходный код]
Export TRBA PyTorch model to ONNX format.
Этот метод преобразует обученную модель TRBA из PyTorch в формат ONNX, который может использоваться для более быстрого инференса с ONNX Runtime. Экспортированная модель может быть загружена с помощью
TRBA(weights="model.onnx").- Параметры:
weights_path (str or Path) – Путь к файлу весов модели PyTorch (.pth).
config_path (str or Path) – Путь к JSON-файлу конфигурации модели. Используется для определения архитектуры модели (img_h, img_w, max_len, hidden_size и т. д.).
charset_path (str or Path) – Путь к файлу набора символов (charset.txt). Используется для определения количества классов модели.
output_path (str or Path) – Путь, по которому будет сохранена ONNX-модель (.onnx).
opset_version (int, optional) – Версия opset ONNX для экспорта. По умолчанию 14.
simplify (bool, optional) – Если True, применяется упрощение графа ONNX с помощью onnx-simplifier для оптимизации модели. Требуется пакет
onnx-simplifier. По умолчанию True.
- Результат:
ONNX-модель сохраняется в
output_path.- Тип результата:
None
- Исключение:
ImportError – Если необходимые пакеты (torch, onnx) не установлены.
FileNotFoundError – Если
weights_pathилиconfig_pathне существуют.
Заметки
Экспортированная ONNX-модель имеет один выход:
logits: предсказания символов с формой(batch, max_length+1, num_classes)
Модель использует жадное декодирование (argmax) и поддерживает динамический размер батча. Длина последовательности зафиксирована как
max_length + 1из конфигурации (аналогично режиму инференса PyTorch для совместимости).Экспортируемая архитектура: - CNN-бэкбон (SE-ResNet-31 или SE-ResNet-31-Lite) - двунаправленный LSTM-энкодер - декодер внимания (жадное декодирование)
Примечание: экспортируется только декодер внимания. CTC-голова используется только во время обучения и не включена в ONNX-модель.
Примеры
Экспорт модели TRBA в ONNX:
>>> from manuscript.recognizers import TRBA >>> TRBA.export( ... weights_path="experiments/best_model/best_acc_weights.pth", ... config_path="experiments/best_model/config.json", ... charset_path="configs/charset.txt", ... output_path="trba_model.onnx" ... ) Loading TRBA model... === TRBA ONNX Export === Max decoding length: 40 Input size: 64x256 [OK] ONNX model saved to: trba_model.onnx
Экспорт с пользовательским opset:
>>> TRBA.export( ... weights_path="model.pth", ... config_path="config.json", ... charset_path="charset.txt", ... output_path="model.onnx", ... opset_version=16, ... simplify=False ... )
Использование экспортированной модели для инференса:
>>> recognizer = TRBA(weights="trba_model.onnx") >>> result = recognizer.predict("word_image.jpg")
См. также
TRBA.__init__Инициализация распознавателя TRBA с ONNX-моделью.