Emily.Generation (emily v0.7.1)

Copy Markdown View Source

A minimal, model-agnostic decode-loop driver for autoregressive generation on Emily's native compiler.

Emily is an Nx backend, not a model library — so this module supplies only the mechanism. It JIT-compiles a caller-supplied shape-stable per-token forward with the single-NIF native compiler (see Emily.Compiler) and drives the token loop from Elixir: offset bookkeeping, KV-cache threading, the stop conditions, next-token selection, and streaming. The caller owns the model: it provides the forward and the (pre-filled) cache.

This is the "loop in Elixir" half of the generation story — it preserves per-token streaming and host-side control. (Bumblebee.Text.generation compiles its own defn while loop instead; that path is handled by the native while opcode, not this driver.)

The forward contract

The forward is an arity-4 function fn token, offset, cache, params -> {logits, cache} end, traceable by Nx.Defn and shape-stable: its argument and result shapes must not depend on the runtime value of offset, so a single compiled program serves every position. Concretely:

  • token — an s32 {1} tensor (the id to decode at this step),
  • offset — an s32 scalar tensor (the absolute position; thread it as a runtime input, e.g. a dynamic Nx.put_slice into a fixed-size KV buffer plus a length mask, rather than a growing slice),
  • cache — an Nx.Container of fixed-shape KV buffers,
  • params — an Nx.Container of the model weights,

returning {logits, cache} where logits is the last position's logit vector and cache has the same structure/shapes as the input.

The driver does not bound offset against the cache window — sizing offset + max_new_tokens to fit the fixed KV buffer is the caller's responsibility (overflowing it silently corrupts the cache via the out-of-bounds put_slice, it does not raise).

params is a required argument rather than a closure on purpose: Nx rejects mixing closed-over Emily.Backend tensors with the traced Nx.Defn.Expr, and passing them as an argument also hands their refs to the compiled program zero-copy (captured once).

Example

# `forward`, `cache0`, and `params` come from your model.
tokens =
  Emily.Generation.stream(forward,
    cache: cache0,
    params: params,
    first_token: bos_id,
    offset: prompt_len,
    max_new_tokens: 64,
    eos: [eos_id],
    on_token: fn id -> send(self(), {:token, id}) end
  )

Returns the list of generated token ids (including the stop token, if one is hit). :select defaults to greedy argmax; pass your own fn logits -> token_tensor end for sampling.

Summary

Functions

Greedy next-token selector: argmax over the vocabulary (last) axis.

Drive an autoregressive decode loop over a shape-stable forward.

Functions

greedy(logits)

@spec greedy(Nx.Tensor.t()) :: Nx.Tensor.t()

Greedy next-token selector: argmax over the vocabulary (last) axis.

stream(forward, opts)

Drive an autoregressive decode loop over a shape-stable forward.

Options

  • :cache (required) — the initial (pre-filled) KV-cache container.
  • :params (required) — the model weights container, passed to the forward each step.
  • :first_token (required) — the first token id to decode.
  • :max_new_tokens (required) — the maximum number of tokens to emit.
  • :offset — the starting absolute position (default 0).
  • :eos — a stop token id or list of ids (default []).
  • :selectfn logits -> token_tensor end (default greedy/1).
  • :on_tokenfn token_id -> any end, called with each generated id as it is produced (default no-op).
  • :defn_options — options for Nx.Defn.compile/3 (default [compiler: Emily.Compiler, native: true]). Override to disable native compilation or pick a different compiler.

Returns the list of generated token ids.