Mix.install(
[
{:emily, "~> 0.3"},
{:axon, "~> 0.7"},
{:scidata, "~> 0.1"},
{:nx, "~> 0.10"},
{:kino, "~> 0.14"}
],
config: [
nx: [default_backend: Emily.Backend]
]
)Overview
The other notebooks in this repo run pretrained models forward. This
one trains from scratch — a small Axon MLP on MNIST, lowered through
Emily.Compiler so every step of the backward pass dispatches to
MLX. It's the end-to-end exercise of the Nx.Defn.grad chain that
inference demos can't cover.
After the f32 baseline we repeat the same training run with an Axon
mixed-precision policy (bf16 compute, f32 params). This is the
canonical recipe Emily.MixedPrecision is built around — halve the
activation memory, keep the master weights in f32 for stable optimizer
updates.
Loading the dataset
scidata caches the ~11 MB MNIST archive at ~/Library/Caches/scidata
on macOS, so this step is a no-op after the first run.
batch_size = 128
images = fn {bin, type, shape} ->
bin
|> Nx.from_binary(type)
|> Nx.reshape(shape)
|> Nx.reshape({elem(shape, 0), 784})
|> Nx.divide(255.0)
end
labels = fn {bin, type, shape} ->
bin
|> Nx.from_binary(type)
|> Nx.reshape(shape)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.iota({1, 10}))
end
{train_images_raw, train_labels_raw} = Scidata.MNIST.download()
{test_images_raw, test_labels_raw} = Scidata.MNIST.download_test()
train_batches =
Stream.zip(
images.(train_images_raw) |> Nx.to_batched(batch_size),
labels.(train_labels_raw) |> Nx.to_batched(batch_size)
)
test_images = images.(test_images_raw)
test_labels = labels.(test_labels_raw)Pixels are normalised into [0.0, 1.0] and labels are one-hot encoded
so Axon.Loop.trainer/2 can use :categorical_cross_entropy
directly.
Defining the model
A two-layer MLP — small enough that five epochs easily clear 96% test accuracy, large enough to exercise the backward pass through two dense kernels plus a softmax.
model =
Axon.input("input", shape: {nil, 784})
|> Axon.dense(128, activation: :relu)
|> Axon.dense(10, activation: :softmax)
Axon.Display.as_graph(model, Nx.template({1, 784}, :f32))Training
trained_state =
model
|> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
|> Axon.Loop.run(train_batches, %{},
epochs: 5,
compiler: Emily.Compiler
)Axon.Loop.run accepts compiler: Emily.Compiler the same way
Nx.Serving.run does. Each step's forward pass, loss, gradient, and
Adam update all become MLX ops. Five epochs over 60k examples runs
in tens of seconds on Apple Silicon.
Evaluating
evaluate = fn state ->
logits = Axon.predict(model, state, test_images, compiler: Emily.Compiler)
predicted = Nx.argmax(logits, axis: -1)
actual = Nx.argmax(test_labels, axis: -1)
Nx.mean(Nx.equal(predicted, actual))
|> Nx.backend_transfer(Nx.BinaryBackend)
|> Nx.to_number()
end
accuracy = evaluate.(trained_state)
IO.puts("f32 test accuracy: #{Float.round(accuracy * 100, 2)}%")Expect >96%. The :training_full canary at
test/emily/training/mnist_full_test.exs asserts exactly this
threshold for regression purposes.
Mixed-precision training
Axon's MixedPrecision policy casts every parameter to the compute
dtype on the forward pass while keeping the master copy in params
dtype for the optimizer. The policy below mirrors the bf16 recipe
documented in Emily.MixedPrecision.
policy =
Axon.MixedPrecision.create_policy(
params: {:f, 32},
compute: {:bf, 16},
output: {:f, 32}
)
bf16_model =
Axon.input("input", shape: {nil, 784})
|> Axon.dense(128, activation: :relu)
|> Axon.dense(10, activation: :softmax)
|> Axon.MixedPrecision.apply_policy(policy)
bf16_state =
bf16_model
|> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
|> Axon.Loop.run(train_batches, %{},
epochs: 5,
compiler: Emily.Compiler
)
bf16_accuracy = evaluate.(bf16_state)
IO.puts("bf16 test accuracy: #{Float.round(bf16_accuracy * 100, 2)}%")bf16 should land within ~0.5% of the f32 baseline. On this architecture the memory saving is modest — on a real transformer block with much larger activations, the win compounds.
For a fully manual mixed-precision loop (dynamic loss scaling,
explicit overflow detection, hand-rolled optimizer step),
Emily.MixedPrecision exposes cast_params/2, scale_loss/2,
unscale/2, and an LossScaler struct — see its @moduledoc for
the worked example.
Telemetry
Emily emits a span at each evaluation boundary. The same pattern from
the DistilBERT notebook applies here — attach on [:emily, :eval, :stop] to log per-batch timing, or on [:emily, :fallback, :stop] to
spot any op that routed through Nx.BinaryBackend. See
Emily.Telemetry for the full event catalogue.
Next steps
- Swap the MLP for the LeNet-style CNN in
test/emily/training/mnist_cnn_full_test.exsto exercise MLX's conv and window-pool ops through the backward pass. - Drop
compiler: Emily.Compilerto see the same loop run onNx.BinaryBackend— useful for A/B numerics checks against a reference backend.