Emily.QuantizedWeight (emily v0.3.3)

Copy Markdown View Source

Container for a matrix quantized via one of MLX's group-wise quantization schemes ("affine" int4/int8, plus the microscaled variants "mxfp4", "mxfp8", "nvfp4").

A QuantizedWeight bundles the packed integer weights with the per-group scale and bias tensors that are needed to recover (or multiply against) the original dense matrix. It derives Nx.Container so it can flow through defn transforms, parameter maps, and backend_transfer alongside regular tensors — while the group_size, bits, transpose, and mode metadata survive the traversal via Nx.Container's keep: option.

Layout

  • :value — packed integer weights. For bits=4 the packing is 8 nibbles per uint32; the last axis therefore shrinks by a factor of 32 / bits.
  • :scales — per-group scale. For "affine" its dtype matches the source weight's dtype and its shape is (..., last_dim / group_size). Microscaled modes store a fused e8m0/e4m3 scale instead (dtype :u8).
  • :biases — per-group bias for "affine". For microscaled modes MLX's fp_quantize doesn't emit biases; the field holds a scalar-zero placeholder so Nx.Container can still traverse it, and the Native layer substitutes nil before dispatching to MLX.
  • :group_size / :bits — the parameters originally passed to MLX's quantize.
  • :transposetrue if the quantized matmul should treat :value as [out, in] (the MLX / fresh-from_dense/2 default). External checkpoint formats (e.g. AWQ) may need false.
  • :mode — quantization mode string. "affine" (default) is the classical int4/int8 scheme with real biases. The microscaled modes ("mxfp4", "mxfp8", "nvfp4") trade biases for a floating-point scale format and carry fixed group_size / bits combinations (see below).

Microscaled modes

MLX's microscaled variants (see QuantizationMode in vendor/mlx/ mlx/primitives.h) each require a specific group_size and bits:

Modegroup_sizebits
mxfp4324
mxfp8328
nvfp4164

Mismatches raise at from_dense/2 time with a clear error before touching MLX. dequantize_defn/1 only understands the affine format; microscaled modes must round-trip through to_dense/1 (the Native path).

Dispatch

Use Emily.Quantization.quantized_matmul/2 to run a fused quantized matmul against a QuantizedWeight. Nx.dot/2 itself cannot accept a QuantizedWeight operand — Nx traverses containers expecting a single tensor — so the direct-call helper is the supported path.

Defn-traced Axon forward passes

Emily.Quantization.dequantize_defn/1 is the defn-native analogue of to_dense/1; pair it with Emily.Quantization.Layers.quantized_dense/4 to splice a quantized linear into any Nx.Defn.jit-traced Axon forward pass. The layer performs Nx.dot(x, dequantize_defn(qw)) instead of the fused quantized_matmul kernel — two dispatches vs one, but fully integrated with the rest of Bumblebee's defn graph. Only mode: "affine" is supported on the defn path today.

Summary

Functions

Quantize a dense float tensor into a QuantizedWeight.

Reconstruct a dense Nx.Tensor from a QuantizedWeight.

Types

mode()

@type mode() :: String.t()

t()

@type t() :: %Emily.QuantizedWeight{
  biases: Nx.Tensor.t(),
  bits: pos_integer(),
  group_size: pos_integer(),
  mode: mode(),
  scales: Nx.Tensor.t(),
  transpose: boolean(),
  value: Nx.Tensor.t()
}

Functions

from_dense(w, opts \\ [])

@spec from_dense(
  Nx.Tensor.t(),
  keyword()
) :: t()

Quantize a dense float tensor into a QuantizedWeight.

The input must live on Emily.Backend (transfer first if you're coming from Nx.BinaryBackend). The last axis must be divisible by :group_size.

Options

  • :group_size — default 64. Elements per quantization group. Microscaled modes pin this to a specific value (see Emily.QuantizedWeight moduledoc).
  • :bits — default 4. One of [2, 3, 4, 6, 8] for "affine"; pinned by the mode for microscaled variants.
  • :transpose — default true. Layout flag threaded to the quantized_matmul kernel. Leave as true for weights produced here; set false when wrapping pre-packed external checkpoints.
  • :mode — default "affine". One of ["affine", "mxfp4", "mxfp8", "nvfp4"].

Examples

iex> w = Nx.iota({4, 64}, backend: Emily.Backend, type: :f32)
iex> qw = Emily.QuantizedWeight.from_dense(w, bits: 4, group_size: 64)
iex> qw.bits
4
iex> qw.group_size
64
iex> qw.mode
"affine"
iex> Nx.shape(qw.value)
{4, 8}

to_dense(quantized_weight)

@spec to_dense(t()) :: Nx.Tensor.t()

Reconstruct a dense Nx.Tensor from a QuantizedWeight.

Useful for oracle comparisons and for transferring a quantized parameter off Emily.Backend (backend transfer traverses each container tensor individually; most consumers want the dense view).

Examples

iex> w = Nx.iota({4, 64}, backend: Emily.Backend, type: :f32)
iex> dense = Emily.QuantizedWeight.from_dense(w) |> Emily.QuantizedWeight.to_dense()
iex> Nx.shape(dense)
{4, 64}