Fused transformer kernels on Emily

Copy Markdown View Source
Mix.install(
  [
    {:emily, "~> 0.3"},
    {:bumblebee, "~> 0.6"},
    {:tokenizers, "~> 0.5"},
    {:axon, "~> 0.7"},
    {:nx, "~> 0.10"},
    {:kino, "~> 0.14"},
    {:req, "~> 0.5"}
  ],
  config: [
    nx: [default_backend: Emily.Backend]
  ]
)

Overview

Emily.Fast wraps four MLX fused kernels: rms_norm, layer_norm, rope, and scaled_dot_product_attention. Every helper also ships a pure-defn fallback so the calls work on any backend — the fused path is only taken under Emily.Compiler.

This notebook has two parts:

  1. Low-level: call Emily.Fast.rms_norm and Emily.Fast.scaled_dot_product_attention directly, and compare against the same math written by hand with Nx primitives. Confirms both paths produce the same tensors; then times both.
  2. High-level: apply Emily.Bumblebee.FastKernels.apply/1 to a whisper-tiny Bumblebee model and transcribe the JFK sample from the Whisper notebook end-to-end, with and without the shim. Shows what the fused kernels are worth on a real workload.

Part 1: Low-level kernels

Each kernel's composed version is the textbook formula written with Nx primitives. Both versions are wrapped with Nx.Defn.jit/2 + Emily.Compiler so the graph is traced, JIT-compiled, and dispatched to MLX — the only difference is how many kernels the graph decomposes into on the way down.

# RMSNorm: normalise to unit RMS over the last axis, scale by `weight`.
composed_rms_norm = fn x, weight ->
  rms =
    x
    |> Nx.pow(2)
    |> Nx.mean(axes: [-1], keep_axes: true)
    |> Nx.add(1.0e-6)
    |> Nx.sqrt()

  x |> Nx.divide(rms) |> Nx.multiply(weight)
end

fused_rms_norm = fn x, weight ->
  Emily.Fast.rms_norm(x, weight, eps: 1.0e-6)
end

# SDPA: softmax(QKᵀ / sqrt(d)) · V. The max-subtract in the composed
# version stabilises the softmax; without it, large logits overflow.
composed_sdpa = fn q, k, v ->
  scale = 1.0 / :math.sqrt(Nx.axis_size(q, -1))
  logits = Nx.dot(q, [-1], [0, 1], k, [-1], [0, 1]) |> Nx.multiply(scale)
  stable = Nx.subtract(logits, Nx.reduce_max(logits, axes: [-1], keep_axes: true))
  exp = Nx.exp(stable)
  probs = Nx.divide(exp, Nx.sum(exp, axes: [-1], keep_axes: true))
  Nx.dot(probs, [-1], [0, 1], v, [-2], [0, 1])
end

fused_sdpa = fn q, k, v ->
  Emily.Fast.scaled_dot_product_attention(q, k, v)
end

jit = &Nx.Defn.jit(&1, compiler: Emily.Compiler)

composed_rms_norm_c = jit.(composed_rms_norm)
fused_rms_norm_c = jit.(fused_rms_norm)
composed_sdpa_c = jit.(composed_sdpa)
fused_sdpa_c = jit.(fused_sdpa)

Correctness

# RMSNorm: batch=4, seq=128, hidden=1024 — representative for a
# small transformer block.
rms_x = Nx.iota({4, 128, 1024}, type: :f32) |> Nx.divide(100.0)
rms_weight = Nx.broadcast(1.0, {1024}) |> Nx.as_type(:f32)

rms_composed = composed_rms_norm_c.(rms_x, rms_weight)
rms_fused = fused_rms_norm_c.(rms_x, rms_weight)

rms_max_diff =
  rms_composed
  |> Nx.subtract(rms_fused)
  |> Nx.abs()
  |> Nx.reduce_max()
  |> Nx.to_number()

IO.puts("RMSNorm max |composed − fused|: #{rms_max_diff}")
# SDPA: batch=2, heads=8, seq=128, head_dim=64 — a small attention
# block at realistic proportions.
sdpa_shape = {2, 8, 128, 64}

q = Nx.iota(sdpa_shape, type: :f32) |> Nx.divide(10_000.0)
k = Nx.iota(sdpa_shape, type: :f32) |> Nx.divide(10_000.0)
v = Nx.iota(sdpa_shape, type: :f32) |> Nx.divide(10_000.0)

sdpa_composed = composed_sdpa_c.(q, k, v)
sdpa_fused = fused_sdpa_c.(q, k, v)

sdpa_max_diff =
  sdpa_composed
  |> Nx.subtract(sdpa_fused)
  |> Nx.abs()
  |> Nx.reduce_max()
  |> Nx.to_number()

IO.puts("SDPA max |composed − fused|: #{sdpa_max_diff}")

Both diffs should be ≲ 1e-5 — round-off, not a bug. The fused kernels compute the same math with different reduction orders.

Timing

MLX is lazy: ops queue onto a Metal command buffer and only run when something forces evaluation. The helper below calls Emily.eval/1 to flush the graph before each tc boundary, so what we measure is actually GPU work rather than queueing overhead.

bench = fn label, fun, args, iters ->
  # Warmup: the first call compiles the defn closure.
  warmup = apply(fun, args)
  Emily.eval(warmup.data.ref)

  {time_us, _} =
    :timer.tc(fn ->
      Enum.each(1..iters, fn _ ->
        out = apply(fun, args)
        Emily.eval(out.data.ref)
      end)
    end)

  per_call_ms = time_us / iters / 1000
  IO.puts("#{label}: #{Float.round(per_call_ms, 3)} ms/call")
  per_call_ms
end

iters = 50

IO.puts("\nRMSNorm  (batch=4, seq=128, hidden=1024):")
r_composed = bench.("  composed", composed_rms_norm_c, [rms_x, rms_weight], iters)
r_fused = bench.("  fused   ", fused_rms_norm_c, [rms_x, rms_weight], iters)

IO.puts("  speedup: #{Float.round(r_composed / r_fused, 2)}×")

IO.puts("\nSDPA     (batch=2, heads=8, seq=128, head_dim=64):")
s_composed = bench.("  composed", composed_sdpa_c, [q, k, v], iters)
s_fused = bench.("  fused   ", fused_sdpa_c, [q, k, v], iters)

IO.puts("  speedup: #{Float.round(s_composed / s_fused, 2)}×")

The absolute numbers depend heavily on machine, batch size, and whether MLX is running AOT or JIT; what matters is that the fused path is consistently at least as fast as the composed one, with the gap widening on larger tensors (where the savings from avoiding intermediate allocations dominate kernel-launch overhead).

Part 2: FastKernels shim on whisper-tiny

Emily.Bumblebee.FastKernels.apply/1 walks a Bumblebee Axon graph and rewrites its RMSNorm / LayerNorm / RoPE / SDPA nodes to call the fused kernels from Part 1. The rest of the model is untouched, so the same params work on both graphs.

defmodule WAV do
  @moduledoc false

  @doc "Decode 16-bit PCM mono 16 kHz WAV to a 1-D f32 Nx tensor in [-1, 1]."
  def decode!(<<"RIFF", _size::little-32, "WAVE", rest::binary>>) do
    {%{bits: 16, channels: 1, rate: 16_000}, data} = scan_chunks(rest, nil, nil)
    data |> Nx.from_binary(:s16) |> Nx.divide(32_768.0) |> Nx.as_type(:f32)
  end

  defp scan_chunks(<<"fmt ", size::little-32, fmt::binary-size(size), rest::binary>>, _, data) do
    <<_::little-16, channels::little-16, rate::little-32, _::little-32,
      _::little-16, bits::little-16, _rest::binary>> = fmt

    scan_chunks(rest, %{bits: bits, channels: channels, rate: rate}, data)
  end

  defp scan_chunks(<<"data", size::little-32, data::binary-size(size), _::binary>>, fmt, _),
    do: {fmt, data}

  defp scan_chunks(
         <<_id::binary-4, size::little-32, _::binary-size(size), rest::binary>>,
         fmt,
         data
       ),
       do: scan_chunks(rest, fmt, data)
end
repo = {:hf, "openai/whisper-tiny"}

{:ok, model_info} = Bumblebee.load_model(repo)
{:ok, featurizer} = Bumblebee.load_featurizer(repo)
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)
{:ok, generation_config} = Bumblebee.load_generation_config(repo)

jfk_audio =
  Req.get!(
    "https://raw.githubusercontent.com/ggml-org/whisper.cpp/master/samples/jfk.wav"
  ).body
  |> WAV.decode!()

Nx.shape(jfk_audio)
build_serving = fn model_info ->
  Bumblebee.Audio.speech_to_text_whisper(
    model_info,
    featurizer,
    tokenizer,
    generation_config,
    defn_options: [compiler: Emily.Compiler]
  )
end

baseline_serving = build_serving.(model_info)

fast_model_info = update_in(model_info.model, &Emily.Bumblebee.FastKernels.apply/1)
fast_serving = build_serving.(fast_model_info)

Correctness on a real transcription

Both servings should produce the same words on the same audio — within greedy-decode jitter. A divergence here means the shim has rewritten a layer incorrectly.

baseline_text =
  baseline_serving
  |> Nx.Serving.run(jfk_audio)
  |> Map.fetch!(:chunks)
  |> Enum.map_join(& &1.text)
  |> String.trim()

fast_text =
  fast_serving
  |> Nx.Serving.run(jfk_audio)
  |> Map.fetch!(:chunks)
  |> Enum.map_join(& &1.text)
  |> String.trim()

{baseline_text, fast_text}

Timing

time_ms = fn serving, audio ->
  {us, _} = :timer.tc(fn -> Nx.Serving.run(serving, audio) end)
  us / 1000
end

# Warmups are already done (correctness cell above ran each serving
# once, which triggers compile).
baseline_ms = time_ms.(baseline_serving, jfk_audio)
fast_ms = time_ms.(fast_serving, jfk_audio)

speedup = Float.round(baseline_ms / fast_ms, 2)

IO.puts("whisper-tiny on #{inspect(Nx.shape(jfk_audio))} samples:")
IO.puts("  baseline      : #{Float.round(baseline_ms, 1)} ms")
IO.puts("  with FastKernels: #{Float.round(fast_ms, 1)} ms  (#{speedup}×)")

Whisper-tiny is small — the decoder dominates and each decoder step touches an RMSNorm and an SDPA. Larger models (Whisper-base / -small, Qwen3, Llama) see bigger wins because the fused kernels replace more of the hot path.

For inspection, the rewritten graph has the same topology as the original; only specific node op fields change:

fast_ops =
  fast_model_info.model
  |> Axon.properties()
  |> Enum.frequencies_by(fn {_name, op} -> op end)
  |> Enum.sort_by(&elem(&1, 1), :desc)
  |> Enum.take(10)

IO.puts("Top op frequencies in the rewritten Axon graph:")
Enum.each(fast_ops, fn {op, n} -> IO.puts("  #{op}: #{n}") end)

What to do next

  • Try the shim on Qwen3-0.6B via the qwen3_quantized notebook — add one update_in call on model_info.model before building the generation serving. RoPE and SDPA win harder there.
  • For training, leave the shim off: the rewritten nodes intentionally drop the dropout_rate > 0 branch, so gradient training paths need the composed-defn version to stay correct.
  • See Emily.Fast for the full fused-kernel catalogue and the additive-mask variant of SDPA.