Emily.MixedPrecision (emily v0.3.2)

Copy Markdown View Source

Mixed-precision training utilities.

Standard recipe for memory-efficient training: bf16 activations with f32 master weights and dynamic loss scaling. Keeps the full-precision copy of parameters for numerically stable optimizer updates while running the forward and backward pass in half precision.

Worked example

alias Emily.MixedPrecision, as: MP
alias Emily.MixedPrecision.LossScaler

# f32 master weights — the optimizer's ground truth.
master_params = init_params()
scaler = MP.loss_scale()

for {x, y} <- batches, reduce: {master_params, scaler} do
  {params, scaler} ->
    # Forward pass in bf16.
    bf16_params = MP.cast_params(params, {:bf, 16})

    # Backward pass: grad w.r.t. f32 master params, but the
    # forward graph runs in bf16 thanks to the as_type casts
    # inside the closure.
    grads =
      Nx.Defn.grad(params, fn p ->
        p
        |> MP.cast_params({:bf, 16})
        |> forward(x, y)
        |> MP.scale_loss(scaler)
      end)

    # Unscale, detect overflow, adjust scaler.
    {grads, overflow?} = MP.unscale(grads, scaler)
    scaler = MP.update(scaler, overflow?)

    if overflow? do
      {params, scaler}
    else
      f32_grads = MP.accumulate_grad(grads, {:f, 32})
      {sgd_step(params, f32_grads, lr), scaler}
    end
end

Container traversal

cast_params/2, accumulate_grad/2, and has_overflow?/1 traverse plain maps, tuples, and lists of Nx.Tensor leaves. For Axon.ModelState, access the .data field first:

MP.cast_params(model_state.data, {:bf, 16})

Examples

iex> params = %{w: Nx.tensor([1.0, 2.0], type: :f32)}
iex> Emily.MixedPrecision.cast_params(params, {:bf, 16}).w.type
{:bf, 16}

iex> scaler = Emily.MixedPrecision.loss_scale(scale: 1024.0)
iex> scaler.scale
1024.0
iex> Emily.MixedPrecision.update(scaler, true).scale
512.0

Summary

Functions

Upcast float tensors in a nested gradient structure to type.

Downcast float tensors in a nested structure to type.

Check whether any tensor in a nested structure contains inf or nan.

Create a new dynamic loss scaler.

Scale the loss by the scaler's current factor.

Unscale gradients and detect overflow.

Update the scaler after a training step.

Functions

accumulate_grad(grads, type)

Upcast float tensors in a nested gradient structure to type.

Semantically identical to cast_params/2 — exists for readability at the call site (the direction of the cast is part of the name).

Examples

iex> grads = %{w: Nx.tensor([0.5, 0.25], type: {:bf, 16})}
iex> Emily.MixedPrecision.accumulate_grad(grads, {:f, 32}).w.type
{:f, 32}

cast_params(params, type)

Downcast float tensors in a nested structure to type.

Integer and predicate tensors are left unchanged.

Examples

iex> params = %{w: Nx.tensor([1.0, 2.0], type: :f32)}
iex> Emily.MixedPrecision.cast_params(params, {:bf, 16}).w.type
{:bf, 16}

has_overflow?(structure)

Check whether any tensor in a nested structure contains inf or nan.

Examples

iex> Emily.MixedPrecision.has_overflow?(%{w: Nx.tensor([1.0, 2.0])})
false

iex> Emily.MixedPrecision.has_overflow?(%{w: Nx.tensor([1.0, :nan])})
true

loss_scale(opts \\ [])

Create a new dynamic loss scaler.

Options

  • :scale — initial scale factor (default 65_536.0)
  • :growth_factor — multiply scale by this on growth (default 2.0)
  • :backoff_factor — multiply scale by this on overflow (default 0.5)
  • :growth_interval — successful steps before growing (default 2000)
  • :min_scale — floor for the scale (default 1.0)

Examples

iex> scaler = Emily.MixedPrecision.loss_scale(scale: 1024.0)
iex> scaler.scale
1024.0

scale_loss(loss, loss_scaler)

Scale the loss by the scaler's current factor.

Call this inside the function passed to Nx.Defn.grad so that the backward pass produces scaled gradients.

Examples

iex> scaler = Emily.MixedPrecision.loss_scale(scale: 8.0)
iex> loss = Nx.tensor(2.5)
iex> Emily.MixedPrecision.scale_loss(loss, scaler) |> Nx.to_number()
20.0

unscale(grads, loss_scaler)

Unscale gradients and detect overflow.

Divides every float tensor in grads by scaler.scale, then checks for inf/nan. Returns {unscaled_grads, overflow?}.

Examples

iex> scaler = Emily.MixedPrecision.loss_scale(scale: 4.0)
iex> grads = %{w: Nx.tensor([8.0, 16.0])}
iex> {unscaled, overflow?} = Emily.MixedPrecision.unscale(grads, scaler)
iex> {Nx.to_flat_list(unscaled.w), overflow?}
{[2.0, 4.0], false}

update(scaler, overflow)

Update the scaler after a training step.

On overflow: halves the scale (floored at min_scale), resets the counter. On success: increments the counter; doubles the scale after growth_interval consecutive successes.

Examples

iex> scaler = Emily.MixedPrecision.loss_scale(scale: 1024.0)
iex> Emily.MixedPrecision.update(scaler, true).scale
512.0

iex> scaler = Emily.MixedPrecision.loss_scale(scale: 1024.0)
iex> Emily.MixedPrecision.update(scaler, false).counter
1