Source code for pyaota.ocr.train_mnist_digit_model
# train_mnist_digit_model.py
import tensorflow as tf
layers = tf.keras.layers
models = tf.keras.models
from pathlib import Path
[docs]
def build_cnn_model(input_shape=(28, 28, 1)) -> tf.keras.Model:
model = models.Sequential(
[
layers.Conv2D(32, (3, 3), activation="relu", input_shape=input_shape),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation="relu"),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation="relu"),
layers.Flatten(),
layers.Dense(128, activation="relu"),
layers.Dense(10, activation="softmax"),
]
)
model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
return model
[docs]
def main():
# Load MNIST
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# Normalize to [0,1] and add channel dim
x_train = (x_train.astype("float32") / 255.0)[..., None]
x_test = (x_test.astype("float32") / 255.0)[..., None]
model = build_cnn_model()
model.summary()
# Train
model.fit(
x_train,
y_train,
epochs=5,
batch_size=64,
validation_split=0.1,
)
# Evaluate
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f"Test accuracy: {test_acc:.4f}")
# this script lives at PACKAGE_ROOT/src/package_name/ocr/train_mnist_digit_model.py
# model data lives at PACKAGE_ROOT/data/models/mnist_digit_cnn.keras
PACKAGE_ROOT = Path(__file__).parents[2]
out_path = PACKAGE_ROOT / "data" / "models" / "mnist_digit_cnn.keras"
model.save(out_path)
print(f"Saved digit model to {out_path}")
if __name__ == "__main__":
main()