Nx.Defn.Compiler implementation that runs defn computations on
Emily.Backend.
The compiler walks Nx.Defn.Expr in Elixir and dispatches each node
through the active backend — exactly what Nx.Defn.Evaluator already
does — with two adjustments specific to Emily:
__to_backend__/1returns{Emily.Backend, [device: …]}soNx.Defn.to_backend/1(and the callers that consult it, includingNx.Serving) allocate inputs and outputs on Emily rather than the process-default backend.__partitions_options__/1always returns a single partition. MLX's Metal runtime was historically unsafe for concurrent kernel dispatch from multiple OS threads.:max_concurrencyis accepted for API compatibility withNx.Servingbut capped at 1. For concurrent inference on a shared model useEmily.Stream.
Public API
Users do not call this module directly. Install it as the default
compiler and Nx.Serving / Bumblebee picks it up:
Nx.Defn.global_default_options(compiler: Emily.Compiler)Or attach it per-call:
Nx.Defn.jit(&my_fn/1, compiler: Emily.Compiler).(input)The four callbacks on Nx.Defn.Compiler (__jit__/5,
__compile__/4, __partitions_options__/1, __to_backend__/1)
are invoked by Nx.Defn on your behalf.
Design notes
__jit__/5 and __compile__/4 delegate to Nx.Defn.Evaluator
after filtering the option list down to the keys this module
consumes. There is no external JIT cache beyond the
closure Nx.Defn.compile/3 already returns: Bumblebee and
Nx.Serving hold that closure on warmup, so subsequent calls skip
the walk.
The compiler does not wrap mlx::core::compile by default. The
single-NIF replay is the load-bearing win (it collapses the per-op
BEAM↔worker round-trips); mx::compile is exposed as an opt-in
compiled eval mode on the program resource, which fuses the
elementwise runs the replay leaves separate. On a decode-shaped
transformer block bench/program_compile.exs measures ~1.6× over the
sync replay (kernel-launch + intermediate-memory overhead dominates at
small sequence lengths, and fusion removes it), at the cost of
last-few-ULP f32 reassociation and a shape-stability requirement —
hence opt-in, not the default for the general compiler.
Options
:device—:gpu(default) or:cpu. Forwarded toEmily.Backendvia the__to_backend__/1callback.:hooks,:debug_options,:garbage_collect— passed through toNx.Defn.Evaluatorunchanged. See its moduledoc.:max_concurrency— accepted forNx.Servingcompatibility, but multi-partition serving is rejected because MLX kernel dispatch isn't thread-safe. Pass1(the default) to silence. For concurrent inference seeEmily.Stream.:batch_keys,:cache— accepted and ignored.Nx.Servingpropagates:batch_keysto the compiler viadefn_optionsfor arity-1 serving builders (e.g.Bumblebee.Audio.speech_to_text_whisper/5), and Bumblebee passes:cachethrough for its own per-scope cache suffixing. Neither is used by the Evaluator walk, but rejecting them would break those servings.:native—true(the default) compiles the tracedNx.Defn.Exprto a flat IR and replays the whole graph in a single NIF call per invocation;falseruns the op-by-op Evaluator walk instead. The default is read fromconfig :emily, :native(itself defaulting totrue), soconfig :emily, native: falseopts every defn out of the native lane application-wide — e.g. on a memory-constrained host where the one-shot compile peak is too large. The per-call option wins over the app env, so an explicitnative: falseoverrides a globalconfig :emily, native: trueand vice versa. A non-boolean raisesArgumentError.:native_fallback—:eval(default) or:raise. Controls what happens whennative: truebut the expression contains an op or construct the IR can't lower yet.:evalroutes the whole defn throughNx.Defn.Evaluator(each op then dispatches throughEmily.Backend, with its own per-opvia_binaryfallback) and fires a one-shot[:emily, :compiler, :fallback]event, so installingcompiler: Emily.Compiler, native: trueglobally is safe on any model.:raisere-raises the lowering error instead — use it in CI to prove a model lowers fully native. The per-call option wins overconfig :emily, :native_fallback, :eval | :raise.:fuse—trueevals the compiled program in themx::compile'd mode instead of the plain replay. For a while-free forward this fuses the elementwise runs the replay leaves separate (the CM6 win); for aBumblebee.Text.generationdefn whileit keeps the decode loop host-controlled but fuses each loop body undermx::compile, replaying the cached fused callable every token. Defaults tofalse; a non-boolean raisesArgumentError. Opt-in because the fusion reassociates f32 to within a few ULP — logits are not bit-identical to the evaluator. Greedy argmax is robust to that drift (greedy token ids matched the evaluator in our tests), but the match is empirical, not guaranteed: any discrete decision the drift can tip — argmax on a near-tie, or awhiletrip count whose condition reads a reassociated reduction — diverges once it flips. Sampling strategies diverge from the evaluator under fusion even with a fixed seed. Only the native path consults it, so it is ignored unlessnative: true.
Any other option is silently dropped. This matches how
Nx.Defn.Evaluator and EXLA handle their own option lists, and is
the contract higher-level libraries rely on when they forward
caller-supplied options to the JIT compiler — e.g. Axon.build/2,
whose docs state that "all other options are forwarded to the
underlying JIT compiler".
Examples
Process-global installation (typical for Nx.Serving / Bumblebee):
Nx.global_default_backend(Emily.Backend)
Nx.Defn.global_default_options(compiler: Emily.Compiler)Per-call:
add_one = Nx.Defn.jit(fn x -> Nx.add(x, 1) end, compiler: Emily.Compiler)
add_one.(Nx.tensor([1.0, 2.0]))
# => #Nx.Tensor<f32[2] [2.0, 3.0]> on Emily.Backend