Fused transformer kernels as defn-callable helpers.
Each function here emits an optional-expression node whose op name
matches a custom callback on Emily.Backend. Under Emily the
Defn evaluator dispatches directly to the MLX mx::fast::*
kernel; under any other backend the defn-composed fallback runs
and produces a mathematically equivalent result. That means
Bumblebee models rewritten to use these helpers (via the test-only
Emily.Bumblebee.FastKernels shim) keep running on
Nx.BinaryBackend / EXLA for conformance — just without the
fusion speedup.
Hook mechanism
Each helper wraps Nx.Defn.Expr.optional(name, in_args, fallback),
which creates an :optional Expr node. At eval time the Nx
evaluator looks for function_exported?(backend, name, length(in_args) + 1) on the active backend and, if present, calls
backend.name(out, args...) directly. Otherwise the fallback defn
runs. This is Nx's extension point for vendor-fused kernels (the
same pattern EXLA uses for its native ops).
Tensor vs option arguments
The optional-expression contract splits in_args at the first
list: every leading non-list argument is treated as a tensor param;
the list (typically a keyword list) is passed through as opts. Every
Emily.Fast.* function's final argument is a keyword list of
scalars (dims, epsilons, flags), and all tensor inputs come before
it.
Covered kernels
rms_norm/3—mx::fast::rms_norm.layer_norm/4—mx::fast::layer_norm.rope/3—mx::fast::ropewith the standard geometric-progression theta schedule.rope_with_freqs/4—mx::fast::ropewith a precomputed inverse-frequency table (for Llama-3 / LongRoPE / linear / dynamic scaling).scaled_dot_product_attention/4—mx::fast::sdpa, without mask or with causal mask. Optional:sinksopt threads a per-head sinks tensor through the softmax denominator (StreamingLLM).scaled_dot_product_attention_with_mask/5— the same with an additive bias tensor; also supports:sinks.einsum/2—mx::einsum(variadic operands, path-optimised by MLX). Eager-only, not defn-callable: it takes refs directly off Emily-backed tensors and raises on any other backend. There is no defn fallback because writing a correct einsum-string parser (diagonals, ellipsis, contraction ordering) is a non-trivial piece of work we defer until a user asks for cross-backend compatibility.
Usage
Call these from inside a defn or Nx.Defn.jit-traced function,
alongside regular Nx ops:
defn block(x, w, b) do
x
|> Emily.Fast.layer_norm(w, b, eps: 1.0e-5)
|> Nx.multiply(0.5)
endUnder Emily.Compiler the layer_norm node dispatches to
mx::fast::layer_norm; under Nx.Defn.Evaluator + any other
backend it runs the composed fallback.
Summary
Functions
Variadic-operand einsum computed by MLX's path-optimised
mx::einsum kernel.
Fused LayerNorm: Welford-style mean+variance of the last axis, then
affine (x - mean) / sqrt(var + eps) * weight + bias.
Fused RMSNorm: x * rsqrt(mean(x², axis=-1) + eps) * weight.
Fused RoPE with the standard geometric-progression theta schedule.
RoPE with a precomputed inverse-frequency table.
Fused scaled-dot-product attention without an additive-bias mask.
SDPA with an additive mask tensor broadcasting across QKᵀ.
Functions
@spec einsum(String.t(), [Nx.Tensor.t()]) :: Nx.Tensor.t()
Variadic-operand einsum computed by MLX's path-optimised
mx::einsum kernel.
subscripts is a standard Einstein-summation equation (e.g.
"ij,jk->ik", "bij,bjk->bik", "bhid,bhjd->bhij",
"ij,jk,kl->il"). operands is the corresponding list of tensors.
Eager-only, not defn-callable
Unlike the other helpers in this module, einsum/2 does not
emit an Nx.Defn.Expr optional/3 node. It takes refs directly off
Emily-backed tensors and calls the NIF eagerly, in the same
"direct-call helper" style as Emily.Quantization.quantized_matmul/2.
Every operand must live on Emily.Backend; anything else raises
ArgumentError. Writing a correct einsum-string parser (for
diagonals, ellipsis, and contraction ordering) is deferred until a
user needs cross-backend compatibility.
Examples
iex> a = Nx.iota({2, 3}, backend: Emily.Backend, type: :f32)
iex> b = Nx.iota({3, 4}, backend: Emily.Backend, type: :f32)
iex> y = Emily.Fast.einsum("ij,jk->ik", [a, b])
iex> Nx.shape(y)
{2, 4}
@spec layer_norm(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
Fused LayerNorm: Welford-style mean+variance of the last axis, then
affine (x - mean) / sqrt(var + eps) * weight + bias.
weight and bias must both have shape {axis_size(x, -1)}.
opts:
:eps— small constant added inside the sqrt. Default1.0e-5.
Examples
iex> x = Nx.tensor([[1.0, 2.0, 3.0, 4.0]], backend: Emily.Backend)
iex> w = Nx.tensor([1.0, 1.0, 1.0, 1.0], backend: Emily.Backend)
iex> b = Nx.tensor([0.0, 0.0, 0.0, 0.0], backend: Emily.Backend)
iex> y = Nx.Defn.jit_apply(
...> fn x, w, b -> Emily.Fast.layer_norm(x, w, b, eps: 1.0e-5) end,
...> [x, w, b]
...> )
iex> Nx.shape(y)
{1, 4}
@spec rms_norm(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
Fused RMSNorm: x * rsqrt(mean(x², axis=-1) + eps) * weight.
Normalises the last axis of x. weight must have shape
{axis_size(x, -1)} and broadcasts across the preceding dims.
opts:
:eps— small constant added inside the rsqrt. Default1.0e-6.
Examples
iex> x = Nx.tensor([[1.0, 2.0, 3.0, 4.0]], backend: Emily.Backend)
iex> w = Nx.tensor([1.0, 1.0, 1.0, 1.0], backend: Emily.Backend)
iex> y = Nx.Defn.jit_apply(
...> fn x, w -> Emily.Fast.rms_norm(x, w, eps: 1.0e-5) end,
...> [x, w]
...> )
iex> Nx.shape(y)
{1, 4}
@spec rope(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
Fused RoPE with the standard geometric-progression theta schedule.
Rotates the trailing dims axes of x (typically head_dim) in
position-indexed planes. offset is a scalar integer tensor
(usually Nx.tensor(0) for prompt-processing, or the KV-cache
length for incremental decode).
opts:
:dims— number of trailing axes to rotate. Required.:traditional— iftrue, use the paired-interleave layout (MLX / Meta convention). Iffalse, split-half layout (HuggingFace convention). Defaultfalse.:base— theta base. Default10_000.0.:scale— position scale multiplier. Default1.0.
For scaled variants (Llama-3, LongRoPE, linear, dynamic) use
rope_with_freqs/4 with a precomputed inverse-frequency table.
Examples
iex> x = Nx.iota({1, 1, 4, 8}, backend: Emily.Backend, type: :f32)
iex> offset = Nx.tensor(0, backend: Emily.Backend)
iex> y = Nx.Defn.jit_apply(
...> fn x, o -> Emily.Fast.rope(x, o, dims: 8) end,
...> [x, offset]
...> )
iex> Nx.shape(y)
{1, 1, 4, 8}
@spec rope_with_freqs(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
RoPE with a precomputed inverse-frequency table.
Use this overload when the model applies a non-standard scaling
strategy to the base frequencies (e.g. Llama-3, LongRoPE, linear,
dynamic). freqs must be a 1-D :f32 tensor of length dims / 2.
opts:
:dims— number of trailing axes to rotate. Required.:traditional— seerope/3. Defaultfalse.:scale— position scale multiplier. Default1.0.
Examples
iex> x = Nx.iota({1, 1, 4, 8}, backend: Emily.Backend, type: :f32)
iex> offset = Nx.tensor(0, backend: Emily.Backend)
iex> freqs = Nx.tensor([1.0, 0.1, 0.01, 0.001], backend: Emily.Backend)
iex> y = Nx.Defn.jit_apply(
...> fn x, o, f -> Emily.Fast.rope_with_freqs(x, o, f, dims: 8) end,
...> [x, offset, freqs]
...> )
iex> Nx.shape(y)
{1, 1, 4, 8}
@spec scaled_dot_product_attention( Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword() ) :: Nx.Tensor.t()
Fused scaled-dot-product attention without an additive-bias mask.
Expects {batch, heads, seq, head_dim} layout on Q, K, V.
opts:
:scale— multiplier on QKᵀ before softmax. Default1 / sqrt(head_dim).:causal— iftrue, apply MLX's built-in upper-triangular mask. Defaultfalse.:sinks— optional per-head "null destination" tensor. Shape{heads}(or broadcastable to{1, heads, 1, 1}). When present the sinks entries participate in the softmax denominator only, contributing zero to the numerator — the StreamingLLM trick for long-context decode. When absent the helper emits the same node as before (bitwise source-compatible).
Examples
iex> q = Nx.iota({1, 2, 4, 8}, backend: Emily.Backend, type: :f32)
iex> k = Nx.iota({1, 2, 4, 8}, backend: Emily.Backend, type: :f32)
iex> v = Nx.iota({1, 2, 4, 8}, backend: Emily.Backend, type: :f32)
iex> y = Nx.Defn.jit_apply(
...> fn q, k, v -> Emily.Fast.scaled_dot_product_attention(q, k, v) end,
...> [q, k, v]
...> )
iex> Nx.shape(y)
{1, 2, 4, 8}
@spec scaled_dot_product_attention_with_mask( Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword() ) :: Nx.Tensor.t()
SDPA with an additive mask tensor broadcasting across QKᵀ.
mask should match (or broadcast to) shape
{batch_or_1, heads_or_1, q_len, k_len} and is added to QKᵀ after
scaling. Use Nx.Constants.min_finite/1 on positions to mask out.
opts:
:scale— seescaled_dot_product_attention/4. Default1 / sqrt(head_dim).:sinks— seescaled_dot_product_attention/4. Optional.
Examples
iex> q = Nx.iota({1, 2, 4, 8}, backend: Emily.Backend, type: :f32)
iex> k = Nx.iota({1, 2, 4, 8}, backend: Emily.Backend, type: :f32)
iex> v = Nx.iota({1, 2, 4, 8}, backend: Emily.Backend, type: :f32)
iex> mask = Nx.broadcast(Nx.tensor(0.0), {1, 1, 4, 4}) |> Nx.backend_transfer(Emily.Backend)
iex> y = Nx.Defn.jit_apply(
...> fn q, k, v, m -> Emily.Fast.scaled_dot_product_attention_with_mask(q, k, v, m) end,
...> [q, k, v, mask]
...> )
iex> Nx.shape(y)
{1, 2, 4, 8}