Mixed-precision training utilities.
Standard recipe for memory-efficient training: bf16 activations with f32 master weights and dynamic loss scaling. Keeps the full-precision copy of parameters for numerically stable optimizer updates while running the forward and backward pass in half precision.
Worked example
alias Emily.MixedPrecision, as: MP
alias Emily.MixedPrecision.LossScaler
# f32 master weights — the optimizer's ground truth.
master_params = init_params()
scaler = MP.loss_scale()
for {x, y} <- batches, reduce: {master_params, scaler} do
{params, scaler} ->
# Forward pass in bf16.
bf16_params = MP.cast_params(params, {:bf, 16})
# Backward pass: grad w.r.t. f32 master params, but the
# forward graph runs in bf16 thanks to the as_type casts
# inside the closure.
grads =
Nx.Defn.grad(params, fn p ->
p
|> MP.cast_params({:bf, 16})
|> forward(x, y)
|> MP.scale_loss(scaler)
end)
# Unscale, detect overflow, adjust scaler.
{grads, overflow?} = MP.unscale(grads, scaler)
scaler = MP.update(scaler, overflow?)
if overflow? do
{params, scaler}
else
f32_grads = MP.accumulate_grad(grads, {:f, 32})
{sgd_step(params, f32_grads, lr), scaler}
end
endContainer traversal
cast_params/2, accumulate_grad/2, and has_overflow?/1 traverse
plain maps, tuples, and lists of Nx.Tensor leaves. For
Axon.ModelState, access the .data field first:
MP.cast_params(model_state.data, {:bf, 16})Examples
iex> params = %{w: Nx.tensor([1.0, 2.0], type: :f32)}
iex> Emily.MixedPrecision.cast_params(params, {:bf, 16}).w.type
{:bf, 16}
iex> scaler = Emily.MixedPrecision.loss_scale(scale: 1024.0)
iex> scaler.scale
1024.0
iex> Emily.MixedPrecision.update(scaler, true).scale
512.0
Summary
Functions
Upcast float tensors in a nested gradient structure to type.
Downcast float tensors in a nested structure to type.
Check whether any tensor in a nested structure contains inf or nan.
Create a new dynamic loss scaler.
Scale the loss by the scaler's current factor.
Unscale gradients and detect overflow.
Update the scaler after a training step.
Functions
Upcast float tensors in a nested gradient structure to type.
Semantically identical to cast_params/2 — exists for readability
at the call site (the direction of the cast is part of the name).
Examples
iex> grads = %{w: Nx.tensor([0.5, 0.25], type: {:bf, 16})}
iex> Emily.MixedPrecision.accumulate_grad(grads, {:f, 32}).w.type
{:f, 32}
Downcast float tensors in a nested structure to type.
Integer and predicate tensors are left unchanged.
Examples
iex> params = %{w: Nx.tensor([1.0, 2.0], type: :f32)}
iex> Emily.MixedPrecision.cast_params(params, {:bf, 16}).w.type
{:bf, 16}
Check whether any tensor in a nested structure contains inf or nan.
Examples
iex> Emily.MixedPrecision.has_overflow?(%{w: Nx.tensor([1.0, 2.0])})
false
iex> Emily.MixedPrecision.has_overflow?(%{w: Nx.tensor([1.0, :nan])})
true
Create a new dynamic loss scaler.
Options
:scale— initial scale factor (default65_536.0):growth_factor— multiply scale by this on growth (default2.0):backoff_factor— multiply scale by this on overflow (default0.5):growth_interval— successful steps before growing (default2000):min_scale— floor for the scale (default1.0)
Examples
iex> scaler = Emily.MixedPrecision.loss_scale(scale: 1024.0)
iex> scaler.scale
1024.0
Scale the loss by the scaler's current factor.
Call this inside the function passed to Nx.Defn.grad so that the
backward pass produces scaled gradients.
Examples
iex> scaler = Emily.MixedPrecision.loss_scale(scale: 8.0)
iex> loss = Nx.tensor(2.5)
iex> Emily.MixedPrecision.scale_loss(loss, scaler) |> Nx.to_number()
20.0
Unscale gradients and detect overflow.
Divides every float tensor in grads by scaler.scale, then checks
for inf/nan. Returns {unscaled_grads, overflow?}.
Examples
iex> scaler = Emily.MixedPrecision.loss_scale(scale: 4.0)
iex> grads = %{w: Nx.tensor([8.0, 16.0])}
iex> {unscaled, overflow?} = Emily.MixedPrecision.unscale(grads, scaler)
iex> {Nx.to_flat_list(unscaled.w), overflow?}
{[2.0, 4.0], false}
Update the scaler after a training step.
On overflow: halves the scale (floored at min_scale), resets the
counter. On success: increments the counter; doubles the scale after
growth_interval consecutive successes.
Examples
iex> scaler = Emily.MixedPrecision.loss_scale(scale: 1024.0)
iex> Emily.MixedPrecision.update(scaler, true).scale
512.0
iex> scaler = Emily.MixedPrecision.loss_scale(scale: 1024.0)
iex> Emily.MixedPrecision.update(scaler, false).counter
1