Emily.Fast (emily v0.3.5)

Copy Markdown View Source

Fused transformer kernels as defn-callable helpers.

Each function here emits an optional-expression node whose op name matches a custom callback on Emily.Backend. Under Emily the Defn evaluator dispatches directly to the MLX mx::fast::* kernel; under any other backend the defn-composed fallback runs and produces a mathematically equivalent result. That means Bumblebee models rewritten to use these helpers (via the test-only Emily.Bumblebee.FastKernels shim) keep running on Nx.BinaryBackend / EXLA for conformance — just without the fusion speedup.

Hook mechanism

Each helper wraps Nx.Defn.Expr.optional(name, in_args, fallback), which creates an :optional Expr node. At eval time the Nx evaluator looks for function_exported?(backend, name, length(in_args) + 1) on the active backend and, if present, calls backend.name(out, args...) directly. Otherwise the fallback defn runs. This is Nx's extension point for vendor-fused kernels (the same pattern EXLA uses for its native ops).

Tensor vs option arguments

The optional-expression contract splits in_args at the first list: every leading non-list argument is treated as a tensor param; the list (typically a keyword list) is passed through as opts. Every Emily.Fast.* function's final argument is a keyword list of scalars (dims, epsilons, flags), and all tensor inputs come before it.

Covered kernels

  • rms_norm/3mx::fast::rms_norm.
  • layer_norm/4mx::fast::layer_norm.
  • rope/3mx::fast::rope with the standard geometric-progression theta schedule.
  • rope_with_freqs/4mx::fast::rope with a precomputed inverse-frequency table (for Llama-3 / LongRoPE / linear / dynamic scaling).
  • scaled_dot_product_attention/4mx::fast::sdpa, without mask or with causal mask. Optional :sinks opt threads a per-head sinks tensor through the softmax denominator (StreamingLLM).
  • scaled_dot_product_attention_with_mask/5 — the same with an additive bias tensor; also supports :sinks.
  • einsum/2mx::einsum (variadic operands, path-optimised by MLX). Eager-only, not defn-callable: it takes refs directly off Emily-backed tensors and raises on any other backend. There is no defn fallback because writing a correct einsum-string parser (diagonals, ellipsis, contraction ordering) is a non-trivial piece of work we defer until a user asks for cross-backend compatibility.

Usage

Call these from inside a defn or Nx.Defn.jit-traced function, alongside regular Nx ops:

defn block(x, w, b) do
  x
  |> Emily.Fast.layer_norm(w, b, eps: 1.0e-5)
  |> Nx.multiply(0.5)
end

Under Emily.Compiler the layer_norm node dispatches to mx::fast::layer_norm; under Nx.Defn.Evaluator + any other backend it runs the composed fallback.

Summary

Functions

Variadic-operand einsum computed by MLX's path-optimised mx::einsum kernel.

Fused LayerNorm: Welford-style mean+variance of the last axis, then affine (x - mean) / sqrt(var + eps) * weight + bias.

Fused RMSNorm: x * rsqrt(mean(x², axis=-1) + eps) * weight.

Fused RoPE with the standard geometric-progression theta schedule.

RoPE with a precomputed inverse-frequency table.

Fused scaled-dot-product attention without an additive-bias mask.

SDPA with an additive mask tensor broadcasting across QKᵀ.

Functions

einsum(subscripts, operands)

@spec einsum(String.t(), [Nx.Tensor.t()]) :: Nx.Tensor.t()

Variadic-operand einsum computed by MLX's path-optimised mx::einsum kernel.

subscripts is a standard Einstein-summation equation (e.g. "ij,jk->ik", "bij,bjk->bik", "bhid,bhjd->bhij", "ij,jk,kl->il"). operands is the corresponding list of tensors.

Eager-only, not defn-callable

Unlike the other helpers in this module, einsum/2 does not emit an Nx.Defn.Expr optional/3 node. It takes refs directly off Emily-backed tensors and calls the NIF eagerly, in the same "direct-call helper" style as Emily.Quantization.quantized_matmul/2. Every operand must live on Emily.Backend; anything else raises ArgumentError. Writing a correct einsum-string parser (for diagonals, ellipsis, and contraction ordering) is deferred until a user needs cross-backend compatibility.

Examples

iex> a = Nx.iota({2, 3}, backend: Emily.Backend, type: :f32)
iex> b = Nx.iota({3, 4}, backend: Emily.Backend, type: :f32)
iex> y = Emily.Fast.einsum("ij,jk->ik", [a, b])
iex> Nx.shape(y)
{2, 4}

layer_norm(x, weight, bias, opts \\ [])

@spec layer_norm(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword()) ::
  Nx.Tensor.t()

Fused LayerNorm: Welford-style mean+variance of the last axis, then affine (x - mean) / sqrt(var + eps) * weight + bias.

weight and bias must both have shape {axis_size(x, -1)}.

opts:

  • :eps — small constant added inside the sqrt. Default 1.0e-5.

Examples

iex> x = Nx.tensor([[1.0, 2.0, 3.0, 4.0]], backend: Emily.Backend)
iex> w = Nx.tensor([1.0, 1.0, 1.0, 1.0], backend: Emily.Backend)
iex> b = Nx.tensor([0.0, 0.0, 0.0, 0.0], backend: Emily.Backend)
iex> y = Nx.Defn.jit_apply(
...>   fn x, w, b -> Emily.Fast.layer_norm(x, w, b, eps: 1.0e-5) end,
...>   [x, w, b]
...> )
iex> Nx.shape(y)
{1, 4}

rms_norm(x, weight, opts \\ [])

@spec rms_norm(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()

Fused RMSNorm: x * rsqrt(mean(x², axis=-1) + eps) * weight.

Normalises the last axis of x. weight must have shape {axis_size(x, -1)} and broadcasts across the preceding dims.

opts:

  • :eps — small constant added inside the rsqrt. Default 1.0e-6.

Examples

iex> x = Nx.tensor([[1.0, 2.0, 3.0, 4.0]], backend: Emily.Backend)
iex> w = Nx.tensor([1.0, 1.0, 1.0, 1.0], backend: Emily.Backend)
iex> y = Nx.Defn.jit_apply(
...>   fn x, w -> Emily.Fast.rms_norm(x, w, eps: 1.0e-5) end,
...>   [x, w]
...> )
iex> Nx.shape(y)
{1, 4}

rope(x, offset, opts)

@spec rope(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()

Fused RoPE with the standard geometric-progression theta schedule.

Rotates the trailing dims axes of x (typically head_dim) in position-indexed planes. offset is a scalar integer tensor (usually Nx.tensor(0) for prompt-processing, or the KV-cache length for incremental decode).

opts:

  • :dims — number of trailing axes to rotate. Required.
  • :traditional — if true, use the paired-interleave layout (MLX / Meta convention). If false, split-half layout (HuggingFace convention). Default false.
  • :base — theta base. Default 10_000.0.
  • :scale — position scale multiplier. Default 1.0.

For scaled variants (Llama-3, LongRoPE, linear, dynamic) use rope_with_freqs/4 with a precomputed inverse-frequency table.

Examples

iex> x = Nx.iota({1, 1, 4, 8}, backend: Emily.Backend, type: :f32)
iex> offset = Nx.tensor(0, backend: Emily.Backend)
iex> y = Nx.Defn.jit_apply(
...>   fn x, o -> Emily.Fast.rope(x, o, dims: 8) end,
...>   [x, offset]
...> )
iex> Nx.shape(y)
{1, 1, 4, 8}

rope_with_freqs(x, offset, freqs, opts)

@spec rope_with_freqs(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword()) ::
  Nx.Tensor.t()

RoPE with a precomputed inverse-frequency table.

Use this overload when the model applies a non-standard scaling strategy to the base frequencies (e.g. Llama-3, LongRoPE, linear, dynamic). freqs must be a 1-D :f32 tensor of length dims / 2.

opts:

  • :dims — number of trailing axes to rotate. Required.
  • :traditional — see rope/3. Default false.
  • :scale — position scale multiplier. Default 1.0.

Examples

iex> x = Nx.iota({1, 1, 4, 8}, backend: Emily.Backend, type: :f32)
iex> offset = Nx.tensor(0, backend: Emily.Backend)
iex> freqs = Nx.tensor([1.0, 0.1, 0.01, 0.001], backend: Emily.Backend)
iex> y = Nx.Defn.jit_apply(
...>   fn x, o, f -> Emily.Fast.rope_with_freqs(x, o, f, dims: 8) end,
...>   [x, offset, freqs]
...> )
iex> Nx.shape(y)
{1, 1, 4, 8}

scaled_dot_product_attention(q, k, v, opts \\ [])

@spec scaled_dot_product_attention(
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  keyword()
) :: Nx.Tensor.t()

Fused scaled-dot-product attention without an additive-bias mask.

Expects {batch, heads, seq, head_dim} layout on Q, K, V.

opts:

  • :scale — multiplier on QKᵀ before softmax. Default 1 / sqrt(head_dim).
  • :causal — if true, apply MLX's built-in upper-triangular mask. Default false.
  • :sinks — optional per-head "null destination" tensor. Shape {heads} (or broadcastable to {1, heads, 1, 1}). When present the sinks entries participate in the softmax denominator only, contributing zero to the numerator — the StreamingLLM trick for long-context decode. When absent the helper emits the same node as before (bitwise source-compatible).

Examples

iex> q = Nx.iota({1, 2, 4, 8}, backend: Emily.Backend, type: :f32)
iex> k = Nx.iota({1, 2, 4, 8}, backend: Emily.Backend, type: :f32)
iex> v = Nx.iota({1, 2, 4, 8}, backend: Emily.Backend, type: :f32)
iex> y = Nx.Defn.jit_apply(
...>   fn q, k, v -> Emily.Fast.scaled_dot_product_attention(q, k, v) end,
...>   [q, k, v]
...> )
iex> Nx.shape(y)
{1, 2, 4, 8}

scaled_dot_product_attention_with_mask(q, k, v, mask, opts \\ [])

@spec scaled_dot_product_attention_with_mask(
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  keyword()
) :: Nx.Tensor.t()

SDPA with an additive mask tensor broadcasting across QKᵀ.

mask should match (or broadcast to) shape {batch_or_1, heads_or_1, q_len, k_len} and is added to QKᵀ after scaling. Use Nx.Constants.min_finite/1 on positions to mask out.

opts:

Examples

iex> q = Nx.iota({1, 2, 4, 8}, backend: Emily.Backend, type: :f32)
iex> k = Nx.iota({1, 2, 4, 8}, backend: Emily.Backend, type: :f32)
iex> v = Nx.iota({1, 2, 4, 8}, backend: Emily.Backend, type: :f32)
iex> mask = Nx.broadcast(Nx.tensor(0.0), {1, 1, 4, 4}) |> Nx.backend_transfer(Emily.Backend)
iex> y = Nx.Defn.jit_apply(
...>   fn q, k, v, m -> Emily.Fast.scaled_dot_product_attention_with_mask(q, k, v, m) end,
...>   [q, k, v, mask]
...> )
iex> Nx.shape(y)
{1, 2, 4, 8}