A minimal, model-agnostic decode-loop driver for autoregressive generation on Emily's native compiler.
Emily is an Nx backend, not a model library — so this module supplies
only the mechanism. It JIT-compiles a caller-supplied shape-stable
per-token forward with the single-NIF native compiler (see
Emily.Compiler) and drives the token loop from Elixir: offset
bookkeeping, KV-cache threading, the stop conditions, next-token
selection, and streaming. The caller owns the model: it provides the
forward and the (pre-filled) cache.
This is the "loop in Elixir" half of the generation story — it preserves
per-token streaming and host-side control. (Bumblebee.Text.generation
compiles its own defn while loop instead; that path is handled by the
native while opcode, not this driver.)
The forward contract
The forward is an arity-4 function fn token, offset, cache, params -> {logits, cache} end, traceable by Nx.Defn and shape-stable: its
argument and result shapes must not depend on the runtime value of
offset, so a single compiled program serves every position.
Concretely:
token— ans32{1}tensor (the id to decode at this step),offset— ans32scalar tensor (the absolute position; thread it as a runtime input, e.g. a dynamicNx.put_sliceinto a fixed-size KV buffer plus a length mask, rather than a growing slice),cache— anNx.Containerof fixed-shape KV buffers,params— anNx.Containerof the model weights,
returning {logits, cache} where logits is the last position's logit
vector and cache has the same structure/shapes as the input.
The driver does not bound offset against the cache window — sizing
offset + max_new_tokens to fit the fixed KV buffer is the caller's
responsibility (overflowing it silently corrupts the cache via the
out-of-bounds put_slice, it does not raise).
params is a required argument rather than a closure on purpose: Nx
rejects mixing closed-over Emily.Backend tensors with the traced
Nx.Defn.Expr, and passing them as an argument also hands their refs to
the compiled program zero-copy (captured once).
Example
# `forward`, `cache0`, and `params` come from your model.
tokens =
Emily.Generation.stream(forward,
cache: cache0,
params: params,
first_token: bos_id,
offset: prompt_len,
max_new_tokens: 64,
eos: [eos_id],
on_token: fn id -> send(self(), {:token, id}) end
)Returns the list of generated token ids (including the stop token, if
one is hit). :select defaults to greedy argmax; pass your own
fn logits -> token_tensor end for sampling.
Summary
Functions
Greedy next-token selector: argmax over the vocabulary (last) axis.
Drive an autoregressive decode loop over a shape-stable forward.
Functions
@spec greedy(Nx.Tensor.t()) :: Nx.Tensor.t()
Greedy next-token selector: argmax over the vocabulary (last) axis.
@spec stream( (Nx.Tensor.t(), Nx.Tensor.t(), Nx.Container.t(), Nx.Container.t() -> {Nx.Tensor.t(), Nx.Container.t()}), keyword() ) :: [integer()]
Drive an autoregressive decode loop over a shape-stable forward.
Options
:cache(required) — the initial (pre-filled) KV-cache container.:params(required) — the model weights container, passed to the forward each step.:first_token(required) — the first token id to decode.:max_new_tokens(required) — the maximum number of tokens to emit.:offset— the starting absolute position (default0).:eos— a stop token id or list of ids (default[]).:select—fn logits -> token_tensor end(defaultgreedy/1).:on_token—fn token_id -> any end, called with each generated id as it is produced (default no-op).:defn_options— options forNx.Defn.compile/3(default[compiler: Emily.Compiler, native: true]). Override to disable native compilation or pick a different compiler.
Returns the list of generated token ids.