Emily.Quantization (emily v0.5.1)

Copy Markdown View Source

Quantized inference primitives.

Public API

  • quantized_matmul/2 — eager-mode fused kernel over a materialized %Emily.QuantizedWeight{} and an Nx.Tensor. Calls the MLX quantized_matmul C++ kernel directly; produces Nx.dot(x, to_dense(qw) |> Nx.transpose()) within quantization tolerance.
  • dequantize_defn/1 — defn-native analogue of Emily.QuantizedWeight.to_dense/1, composed from Nx.right_shift / Nx.bitwise_and / multiply / add. Use inside Nx.Defn.jit-traced Axon forward passes where a fused quantized_matmul node isn't available; Nx.dot(x, dequantize_defn(qw)) runs in two kernels (dequantize then dense matmul) instead of one.
  • defn_supported_bits/0 — enumerates the bit widths the defn-native path supports ([2, 3, 4, 6, 8]).

See Emily.QuantizedWeight for the container struct and Emily.QuantizedWeight.from_dense/2 for building one.

Summary

Functions

Bit widths supported by dequantize_defn/1 (and therefore by Emily.Quantization.Layers.quantized_dense/4 and any Axon graph rewrite that wires it in).

Reconstruct a dense tensor from a QuantizedWeight, built entirely from Nx primitives so it composes inside defn traces.

Compute x @ W^T where W is represented as a QuantizedWeight.

Functions

defn_supported_bits()

@spec defn_supported_bits() :: [pos_integer()]

Bit widths supported by dequantize_defn/1 (and therefore by Emily.Quantization.Layers.quantized_dense/4 and any Axon graph rewrite that wires it in).

bits ∈ {3, 6} use cross-u32 lane packing and therefore take a denser unpacking path than the integral-lanes-per-u32 bit widths.

Examples

iex> Emily.Quantization.defn_supported_bits()
[2, 3, 4, 6, 8]

dequantize_defn(qw)

@spec dequantize_defn(Emily.QuantizedWeight.t()) :: Nx.Tensor.t()

Reconstruct a dense tensor from a QuantizedWeight, built entirely from Nx primitives so it composes inside defn traces.

This is the defn-compatible analogue of QuantizedWeight.to_dense/1. The math is identical to MLX's dequantize: lane i is extracted from the packed u32 stream at bit offset i * bits, masked to the low bits bits, then w[i] = lane * scales[g] + biases[g] where g = div(i, group_size) is the group index along the last axis.

Supported: bits ∈ [2, 3, 4, 6, 8]. Two unpack paths are picked at trace time:

  • bits ∈ {2, 4, 8} — integral lanes per u32, broadcast-shift w[..., :, lane] = (w_q[..., :] >> (lane * bits)) & mask.
  • bits ∈ {3, 6} — lanes cross u32 boundaries, so we read adjacent u32 pairs as a u64, then shift by rem(i * bits, 32) and mask.

Supported modes: "affine", "mxfp4", "mxfp8", and "nvfp4" — every MLX QuantizationMode value runs through the defn-native path. Lane decode: 16-entry FP4-E2M1 LUT for mxfp4 and nvfp4, 256-entry FP8-E4M3 LUT (matching MLX's FromFP8 bit-trick) for mxfp8. Scale decode: 256-entry FP8-E8M0 LUT (2^(s - 127)) for mxfp4 and mxfp8, the same FP8-E4M3 LUT for nvfp4's finer-grained per-group scales. Output dtype is bf16 to match QuantizedWeight.to_dense/1.

Examples

iex> w = Nx.iota({4, 64}, backend: Emily.Backend, type: :f32)
iex> qw = Emily.QuantizedWeight.from_dense(w, group_size: 64, bits: 4)
iex> dense = Emily.Quantization.dequantize_defn(qw)
iex> Nx.shape(dense)
{4, 64}

quantized_matmul(x, qw)

@spec quantized_matmul(Nx.Tensor.t(), Emily.QuantizedWeight.t()) :: Nx.Tensor.t()

Compute x @ W^T where W is represented as a QuantizedWeight.

With qw.transpose == true (the default from QuantizedWeight.from_dense/2) this matches Nx.dot(x, QuantizedWeight.to_dense(qw) |> Nx.transpose()) — i.e. a dense-kernel dot with a pre-transposed, dequantized weight — within MLX's quantization tolerance. With transpose == false, MLX interprets the packed layout as already transposed (the AWQ convention).

Both operands must live on Emily.Backend; pass scalars/tensors from Nx.BinaryBackend and they will be transferred. The input dtype must match qw.scales.type (typically f16, bf16, or f32).

Examples

iex> w = Nx.iota({4, 128}, backend: Emily.Backend, type: :f32)
iex> qw = Emily.QuantizedWeight.from_dense(w)
iex> x = Nx.iota({3, 128}, backend: Emily.Backend, type: :f32)
iex> y = Emily.Quantization.quantized_matmul(x, qw)
iex> Nx.shape(y)
{3, 4}