Training an MNIST classifier on Emily

Copy Markdown View Source
Mix.install(
  [
    {:emily, "~> 0.3"},
    {:axon, "~> 0.7"},
    {:scidata, "~> 0.1"},
    {:nx, "~> 0.10"},
    {:kino, "~> 0.14"}
  ],
  config: [
    nx: [default_backend: Emily.Backend]
  ]
)

Overview

The other notebooks in this repo run pretrained models forward. This one trains from scratch — a small Axon MLP on MNIST, lowered through Emily.Compiler so every step of the backward pass dispatches to MLX. It's the end-to-end exercise of the Nx.Defn.grad chain that inference demos can't cover.

After the f32 baseline we repeat the same training run with an Axon mixed-precision policy (bf16 compute, f32 params). This is the canonical recipe Emily.MixedPrecision is built around — halve the activation memory, keep the master weights in f32 for stable optimizer updates.

Loading the dataset

scidata caches the ~11 MB MNIST archive at ~/Library/Caches/scidata on macOS, so this step is a no-op after the first run.

batch_size = 128

images = fn {bin, type, shape} ->
  bin
  |> Nx.from_binary(type)
  |> Nx.reshape(shape)
  |> Nx.reshape({elem(shape, 0), 784})
  |> Nx.divide(255.0)
end

labels = fn {bin, type, shape} ->
  bin
  |> Nx.from_binary(type)
  |> Nx.reshape(shape)
  |> Nx.new_axis(-1)
  |> Nx.equal(Nx.iota({1, 10}))
end

{train_images_raw, train_labels_raw} = Scidata.MNIST.download()
{test_images_raw, test_labels_raw} = Scidata.MNIST.download_test()

train_batches =
  Stream.zip(
    images.(train_images_raw) |> Nx.to_batched(batch_size),
    labels.(train_labels_raw) |> Nx.to_batched(batch_size)
  )

test_images = images.(test_images_raw)
test_labels = labels.(test_labels_raw)

Pixels are normalised into [0.0, 1.0] and labels are one-hot encoded so Axon.Loop.trainer/2 can use :categorical_cross_entropy directly.

Defining the model

A two-layer MLP — small enough that five epochs easily clear 96% test accuracy, large enough to exercise the backward pass through two dense kernels plus a softmax.

model =
  Axon.input("input", shape: {nil, 784})
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(10, activation: :softmax)

Axon.Display.as_graph(model, Nx.template({1, 784}, :f32))

Training

trained_state =
  model
  |> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
  |> Axon.Loop.run(train_batches, %{},
    epochs: 5,
    compiler: Emily.Compiler
  )

Axon.Loop.run accepts compiler: Emily.Compiler the same way Nx.Serving.run does. Each step's forward pass, loss, gradient, and Adam update all become MLX ops. Five epochs over 60k examples runs in tens of seconds on Apple Silicon.

Evaluating

evaluate = fn state ->
  logits = Axon.predict(model, state, test_images, compiler: Emily.Compiler)
  predicted = Nx.argmax(logits, axis: -1)
  actual = Nx.argmax(test_labels, axis: -1)

  Nx.mean(Nx.equal(predicted, actual))
  |> Nx.backend_transfer(Nx.BinaryBackend)
  |> Nx.to_number()
end

accuracy = evaluate.(trained_state)
IO.puts("f32 test accuracy: #{Float.round(accuracy * 100, 2)}%")

Expect >96%. The :training_full canary at test/emily/training/mnist_full_test.exs asserts exactly this threshold for regression purposes.

Mixed-precision training

Axon's MixedPrecision policy casts every parameter to the compute dtype on the forward pass while keeping the master copy in params dtype for the optimizer. The policy below mirrors the bf16 recipe documented in Emily.MixedPrecision.

policy =
  Axon.MixedPrecision.create_policy(
    params: {:f, 32},
    compute: {:bf, 16},
    output: {:f, 32}
  )

bf16_model =
  Axon.input("input", shape: {nil, 784})
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(10, activation: :softmax)
  |> Axon.MixedPrecision.apply_policy(policy)

bf16_state =
  bf16_model
  |> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
  |> Axon.Loop.run(train_batches, %{},
    epochs: 5,
    compiler: Emily.Compiler
  )

bf16_accuracy = evaluate.(bf16_state)
IO.puts("bf16 test accuracy: #{Float.round(bf16_accuracy * 100, 2)}%")

bf16 should land within ~0.5% of the f32 baseline. On this architecture the memory saving is modest — on a real transformer block with much larger activations, the win compounds.

For a fully manual mixed-precision loop (dynamic loss scaling, explicit overflow detection, hand-rolled optimizer step), Emily.MixedPrecision exposes cast_params/2, scale_loss/2, unscale/2, and an LossScaler struct — see its @moduledoc for the worked example.

Telemetry

Emily emits a span at each evaluation boundary. The same pattern from the DistilBERT notebook applies here — attach on [:emily, :eval, :stop] to log per-batch timing, or on [:emily, :fallback, :stop] to spot any op that routed through Nx.BinaryBackend. See Emily.Telemetry for the full event catalogue.

Next steps

  • Swap the MLP for the LeNet-style CNN in test/emily/training/mnist_cnn_full_test.exs to exercise MLX's conv and window-pool ops through the backward pass.
  • Drop compiler: Emily.Compiler to see the same loop run on Nx.BinaryBackend — useful for A/B numerics checks against a reference backend.