import json
import os
import re
from pathlib import Path
from typing import Optional, Union
import numpy as np
import onnxruntime as ort
from manuscript.api.base import BaseModel
from manuscript.data import Page
[docs]
class CharLM(BaseModel):
"""
Character-level language model corrector using ONNX Runtime.
CharLM uses a Transformer-based masked language model to correct
OCR errors at the character level. It analyzes character confidence
and applies corrections based on learned substitution patterns.
Parameters
----------
weights : str or Path, optional
Path or identifier for ONNX model weights. Supports:
- Local file path: ``"path/to/model.onnx"``
- HTTP/HTTPS URL: ``"https://example.com/model.onnx"``
- GitHub release: ``"github://owner/repo/tag/file.onnx"``
- Google Drive: ``"gdrive:FILE_ID"``
- Preset name: ``"prereform_charlm_g1"`` or ``"modern_charlm_g1"`` (from pretrained_registry)
- ``None``: auto-downloads default preset (prereform_charlm_g1)
vocab : str or Path, optional
Path to vocabulary JSON file. If None, inferred from weights location.
lexicon : str, Path, or set, optional
Word list for dictionary-based validation. Supports:
- Local file path: ``"path/to/words.txt"``
- Preset name: ``"prereform_words"`` or ``"modern_words"`` (from lexicon_registry)
- Python set: ``{"word1", "word2", ...}``
- ``None``: auto-downloads default lexicon for model preset
(prereform_words for prereform_charlm_g1, modern_words for modern_charlm_g1)
device : {"cuda", "cpu"}, optional
Compute device. Default is auto-detected.
mask_threshold : float, optional
Confidence threshold below which characters are considered for correction.
Default is 0.05.
apply_threshold : float, optional
Minimum model confidence required to apply a correction. Default is 0.95.
max_edits : int, optional
Maximum number of edits per word. Default is 2.
min_word_len : int, optional
Minimum word length to attempt correction. Default is 4.
**kwargs
Additional configuration options.
Examples
--------
>>> from manuscript.correctors import CharLM
>>> corrector = CharLM()
>>> corrected_page = corrector.predict(page)
"""
default_weights_name = "prereform_charlm_g1"
pretrained_registry = {
"prereform_charlm_g1": "https://github.com/konstantinkozhin/manuscript-ocr/releases/download/v0.1.0/prereform_charlm_g1.onnx",
"modern_charlm_g1": "https://github.com/konstantinkozhin/manuscript-ocr/releases/download/v0.1.0/modern_charlm_g1.onnx",
}
vocab_registry = {
"prereform_charlm_g1": "https://github.com/konstantinkozhin/manuscript-ocr/releases/download/v0.1.0/prereform_charlm_g1.json",
"modern_charlm_g1": "https://github.com/konstantinkozhin/manuscript-ocr/releases/download/v0.1.0/modern_charlm_g1.json",
}
lexicon_registry = {
"prereform_words": "https://github.com/konstantinkozhin/manuscript-ocr/releases/download/v0.1.0/prereform_words.txt",
"modern_words": "https://github.com/konstantinkozhin/manuscript-ocr/releases/download/v0.1.0/modern_words.txt",
}
default_lexicon_for_model = {
"prereform_charlm_g1": "prereform_words",
"modern_charlm_g1": "modern_words",
}
[docs]
def __init__(
self,
weights: Optional[Union[str, Path]] = None,
vocab: Optional[Union[str, Path]] = None,
lexicon: Optional[Union[str, Path, set]] = None,
device: Optional[str] = None,
mask_threshold: float = 0.05,
apply_threshold: float = 0.95,
max_edits: int = 2,
min_word_len: int = 4,
max_len: int = 32,
**kwargs,
):
if weights is None and self.default_weights_name is not None:
weights = self.default_weights_name
# Remember original weights name for lexicon resolution
self._weights_preset = weights if weights in self.pretrained_registry else None
if weights is None:
self.device = device or "cpu"
self.weights = None
self.extra_config = kwargs
self.session = None
else:
super().__init__(weights=weights, device=device, **kwargs)
self.mask_threshold = mask_threshold
self.apply_threshold = apply_threshold
self.max_edits = max_edits
self.min_word_len = min_word_len
self.max_len = max_len
self._word_pattern = re.compile(r"(\w+)|(\W+)", re.UNICODE)
self.vocab_path = self._resolve_vocab(vocab) if self.weights else None
self.c2i = {}
self.i2c = {}
self.lexicon = None
if lexicon is not None:
if isinstance(lexicon, set):
self.lexicon = frozenset(w.lower() for w in lexicon)
else:
lexicon_path = self._resolve_lexicon(lexicon)
if lexicon_path:
self._load_lexicon(lexicon_path)
elif (
self._weights_preset
and self._weights_preset in self.default_lexicon_for_model
):
# Auto-load default lexicon for known model presets
default_lexicon_name = self.default_lexicon_for_model[self._weights_preset]
lexicon_path = self._resolve_lexicon(default_lexicon_name)
if lexicon_path:
self._load_lexicon(lexicon_path)
self.onnx_session = None
if self.weights and self.vocab_path:
self._load_vocab()
def _resolve_vocab(self, vocab: Optional[str]) -> Optional[str]:
"""Resolve vocab path, inferring from weights location if needed."""
if vocab is not None:
if Path(vocab).exists():
return str(Path(vocab).absolute())
if vocab in self.vocab_registry:
return self._resolve_extra_artifact(
vocab,
default_name=None,
registry=self.vocab_registry,
description="vocab",
)
# Use actual weights preset (if any), otherwise fall back to default
preset_to_use = self._weights_preset or self.default_weights_name
if preset_to_use and preset_to_use in self.vocab_registry:
return self._resolve_extra_artifact(
preset_to_use,
default_name=None,
registry=self.vocab_registry,
description="vocab",
)
if self.weights:
weights_path = Path(self.weights)
vocab_candidate = weights_path.parent / "vocab.json"
if vocab_candidate.exists():
return str(vocab_candidate.absolute())
return None
def _resolve_lexicon(self, lexicon: str) -> Optional[str]:
"""Resolve lexicon path from registry or local file."""
if Path(lexicon).exists():
return str(Path(lexicon).absolute())
if lexicon in self.lexicon_registry:
return self._resolve_extra_artifact(
lexicon,
default_name=None,
registry=self.lexicon_registry,
description="lexicon",
)
return None
def _load_vocab(self):
"""Load vocabulary from JSON file."""
if not self.vocab_path or not Path(self.vocab_path).exists():
return
with open(self.vocab_path, "r", encoding="utf-8") as f:
chars = json.load(f)
self.c2i = {c: i for i, c in enumerate(chars)}
self.i2c = {i: c for c, i in self.c2i.items()}
def _load_lexicon(self, lexicon_path: str):
"""Load lexicon (word list) from text file."""
if not Path(lexicon_path).exists():
return
with open(lexicon_path, "r", encoding="utf-8") as f:
words = set(line.strip().lower() for line in f if line.strip())
self.lexicon = frozenset(words)
def _initialize_session(self):
"""Initialize ONNX Runtime session (lazy loading)."""
if self.onnx_session is not None:
return
if self.weights is None:
raise ValueError("No weights provided for CharLM corrector")
providers = self.runtime_providers()
self.onnx_session = ort.InferenceSession(str(self.weights), providers=providers)
self._log_device_info(self.onnx_session)
[docs]
def predict(self, page: Page) -> Page:
"""
Apply character-level correction to a Page.
Parameters
----------
page : Page
Input Page object with recognized text.
Returns
-------
Page
Corrected Page object with updated word texts.
"""
if self.weights is None or not self.c2i:
return page.model_copy(deep=True)
if self.onnx_session is None:
self._initialize_session()
result = page.model_copy(deep=True)
for block in result.blocks:
for line in block.lines:
for word in line.words:
if word.text:
corrected = self._correct_word(word.text)
word.text = corrected
return result
def _correct_word(self, text: str) -> str:
"""Correct a single word using the CharLM model."""
tokens = []
for m in self._word_pattern.finditer(text):
word_part, other_part = m.groups()
if word_part:
tokens.append((word_part, True))
else:
tokens.append((other_part, False))
result_parts = []
for token, is_word in tokens:
if not is_word:
result_parts.append(token)
continue
word_lower = token.lower()
if len(word_lower) < self.min_word_len:
result_parts.append(token)
continue
if self.lexicon and word_lower in self.lexicon:
result_parts.append(token)
continue
corrected = self._correct_single_word(word_lower)
if corrected != word_lower:
corrected = self._restore_case(token, corrected)
else:
corrected = token
result_parts.append(corrected)
return "".join(result_parts)
def _correct_single_word(self, word: str) -> str:
"""Apply MLM-based correction to a single lowercase word."""
chars = list(word[: self.max_len])
L = len(chars)
if L == 0:
return word
# Use 0 as fallback for unknown tokens (safer than potentially out-of-bounds unk)
unk = self.c2i.get("<UNK>", 0)
mask = self.c2i.get("<MASK>", 1)
pad = self.c2i.get("<PAD>", 0)
batch = []
for i in range(L):
ids = [
(self.c2i.get(ch, unk) if j != i else mask)
for j, ch in enumerate(chars)
]
ids += [pad] * (self.max_len - len(ids))
batch.append(ids)
batch_array = np.array(batch, dtype=np.int64)
input_name = self.onnx_session.get_inputs()[0].name
output_name = self.onnx_session.get_outputs()[0].name
try:
logits = self.onnx_session.run([output_name], {input_name: batch_array})[0]
except Exception as e:
# If ONNX inference fails (e.g., vocab mismatch), log warning and return original word
import warnings
warnings.warn(
f"Corrector inference error for word '{word}': {e}. "
"Returning original text. This may indicate a vocab/weights mismatch.",
RuntimeWarning,
)
return word
probs = self._softmax(logits, axis=-1)
vocab_size = probs.shape[-1]
# Track which positions have unknown characters (should not be corrected)
unknown_positions = set()
for i, ch in enumerate(chars):
if ch not in self.c2i or self.c2i[ch] >= vocab_size:
unknown_positions.add(i)
confidences = []
for i in range(L):
# Skip unknown characters
if i in unknown_positions:
continue
char_id = self.c2i[chars[i]]
p_cur = probs[i, i, char_id]
prob_vec = probs[i, i]
confidences.append((i, p_cur, prob_vec))
candidates = sorted(
[(i, p, v) for i, p, v in confidences if p < self.mask_threshold],
key=lambda x: x[1],
)
edits = 0
for i, p_cur, prob_vec in candidates:
if edits >= self.max_edits:
break
best_id = int(np.argmax(prob_vec))
best_p = float(prob_vec[best_id])
best_char = self.i2c.get(best_id, "<UNK>")
if best_char in ("<UNK>", "<PAD>", "<MASK>"):
applied = False
elif best_char == chars[i]:
applied = False
elif best_p < self.apply_threshold:
applied = False
elif best_char.lower() != best_char or chars[i].lower() != chars[i]:
if best_char.lower() == chars[i].lower():
applied = False
else:
test_chars = chars.copy()
test_chars[i] = best_char
test_word = "".join(test_chars)
if self.lexicon and test_word in self.lexicon:
applied = True
elif self.lexicon:
applied = False
else:
applied = True
else:
test_chars = chars.copy()
test_chars[i] = best_char
test_word = "".join(test_chars)
if self.lexicon and test_word in self.lexicon:
applied = True
elif self.lexicon:
applied = False
else:
applied = True
if applied:
chars[i] = best_char
edits += 1
return "".join(chars)
def _restore_case(self, original: str, corrected: str) -> str:
"""Restore original case pattern to corrected word."""
result = []
for i, ch in enumerate(corrected):
if i < len(original) and original[i].isupper():
result.append(ch.upper())
else:
result.append(ch)
return "".join(result)
@staticmethod
def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
"""Compute softmax along axis."""
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
[docs]
@staticmethod
def train(
words_path: Optional[Union[str, Path]] = None,
text_path: Optional[Union[str, Path]] = None,
pairs_path: Optional[Union[str, Path]] = None,
charset_path: Optional[Union[str, Path]] = None,
*,
exp_dir: str = "exp_charlm",
max_words: int = 1_500_000,
max_pairs_edits: int = 3,
max_len: int = 32,
emb_size: int = 192,
n_layers: int = 8,
n_heads: int = 6,
ffn_size: int = 1024,
dropout: float = 0.1,
batch_size: int = 256,
accumulation_steps: int = 2,
use_amp: bool = True,
compile_model: bool = False,
epochs: int = 50,
lr: float = 1e-3,
weight_decay: float = 0.01,
grad_clip: float = 1.0,
min_len: int = 3,
mask_prob: float = 0.3,
span_min: int = 1,
span_max: int = 3,
spans_min: int = 1,
spans_max: int = 2,
pairs_ratio: float = 0.8,
eval_ratio: float = 0.01,
seed: int = 42,
checkpoint: Optional[str] = None,
**extra_config,
) -> str:
"""
Train CharLM character-level language model.
Parameters
----------
words_path : str or Path, optional
Path to words file (one word per line).
text_path : str or Path, optional
Path to text file for n-gram dataset.
pairs_path : str or Path, optional
Path to CSV file with incorrect/correct pairs.
charset_path : str or Path
Path to charset file (allowed characters).
exp_dir : str, optional
Experiment directory. Default is "exp_charlm".
max_words : int, optional
Maximum words to use from words file. Default is 1_500_000.
max_pairs_edits : int, optional
Maximum number of character edits in pairs to include. Default is 3.
max_len : int, optional
Maximum sequence length. Default is 32.
emb_size : int, optional
Embedding size. Default is 192.
n_layers : int, optional
Number of transformer layers. Default is 8.
n_heads : int, optional
Number of attention heads. Default is 6.
ffn_size : int, optional
Feed-forward network size. Default is 1024.
dropout : float, optional
Dropout rate. Default is 0.1.
batch_size : int, optional
Batch size. Default is 256.
accumulation_steps : int, optional
Gradient accumulation steps. Default is 2.
use_amp : bool, optional
Use automatic mixed precision (AMP). Default is True.
compile_model : bool, optional
Use torch.compile for faster training. Default is False.
epochs : int, optional
Number of epochs. Default is 50.
lr : float, optional
Learning rate. Default is 1e-3.
weight_decay : float, optional
Weight decay. Default is 0.01.
grad_clip : float, optional
Gradient clipping. Default is 1.0.
min_len : int, optional
Minimum word length. Default is 3.
mask_prob : float, optional
Probability of using span masking. Default is 0.3.
span_min : int, optional
Minimum span length for masking. Default is 1.
span_max : int, optional
Maximum span length for masking. Default is 3.
spans_min : int, optional
Minimum number of spans. Default is 1.
spans_max : int, optional
Maximum number of spans. Default is 2.
pairs_ratio : float, optional
Ratio of OCR pairs in mixed dataset (0.8 = 80% pairs, 20% ngrams). Default is 0.8.
eval_ratio : float, optional
Evaluation set ratio. Default is 0.01.
seed : int, optional
Random seed. Default is 42.
checkpoint : str, optional
Path to checkpoint to resume from.
**extra_config
Additional config options.
Returns
-------
str
Path to the final checkpoint.
"""
from .train import train as run_training
from .config import DEFAULT_CONFIG
config = {
**DEFAULT_CONFIG,
"exp_dir": exp_dir,
"words_path": words_path,
"text_path": text_path,
"pairs_path": pairs_path,
"charset_path": charset_path,
"max_words": max_words,
"max_pairs_edits": max_pairs_edits,
"max_len": max_len,
"emb_size": emb_size,
"n_layers": n_layers,
"n_heads": n_heads,
"ffn_size": ffn_size,
"dropout": dropout,
"batch_size": batch_size,
"accumulation_steps": accumulation_steps,
"use_amp": use_amp,
"compile_model": compile_model,
"epochs": epochs,
"lr": lr,
"weight_decay": weight_decay,
"grad_clip": grad_clip,
"min_len": min_len,
"mask_prob": mask_prob,
"span_min": span_min,
"span_max": span_max,
"spans_min": spans_min,
"spans_max": spans_max,
"pairs_ratio": pairs_ratio,
"eval_ratio": eval_ratio,
"seed": seed,
"checkpoint": checkpoint,
**extra_config,
}
run_training(config)
return os.path.join(exp_dir, "checkpoints", f"charlm_epoch_{epochs}.pt")
[docs]
@staticmethod
def export(
weights_path: Union[str, Path],
vocab_path: Union[str, Path],
output_path: Union[str, Path],
max_len: int = 32,
emb_size: int = 192,
n_layers: int = 8,
n_heads: int = 6,
ffn_size: int = 1024,
opset_version: int = 14,
simplify: bool = True,
) -> None:
"""
Export CharLM PyTorch model to ONNX format.
Parameters
----------
weights_path : str or Path
Path to PyTorch checkpoint (.pt file).
vocab_path : str or Path
Path to vocabulary JSON file.
output_path : str or Path
Path to save ONNX model.
max_len : int, optional
Maximum sequence length. Default is 32.
emb_size : int, optional
Embedding size. Default is 192.
n_layers : int, optional
Number of transformer layers. Default is 8.
n_heads : int, optional
Number of attention heads. Default is 6.
ffn_size : int, optional
Feed-forward network size. Default is 1024.
opset_version : int, optional
ONNX opset version. Default is 14.
simplify : bool, optional
Apply ONNX simplification. Default is True.
"""
import torch
from .model import CharTransformerMLM
weights_path = Path(weights_path)
vocab_path = Path(vocab_path)
output_path = Path(output_path)
if not weights_path.exists():
raise FileNotFoundError(f"Weights not found: {weights_path}")
if not vocab_path.exists():
raise FileNotFoundError(f"Vocab not found: {vocab_path}")
print(f"Loading vocab from {vocab_path}...")
with open(vocab_path, "r", encoding="utf-8") as f:
chars = json.load(f)
vocab_size = len(chars)
c2i = {c: i for i, c in enumerate(chars)}
pad_idx = c2i.get("<PAD>", 0)
print(f"Vocab size: {vocab_size}")
print(f"Loading checkpoint from {weights_path}...")
checkpoint = torch.load(str(weights_path), map_location="cpu")
if isinstance(checkpoint, dict) and "model" in checkpoint:
state_dict = checkpoint["model"]
else:
state_dict = checkpoint
model = CharTransformerMLM(
vocab_size=vocab_size,
emb_size=emb_size,
max_len=max_len,
n_layers=n_layers,
n_heads=n_heads,
ffn_size=ffn_size,
dropout=0.0,
pad_idx=pad_idx,
)
model.load_state_dict(state_dict)
model.eval()
print(f"\n=== CharLM ONNX Export ===")
print(f"Max length: {max_len}")
print(f"Embedding size: {emb_size}")
print(f"Layers: {n_layers}")
print(f"Heads: {n_heads}")
dummy_input = torch.zeros(1, max_len, dtype=torch.long)
output_path.parent.mkdir(parents=True, exist_ok=True)
print(f"\nExporting to ONNX...")
torch.onnx.export(
model,
dummy_input,
str(output_path),
input_names=["input"],
output_names=["logits"],
dynamic_axes={
"input": {0: "batch_size"},
"logits": {0: "batch_size"},
},
opset_version=opset_version,
do_constant_folding=True,
)
print(f"[OK] ONNX model saved to: {output_path}")
if simplify:
try:
import onnx
from onnxsim import simplify as onnx_simplify
print("Simplifying ONNX model...")
onnx_model = onnx.load(str(output_path))
simplified, check = onnx_simplify(onnx_model)
if check:
onnx.save(simplified, str(output_path))
print("[OK] Model simplified successfully")
else:
print("[WARN] Simplification check failed, keeping original")
except ImportError:
print("[SKIP] onnxsim not installed, skipping simplification")
except Exception as e:
print(f"[WARN] Simplification failed: {e}")
print(f"\nExport complete: {output_path}")