Source code for pyaota.ocr.digit_ocr
"""
MNIST-style digit OCR using a CNN.
"""
from __future__ import annotations
from importlib.resources import files
from pathlib import Path
from typing import Tuple, Optional
import numpy as np
import cv2
import tensorflow as tf
load_model = tf.keras.models.load_model
_MODEL: Optional[tf.keras.Model] = None
[docs]
def load_digit_model(model_path: str | Path | None = None) -> tf.keras.Model:
"""
Load (or reuse) the MNIST-style digit classifier.
If no model path is provided, use the default model from the package.
"""
global _MODEL
if _MODEL is not None:
return _MODEL
data_dir = files("pyaota") / "data"
models_dir = data_dir / "models"
package_model_path = models_dir / "mnist_digit_cnn.keras"
if model_path is None:
model_path = package_model_path
else:
model_path = Path(model_path)
if not model_path.exists():
raise FileNotFoundError(
f"Digit model not found at {model_path}. "
f"Train it with train_mnist_digit_model.py first."
)
_MODEL = load_model(model_path)
return _MODEL
[docs]
def preprocess_digit_crop(
img_gray: np.ndarray,
target_size: Tuple[int, int] = (28, 28),
) -> np.ndarray:
"""
Preprocess a grayscale digit crop for the CNN:
- ensure grayscale
- threshold to binary (digit strokes dark)
- resize to target_size
- normalize to [0, 1]
- shape: (1, H, W, 1)
"""
if img_gray.ndim == 3:
img_gray = cv2.cvtColor(img_gray, cv2.COLOR_BGR2GRAY)
# Otsu threshold, invert so digit is white on black or vice versa is okay
_, th = cv2.threshold(
img_gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU
)
resized = cv2.resize(
th,
target_size,
interpolation=cv2.INTER_AREA,
)
# Normalize
norm = resized.astype("float32") / 255.0
norm = norm[..., None] # add channel dim
batch = np.expand_dims(norm, axis=0)
return batch
[docs]
def ocr_digit_nn(
img_gray: np.ndarray,
model: Optional[tf.keras.Model] = None,
target_size: Tuple[int, int] = (28, 28),
) -> Tuple[Optional[str], float]:
"""
Run the CNN digit classifier on a single grayscale crop.
Returns
-------
(digit_str, confidence)
digit_str: '0'..'9' or None if unreadable
confidence: softmax probability of the predicted class (0..1)
"""
if model is None:
model = load_digit_model()
batch = preprocess_digit_crop(img_gray, target_size=target_size)
preds = model.predict(batch, verbose=0)[0] # shape (10,)
pred_idx = int(np.argmax(preds))
confidence = float(preds[pred_idx])
# You can set a minimum confidence threshold later (e.g., 0.6)
return str(pred_idx), confidence