Rewrite RMSNorm, LayerNorm, RoPE, and SDPA Axon layers of a
Bumblebee model so they call Emily.Fast.* instead of their stock
defn implementations. When the rewritten model is then evaluated
under Emily.Compiler, those Emily.Fast.* calls dispatch to
fused MLX kernels via the :optional-node mechanism (see
Emily.Fast's moduledoc). On any other backend the helpers fall
back to defn composition and produce mathematically equivalent
results, so applying the shim is safe even if the model is later
evaluated on Nx.BinaryBackend or EXLA.
Optional dependency
This module depends on :axon and :bumblebee, which are declared
as optional: true in Emily's mix.exs. Consumers who don't pull
those deps into their own project get a clean build: the whole
module definition is wrapped in Code.ensure_loaded?/1 and elides
entirely when either dep is missing. To use the shim, add both to
your own deps/0:
{:bumblebee, "~> 0.6"},
{:axon, "~> 0.7"}Usage
{:ok, model_info} = Bumblebee.load_model({:hf, "openai/whisper-tiny"})
model_info =
update_in(model_info.model, &Emily.Bumblebee.FastKernels.apply/1)
# then proceed with Bumblebee.Audio.speech_to_text_whisper/5
# (or Text.generation/4 / etc.) as usual.Coverage
:rms_norm(Bumblebee'sBumblebee.Layers.rms_norm/2).:layer_norm(Axon's built-in normalization layer).Bumblebee.Layers.apply_rotary_embedding/5— supports the default schedule plus the:linear,:dynamic,:longrope,:llama3scaling strategies. Inverse frequencies are precomputed Elixir-side using Bumblebee's own helpers and passed tomx::fast::ropevia thefreqs-override overload.Bumblebee.Layers.attention_output_impl/3— coalesced with its siblingattention_weights_impl/7into a singlemx::fast::scaled_dot_product_attentiondispatch. Mask translation: causal + window + key/head/bias collapsed to one additivearraymask.
What's not rewritten
- Norms with
:channel_indexother than-1(uncommon outside vision-CNN heads — those don't fit the fused kernel's last-axis-only contract). The original layer is left in place. - Attention layers with
dropout_rate > 0. Inference path isdropout_rate: 0everywhere, so this is a no-op in practice; training paths continue using composed defn.
Summary
Functions
Apply every available rewrite to model.