Fused transformer kernels as defn-callable helpers.
Each function here emits a Nx.block/4 node carrying an
Emily.Fast.Block.* struct. Under Emily the Defn evaluator
dispatches to Emily.Backend block/4, which calls the matching
mx::fast::* kernel directly. Under any other backend the
composed-defn fallback runs and produces a mathematically
equivalent result. That means Bumblebee models rewritten to use
these helpers (via the test-only Emily.Bumblebee.FastKernels
shim) keep running on Nx.BinaryBackend / EXLA for conformance —
just without the fusion speedup.
Hook mechanism
Each helper wraps Nx.block(struct, args, output, fun), where
struct is one of the Emily.Fast.Block.* structs that carries
the helper's static configuration (eps, dims, scale, …). At eval
time Emily.Backend block/4 pattern-matches on the struct and
dispatches to the matching mx::fast::* NIF; other backends fall
through to fun, which runs the composed-defn fallback. This is
the Nx 0.12 successor to Nx.Defn.Expr optional/3 and is the same
extension point EXLA uses for its native ops.
Tensor vs option arguments
Configuration lives on the struct; runtime tensors travel in the
Nx.block/4 args list. Each helper builds its struct from the
validated keyword list and threads the tensors through.
Covered kernels
rms_norm/3—mx::fast::rms_norm.layer_norm/4—mx::fast::layer_norm.rope/3—mx::fast::ropewith the standard geometric-progression theta schedule.rope_with_freqs/4—mx::fast::ropewith a precomputed inverse-frequency table (for Llama-3 / LongRoPE / linear / dynamic scaling).scaled_dot_product_attention/4—mx::fast::sdpa, without mask or with causal mask. Optional:sinksopt threads a per-head sinks tensor through the softmax denominator (StreamingLLM).scaled_dot_product_attention_with_mask/5— the same with an additive bias tensor; also supports:sinks.einsum/2—mx::einsum(variadic operands, path-optimised by MLX). Eager-only, not defn-callable: it takes refs directly off Emily-backed tensors and raises on any other backend. There is no defn fallback because writing a correct einsum-string parser (diagonals, ellipsis, contraction ordering) is a non-trivial piece of work we defer until a user asks for cross-backend compatibility.
Usage
Call these from inside a defn or Nx.Defn.jit-traced function,
alongside regular Nx ops:
defn block(x, w, b) do
x
|> Emily.Fast.layer_norm(w, b, eps: 1.0e-5)
|> Nx.multiply(0.5)
endUnder Emily.Compiler the layer_norm node dispatches to
mx::fast::layer_norm; under Nx.Defn.Evaluator + any other
backend it runs the composed fallback.
Summary
Functions
Variadic-operand einsum computed by MLX's path-optimised
mx::einsum kernel.
Fused LayerNorm: Welford-style mean+variance of the last axis, then
affine (x - mean) / sqrt(var + eps) * weight + bias.
Fused RMSNorm: x * rsqrt(mean(x², axis=-1) + eps) * weight.
Fused RoPE with the standard geometric-progression theta schedule.
RoPE with a precomputed inverse-frequency table.
Fused scaled-dot-product attention without an additive-bias mask.
SDPA with an additive mask tensor broadcasting across QKᵀ.
Functions
@spec einsum(String.t(), [Nx.Tensor.t()]) :: Nx.Tensor.t()
Variadic-operand einsum computed by MLX's path-optimised
mx::einsum kernel.
subscripts is a standard Einstein-summation equation (e.g.
"ij,jk->ik", "bij,bjk->bik", "bhid,bhjd->bhij",
"ij,jk,kl->il"). operands is the corresponding list of tensors.
Eager-only, not defn-callable
Unlike the other helpers in this module, einsum/2 does not
emit an Nx.Defn.Expr node. It takes refs directly off Emily-backed
tensors and calls the NIF eagerly, in the same "direct-call helper"
style as Emily.Quantization.quantized_matmul/2. Every operand must
live on Emily.Backend; anything else raises ArgumentError.
Writing a correct einsum-string parser (for diagonals, ellipsis, and
contraction ordering) is deferred until a user needs cross-backend
compatibility.
Examples
iex> a = Nx.iota({2, 3}, backend: Emily.Backend, type: :f32)
iex> b = Nx.iota({3, 4}, backend: Emily.Backend, type: :f32)
iex> y = Emily.Fast.einsum("ij,jk->ik", [a, b])
iex> Nx.shape(y)
{2, 4}
@spec layer_norm(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
Fused LayerNorm: Welford-style mean+variance of the last axis, then
affine (x - mean) / sqrt(var + eps) * weight + bias.
weight and bias must both have shape {axis_size(x, -1)}.
opts:
:eps— small constant added inside the sqrt. Default1.0e-5.
Examples
iex> x = Nx.tensor([[1.0, 2.0, 3.0, 4.0]], backend: Emily.Backend)
iex> w = Nx.tensor([1.0, 1.0, 1.0, 1.0], backend: Emily.Backend)
iex> b = Nx.tensor([0.0, 0.0, 0.0, 0.0], backend: Emily.Backend)
iex> y = Nx.Defn.jit_apply(
...> fn x, w, b -> Emily.Fast.layer_norm(x, w, b, eps: 1.0e-5) end,
...> [x, w, b]
...> )
iex> Nx.shape(y)
{1, 4}
@spec rms_norm(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
Fused RMSNorm: x * rsqrt(mean(x², axis=-1) + eps) * weight.
Normalises the last axis of x. weight must have shape
{axis_size(x, -1)} and broadcasts across the preceding dims.
opts:
:eps— small constant added inside the rsqrt. Default1.0e-6.
Examples
iex> x = Nx.tensor([[1.0, 2.0, 3.0, 4.0]], backend: Emily.Backend)
iex> w = Nx.tensor([1.0, 1.0, 1.0, 1.0], backend: Emily.Backend)
iex> y = Nx.Defn.jit_apply(
...> fn x, w -> Emily.Fast.rms_norm(x, w, eps: 1.0e-5) end,
...> [x, w]
...> )
iex> Nx.shape(y)
{1, 4}
@spec rope(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
Fused RoPE with the standard geometric-progression theta schedule.
Rotates the trailing dims axes of x (typically head_dim) in
position-indexed planes. offset is a scalar integer tensor
(usually Nx.tensor(0) for prompt-processing, or the KV-cache
length for incremental decode).
opts:
:dims— number of trailing axes to rotate. Required.:traditional— iftrue, use the paired-interleave layout (MLX / Meta convention). Iffalse, split-half layout (HuggingFace convention). Defaultfalse.:base— theta base. Default10_000.0.:scale— position scale multiplier. Default1.0.
For scaled variants (Llama-3, LongRoPE, linear, dynamic) use
rope_with_freqs/4 with a precomputed inverse-frequency table.
Examples
iex> x = Nx.iota({1, 1, 4, 8}, backend: Emily.Backend, type: :f32)
iex> offset = Nx.tensor(0, backend: Emily.Backend)
iex> y = Nx.Defn.jit_apply(
...> fn x, o -> Emily.Fast.rope(x, o, dims: 8) end,
...> [x, offset]
...> )
iex> Nx.shape(y)
{1, 1, 4, 8}
@spec rope_with_freqs(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
RoPE with a precomputed inverse-frequency table.
Use this overload when the model applies a non-standard scaling
strategy to the base frequencies (e.g. Llama-3, LongRoPE, linear,
dynamic). freqs must be a 1-D :f32 tensor of length dims / 2.
opts:
:dims— number of trailing axes to rotate. Required.:traditional— seerope/3. Defaultfalse.:scale— position scale multiplier. Default1.0.
Examples
iex> x = Nx.iota({1, 1, 4, 8}, backend: Emily.Backend, type: :f32)
iex> offset = Nx.tensor(0, backend: Emily.Backend)
iex> freqs = Nx.tensor([1.0, 0.1, 0.01, 0.001], backend: Emily.Backend)
iex> y = Nx.Defn.jit_apply(
...> fn x, o, f -> Emily.Fast.rope_with_freqs(x, o, f, dims: 8) end,
...> [x, offset, freqs]
...> )
iex> Nx.shape(y)
{1, 1, 4, 8}
@spec scaled_dot_product_attention( Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword() ) :: Nx.Tensor.t()
Fused scaled-dot-product attention without an additive-bias mask.
Expects {batch, heads, seq, head_dim} layout on Q, K, V.
opts:
:scale— multiplier on QKᵀ before softmax. Default1 / sqrt(head_dim).:causal— iftrue, apply MLX's built-in upper-triangular mask. Defaultfalse.:sinks— optional per-head "null destination" tensor. Shape{heads}(or broadcastable to{1, heads, 1, 1}). When present the sinks entries participate in the softmax denominator only, contributing zero to the numerator — the StreamingLLM trick for long-context decode. When absent the helper emits the same node as before (bitwise source-compatible).
Examples
iex> q = Nx.iota({1, 2, 4, 8}, backend: Emily.Backend, type: :f32)
iex> k = Nx.iota({1, 2, 4, 8}, backend: Emily.Backend, type: :f32)
iex> v = Nx.iota({1, 2, 4, 8}, backend: Emily.Backend, type: :f32)
iex> y = Nx.Defn.jit_apply(
...> fn q, k, v -> Emily.Fast.scaled_dot_product_attention(q, k, v) end,
...> [q, k, v]
...> )
iex> Nx.shape(y)
{1, 2, 4, 8}
@spec scaled_dot_product_attention_with_mask( Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword() ) :: Nx.Tensor.t()
SDPA with an additive mask tensor broadcasting across QKᵀ.
mask should match (or broadcast to) shape
{batch_or_1, heads_or_1, q_len, k_len} and is added to QKᵀ after
scaling. Use Nx.Constants.min_finite/1 on positions to mask out.
opts:
:scale— seescaled_dot_product_attention/4. Default1 / sqrt(head_dim).:sinks— seescaled_dot_product_attention/4. Optional.
Examples
iex> q = Nx.iota({1, 2, 4, 8}, backend: Emily.Backend, type: :f32)
iex> k = Nx.iota({1, 2, 4, 8}, backend: Emily.Backend, type: :f32)
iex> v = Nx.iota({1, 2, 4, 8}, backend: Emily.Backend, type: :f32)
iex> mask = Nx.broadcast(Nx.tensor(0.0), {1, 1, 4, 4}) |> Nx.backend_transfer(Emily.Backend)
iex> y = Nx.Defn.jit_apply(
...> fn q, k, v, m -> Emily.Fast.scaled_dot_product_attention_with_mask(q, k, v, m) end,
...> [q, k, v, mask]
...> )
iex> Nx.shape(y)
{1, 2, 4, 8}