Mix.install(
[
{:emily, "~> 0.3"},
{:bumblebee, "~> 0.6"},
{:tokenizers, "~> 0.5"},
{:axon, "~> 0.7"},
{:nx, "~> 0.10"},
{:kino, "~> 0.14"},
{:req, "~> 0.5"}
],
config: [
nx: [default_backend: Emily.Backend]
]
)Overview
Emily.Fast wraps four MLX fused kernels: rms_norm, layer_norm,
rope, and scaled_dot_product_attention. Every helper also ships a
pure-defn fallback so the calls work on any backend — the fused
path is only taken under Emily.Compiler.
This notebook has two parts:
- Low-level: call
Emily.Fast.rms_normandEmily.Fast.scaled_dot_product_attentiondirectly, and compare against the same math written by hand with Nx primitives. Confirms both paths produce the same tensors; then times both. - High-level: apply
Emily.Bumblebee.FastKernels.apply/1to a whisper-tiny Bumblebee model and transcribe the JFK sample from the Whisper notebook end-to-end, with and without the shim. Shows what the fused kernels are worth on a real workload.
Part 1: Low-level kernels
Each kernel's composed version is the textbook formula written with
Nx primitives. Both versions are wrapped with Nx.Defn.jit/2 +
Emily.Compiler so the graph is traced, JIT-compiled, and dispatched
to MLX — the only difference is how many kernels the graph decomposes
into on the way down.
# RMSNorm: normalise to unit RMS over the last axis, scale by `weight`.
composed_rms_norm = fn x, weight ->
rms =
x
|> Nx.pow(2)
|> Nx.mean(axes: [-1], keep_axes: true)
|> Nx.add(1.0e-6)
|> Nx.sqrt()
x |> Nx.divide(rms) |> Nx.multiply(weight)
end
fused_rms_norm = fn x, weight ->
Emily.Fast.rms_norm(x, weight, eps: 1.0e-6)
end
# SDPA: softmax(QKᵀ / sqrt(d)) · V. The max-subtract in the composed
# version stabilises the softmax; without it, large logits overflow.
composed_sdpa = fn q, k, v ->
scale = 1.0 / :math.sqrt(Nx.axis_size(q, -1))
logits = Nx.dot(q, [-1], [0, 1], k, [-1], [0, 1]) |> Nx.multiply(scale)
stable = Nx.subtract(logits, Nx.reduce_max(logits, axes: [-1], keep_axes: true))
exp = Nx.exp(stable)
probs = Nx.divide(exp, Nx.sum(exp, axes: [-1], keep_axes: true))
Nx.dot(probs, [-1], [0, 1], v, [-2], [0, 1])
end
fused_sdpa = fn q, k, v ->
Emily.Fast.scaled_dot_product_attention(q, k, v)
end
jit = &Nx.Defn.jit(&1, compiler: Emily.Compiler)
composed_rms_norm_c = jit.(composed_rms_norm)
fused_rms_norm_c = jit.(fused_rms_norm)
composed_sdpa_c = jit.(composed_sdpa)
fused_sdpa_c = jit.(fused_sdpa)Correctness
# RMSNorm: batch=4, seq=128, hidden=1024 — representative for a
# small transformer block.
rms_x = Nx.iota({4, 128, 1024}, type: :f32) |> Nx.divide(100.0)
rms_weight = Nx.broadcast(1.0, {1024}) |> Nx.as_type(:f32)
rms_composed = composed_rms_norm_c.(rms_x, rms_weight)
rms_fused = fused_rms_norm_c.(rms_x, rms_weight)
rms_max_diff =
rms_composed
|> Nx.subtract(rms_fused)
|> Nx.abs()
|> Nx.reduce_max()
|> Nx.to_number()
IO.puts("RMSNorm max |composed − fused|: #{rms_max_diff}")# SDPA: batch=2, heads=8, seq=128, head_dim=64 — a small attention
# block at realistic proportions.
sdpa_shape = {2, 8, 128, 64}
q = Nx.iota(sdpa_shape, type: :f32) |> Nx.divide(10_000.0)
k = Nx.iota(sdpa_shape, type: :f32) |> Nx.divide(10_000.0)
v = Nx.iota(sdpa_shape, type: :f32) |> Nx.divide(10_000.0)
sdpa_composed = composed_sdpa_c.(q, k, v)
sdpa_fused = fused_sdpa_c.(q, k, v)
sdpa_max_diff =
sdpa_composed
|> Nx.subtract(sdpa_fused)
|> Nx.abs()
|> Nx.reduce_max()
|> Nx.to_number()
IO.puts("SDPA max |composed − fused|: #{sdpa_max_diff}")Both diffs should be ≲ 1e-5 — round-off, not a bug. The fused kernels compute the same math with different reduction orders.
Timing
MLX is lazy: ops queue onto a Metal command buffer and only run
when something forces evaluation. The helper below calls Emily.eval/1
to flush the graph before each tc boundary, so what we measure is
actually GPU work rather than queueing overhead.
bench = fn label, fun, args, iters ->
# Warmup: the first call compiles the defn closure.
warmup = apply(fun, args)
Emily.eval(warmup.data.ref)
{time_us, _} =
:timer.tc(fn ->
Enum.each(1..iters, fn _ ->
out = apply(fun, args)
Emily.eval(out.data.ref)
end)
end)
per_call_ms = time_us / iters / 1000
IO.puts("#{label}: #{Float.round(per_call_ms, 3)} ms/call")
per_call_ms
end
iters = 50
IO.puts("\nRMSNorm (batch=4, seq=128, hidden=1024):")
r_composed = bench.(" composed", composed_rms_norm_c, [rms_x, rms_weight], iters)
r_fused = bench.(" fused ", fused_rms_norm_c, [rms_x, rms_weight], iters)
IO.puts(" speedup: #{Float.round(r_composed / r_fused, 2)}×")
IO.puts("\nSDPA (batch=2, heads=8, seq=128, head_dim=64):")
s_composed = bench.(" composed", composed_sdpa_c, [q, k, v], iters)
s_fused = bench.(" fused ", fused_sdpa_c, [q, k, v], iters)
IO.puts(" speedup: #{Float.round(s_composed / s_fused, 2)}×")The absolute numbers depend heavily on machine, batch size, and whether MLX is running AOT or JIT; what matters is that the fused path is consistently at least as fast as the composed one, with the gap widening on larger tensors (where the savings from avoiding intermediate allocations dominate kernel-launch overhead).
Part 2: FastKernels shim on whisper-tiny
Emily.Bumblebee.FastKernels.apply/1 walks a Bumblebee Axon graph
and rewrites its RMSNorm / LayerNorm / RoPE / SDPA nodes to call the
fused kernels from Part 1. The rest of the model is untouched, so
the same params work on both graphs.
defmodule WAV do
@moduledoc false
@doc "Decode 16-bit PCM mono 16 kHz WAV to a 1-D f32 Nx tensor in [-1, 1]."
def decode!(<<"RIFF", _size::little-32, "WAVE", rest::binary>>) do
{%{bits: 16, channels: 1, rate: 16_000}, data} = scan_chunks(rest, nil, nil)
data |> Nx.from_binary(:s16) |> Nx.divide(32_768.0) |> Nx.as_type(:f32)
end
defp scan_chunks(<<"fmt ", size::little-32, fmt::binary-size(size), rest::binary>>, _, data) do
<<_::little-16, channels::little-16, rate::little-32, _::little-32,
_::little-16, bits::little-16, _rest::binary>> = fmt
scan_chunks(rest, %{bits: bits, channels: channels, rate: rate}, data)
end
defp scan_chunks(<<"data", size::little-32, data::binary-size(size), _::binary>>, fmt, _),
do: {fmt, data}
defp scan_chunks(
<<_id::binary-4, size::little-32, _::binary-size(size), rest::binary>>,
fmt,
data
),
do: scan_chunks(rest, fmt, data)
endrepo = {:hf, "openai/whisper-tiny"}
{:ok, model_info} = Bumblebee.load_model(repo)
{:ok, featurizer} = Bumblebee.load_featurizer(repo)
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)
{:ok, generation_config} = Bumblebee.load_generation_config(repo)
jfk_audio =
Req.get!(
"https://raw.githubusercontent.com/ggml-org/whisper.cpp/master/samples/jfk.wav"
).body
|> WAV.decode!()
Nx.shape(jfk_audio)build_serving = fn model_info ->
Bumblebee.Audio.speech_to_text_whisper(
model_info,
featurizer,
tokenizer,
generation_config,
defn_options: [compiler: Emily.Compiler]
)
end
baseline_serving = build_serving.(model_info)
fast_model_info = update_in(model_info.model, &Emily.Bumblebee.FastKernels.apply/1)
fast_serving = build_serving.(fast_model_info)Correctness on a real transcription
Both servings should produce the same words on the same audio — within greedy-decode jitter. A divergence here means the shim has rewritten a layer incorrectly.
baseline_text =
baseline_serving
|> Nx.Serving.run(jfk_audio)
|> Map.fetch!(:chunks)
|> Enum.map_join(& &1.text)
|> String.trim()
fast_text =
fast_serving
|> Nx.Serving.run(jfk_audio)
|> Map.fetch!(:chunks)
|> Enum.map_join(& &1.text)
|> String.trim()
{baseline_text, fast_text}Timing
time_ms = fn serving, audio ->
{us, _} = :timer.tc(fn -> Nx.Serving.run(serving, audio) end)
us / 1000
end
# Warmups are already done (correctness cell above ran each serving
# once, which triggers compile).
baseline_ms = time_ms.(baseline_serving, jfk_audio)
fast_ms = time_ms.(fast_serving, jfk_audio)
speedup = Float.round(baseline_ms / fast_ms, 2)
IO.puts("whisper-tiny on #{inspect(Nx.shape(jfk_audio))} samples:")
IO.puts(" baseline : #{Float.round(baseline_ms, 1)} ms")
IO.puts(" with FastKernels: #{Float.round(fast_ms, 1)} ms (#{speedup}×)")Whisper-tiny is small — the decoder dominates and each decoder step touches an RMSNorm and an SDPA. Larger models (Whisper-base / -small, Qwen3, Llama) see bigger wins because the fused kernels replace more of the hot path.
For inspection, the rewritten graph has the same topology as the
original; only specific node op fields change:
fast_ops =
fast_model_info.model
|> Axon.properties()
|> Enum.frequencies_by(fn {_name, op} -> op end)
|> Enum.sort_by(&elem(&1, 1), :desc)
|> Enum.take(10)
IO.puts("Top op frequencies in the rewritten Axon graph:")
Enum.each(fast_ops, fn {op, n} -> IO.puts(" #{op}: #{n}") end)What to do next
- Try the shim on Qwen3-0.6B via the
qwen3_quantizednotebook — add oneupdate_incall onmodel_info.modelbefore building the generation serving. RoPE and SDPA win harder there. - For training, leave the shim off: the rewritten nodes intentionally
drop the
dropout_rate > 0branch, so gradient training paths need the composed-defn version to stay correct. - See
Emily.Fastfor the full fused-kernel catalogue and the additive-mask variant of SDPA.