LlamaCppEx.MTP (LlamaCppEx v0.8.10)

Copy Markdown View Source

Multi-Token Prediction (MTP) speculative decoding.

Drives a target/draft speculative loop where the draft model is the MTP head embedded in the same GGUF as the target. On Qwen 3.6 with n_draft: 3 this typically yields ~2x token-generation throughput at ~75% draft acceptance.

Usage

:ok = LlamaCppEx.init()
{:ok, model} = LlamaCppEx.load_model("Qwen3.6-35B-A3B-MTP-Q4_K_M.gguf",
                                      n_gpu_layers: 999)

{:ok, mtp} = LlamaCppEx.MTP.init(model, n_draft: 3, n_ctx: 8192)

mtp
|> LlamaCppEx.MTP.stream("Write a haiku about the sea:", max_tokens: 200)
|> Stream.each(&IO.write/1)
|> Stream.run()

stats = LlamaCppEx.MTP.stats(mtp)
IO.puts("acceptance: #{Float.round(stats.acceptance_rate * 100, 1)}%")

The model GGUF must contain MTP head layers (e.g. ggml-org/Qwen3.6-35B-A3B-MTP-GGUF). Loading a non-MTP quant with init/2 will return {:error, _} from common_speculative_init.

Upstream currently requires n_parallel = 1 for MTP. This module reflects that — a single MTP session decodes one sequence at a time. Reuse the same %MTP{} value across calls to stream/3 / generate/3 to avoid rebuilding the contexts; KV caches are cleared on each call.

Summary

Functions

Synchronously generates text. Equivalent to running stream/3 and joining the pieces into a single binary.

Initializes an MTP speculative session: builds the target context, the MTP draft context (ctx_type: :mtp), and the underlying common_speculative state.

Writes upstream's own speculative stats summary to stdout (via llama.cpp logging). Useful when cross-checking acceptance rates against the upstream llama-server benchmark output.

Returns the current MTP statistics snapshot (lock-free read of atomic counters). Safe to call at any time — including from another process while a stream is in flight.

Returns a lazy stream of generated text pieces.

Like stream/3, but yields the raw event tuples emitted by the NIF

Types

t()

@type t() :: %LlamaCppEx.MTP{
  main_ctx: LlamaCppEx.Context.t(),
  mtp_ctx: LlamaCppEx.Context.t(),
  n_draft: pos_integer(),
  spec_ref: reference()
}

Functions

generate(mtp, prompt, opts \\ [])

@spec generate(t(), String.t(), keyword()) :: {:ok, String.t()} | {:error, term()}

Synchronously generates text. Equivalent to running stream/3 and joining the pieces into a single binary.

Accepts the same options as stream/3.

init(model, opts \\ [])

@spec init(
  LlamaCppEx.Model.t(),
  keyword()
) :: {:ok, t()} | {:error, term()}

Initializes an MTP speculative session: builds the target context, the MTP draft context (ctx_type: :mtp), and the underlying common_speculative state.

Options

  • :n_draft - Max draft tokens generated per iteration. Defaults to 3. Larger values mean fewer model forward passes but lower per-iteration acceptance; 2–4 is the sweet spot in practice.
  • :n_ctx - Context size for both contexts. Defaults to 2048.
  • Any LlamaCppEx.Context option (e.g. :n_threads, :flash_attn, :type_k/:type_v, :offload_kqv). The same options are applied to both the target and draft contexts.

Returns {:ok, %MTP{}} or {:error, reason}.

stats(mtp)

@spec stats(t()) :: map()

Returns the current MTP statistics snapshot (lock-free read of atomic counters). Safe to call at any time — including from another process while a stream is in flight.

Returns a map with keys:

  • :iters - speculative loop iterations completed
  • :drafts_generated - draft tokens proposed by the MTP head
  • :drafts_accepted - draft tokens accepted by the target model
  • :acceptance_rate - drafts_accepted / drafts_generated (0.0–1.0)
  • :tokens_emitted - tokens streamed back to the caller
  • :tokens_per_sec - throughput over the active generation window
  • :timing_us - %{draft: μs, verify: μs, sample: μs, total: μs}
  • :n_draft - max draft length configured at init

Counters are cumulative across all stream/3 / generate/3 calls on this MTP value.

stream(mtp, prompt, opts \\ [])

@spec stream(t(), String.t(), keyword()) :: Enumerable.t()

Returns a lazy stream of generated text pieces.

Options

  • :max_tokens - Maximum tokens to generate (default 256).
  • :emit_stats_every - When > 0, also emits {:stats, snapshot_map} events every Nth token via the underlying message stream. Note: these events are filtered out of this String.t() stream — to consume them use stream_events/3 instead. Default 0 (off).
  • :timeout - Receive timeout in milliseconds (default 60_000).
  • Any sampler option from LlamaCppEx.Sampler.create/2 (:temp, :top_k, :top_p, :min_p, :seed, :penalty_*, :grammar, etc.).

Each emitted element is the text piece for one accepted token. The stream ends on end-of-generation, max-tokens, or error.

stream_events(mtp, prompt, opts \\ [])

@spec stream_events(t(), String.t(), keyword()) :: Enumerable.t()

Like stream/3, but yields the raw event tuples emitted by the NIF:

  • {:token, token_id, text_piece} - one accepted token
  • {:stats, snapshot_map} - periodic stats (only when :emit_stats_every > 0)
  • {:done, final_stats_map} - generation completed normally
  • {:eog, nil} - model emitted an end-of-generation token

The stream halts after :done / :eog / :error. The final stats map is available via stats/1 on the MTP struct even after the stream ends.