Emily.Bumblebee.FastKernels (emily v0.3.5)

Copy Markdown View Source

Rewrite RMSNorm, LayerNorm, RoPE, and SDPA Axon layers of a Bumblebee model so they call Emily.Fast.* instead of their stock defn implementations. When the rewritten model is then evaluated under Emily.Compiler, those Emily.Fast.* calls dispatch to fused MLX kernels via the :optional-node mechanism (see Emily.Fast's moduledoc). On any other backend the helpers fall back to defn composition and produce mathematically equivalent results, so applying the shim is safe even if the model is later evaluated on Nx.BinaryBackend or EXLA.

Optional dependency

This module depends on :axon and :bumblebee, which are declared as optional: true in Emily's mix.exs. Consumers who don't pull those deps into their own project get a clean build: the whole module definition is wrapped in Code.ensure_loaded?/1 and elides entirely when either dep is missing. To use the shim, add both to your own deps/0:

{:bumblebee, "~> 0.6"},
{:axon, "~> 0.7"}

Usage

{:ok, model_info} = Bumblebee.load_model({:hf, "openai/whisper-tiny"})

model_info =
  update_in(model_info.model, &Emily.Bumblebee.FastKernels.apply/1)

# then proceed with Bumblebee.Audio.speech_to_text_whisper/5
# (or Text.generation/4 / etc.) as usual.

Coverage

  • :rms_norm (Bumblebee's rms_norm/2).
  • :layer_norm (Axon's built-in normalization layer).
  • Bumblebee.Layers.apply_rotary_embedding/5 — supports the default schedule plus the :linear, :dynamic, :longrope, :llama3 scaling strategies. Inverse frequencies are precomputed Elixir-side using Bumblebee's own helpers and passed to mx::fast::rope via the freqs-override overload.
  • Bumblebee.Layers.attention_output_impl/3 — coalesced with its sibling attention_weights_impl/7 into a single mx::fast::scaled_dot_product_attention dispatch. Mask translation: causal + window + key/head/bias collapsed to one additive array mask.

What's not rewritten

  • Norms with :channel_index other than -1 (uncommon outside vision-CNN heads — those don't fit the fused kernel's last-axis-only contract). The original layer is left in place.
  • Attention layers with dropout_rate > 0. Inference path is dropout_rate: 0 everywhere, so this is a no-op in practice; training paths continue using composed defn.

Summary

Functions

Apply every available rewrite to model.

Functions

apply(model)

@spec apply(Axon.t()) :: Axon.t()

Apply every available rewrite to model.