Quantized inference primitives.
Public API
quantized_matmul/2— eager-mode fused kernel over a materialized%Emily.QuantizedWeight{}and anNx.Tensor. Calls the MLXquantized_matmulC++ kernel directly; producesNx.dot(x, to_dense(qw) |> Nx.transpose())within quantization tolerance.dequantize_defn/1— defn-native analogue ofEmily.QuantizedWeight.to_dense/1, composed fromNx.right_shift/Nx.bitwise_and/ multiply / add. Use insideNx.Defn.jit-traced Axon forward passes where a fusedquantized_matmulnode isn't available;Nx.dot(x, dequantize_defn(qw))runs in two kernels (dequantize then dense matmul) instead of one.defn_supported_bits/0— enumerates the bit widths the defn-native path supports ([2, 3, 4, 6, 8]).
See Emily.QuantizedWeight for the container struct and
Emily.QuantizedWeight.from_dense/2 for building one.
Summary
Functions
Bit widths supported by dequantize_defn/1 (and therefore by
Emily.Quantization.Layers.quantized_dense/4 and any Axon graph
rewrite that wires it in).
Reconstruct a dense tensor from a QuantizedWeight, built entirely
from Nx primitives so it composes inside defn traces.
Compute x @ W^T where W is represented as a QuantizedWeight.
Functions
@spec defn_supported_bits() :: [pos_integer()]
Bit widths supported by dequantize_defn/1 (and therefore by
Emily.Quantization.Layers.quantized_dense/4 and any Axon graph
rewrite that wires it in).
bits ∈ {3, 6} use cross-u32 lane packing and therefore take a
denser unpacking path than the integral-lanes-per-u32 bit widths.
Examples
iex> Emily.Quantization.defn_supported_bits()
[2, 3, 4, 6, 8]
@spec dequantize_defn(Emily.QuantizedWeight.t()) :: Nx.Tensor.t()
Reconstruct a dense tensor from a QuantizedWeight, built entirely
from Nx primitives so it composes inside defn traces.
This is the defn-compatible analogue of QuantizedWeight.to_dense/1.
The math is identical to MLX's dequantize: lane i is extracted
from the packed u32 stream at bit offset i * bits, masked to the
low bits bits, then w[i] = lane * scales[g] + biases[g] where
g = div(i, group_size) is the group index along the last axis.
Supported: bits ∈ [2, 3, 4, 6, 8]. Two unpack
paths are picked at trace time:
bits ∈ {2, 4, 8}— integral lanes per u32, broadcast-shiftw[..., :, lane] = (w_q[..., :] >> (lane * bits)) & mask.bits ∈ {3, 6}— lanes cross u32 boundaries, so we read adjacent u32 pairs as a u64, then shift byrem(i * bits, 32)and mask.
Supported modes: "affine", "mxfp4", "mxfp8", and "nvfp4"
— every MLX QuantizationMode value runs through the defn-native
path. Lane decode: 16-entry FP4-E2M1 LUT for mxfp4 and nvfp4,
256-entry FP8-E4M3 LUT (matching MLX's FromFP8 bit-trick) for
mxfp8. Scale decode: 256-entry FP8-E8M0 LUT (2^(s - 127)) for
mxfp4 and mxfp8, the same FP8-E4M3 LUT for nvfp4's
finer-grained per-group scales. Output dtype is bf16 to match
QuantizedWeight.to_dense/1.
Examples
iex> w = Nx.iota({4, 64}, backend: Emily.Backend, type: :f32)
iex> qw = Emily.QuantizedWeight.from_dense(w, group_size: 64, bits: 4)
iex> dense = Emily.Quantization.dequantize_defn(qw)
iex> Nx.shape(dense)
{4, 64}
@spec quantized_matmul(Nx.Tensor.t(), Emily.QuantizedWeight.t()) :: Nx.Tensor.t()
Compute x @ W^T where W is represented as a QuantizedWeight.
With qw.transpose == true (the default from QuantizedWeight.from_dense/2)
this matches Nx.dot(x, QuantizedWeight.to_dense(qw) |> Nx.transpose())
— i.e. a dense-kernel dot with a pre-transposed, dequantized weight —
within MLX's quantization tolerance. With transpose == false, MLX
interprets the packed layout as already transposed (the AWQ convention).
Both operands must live on Emily.Backend; pass scalars/tensors from
Nx.BinaryBackend and they will be transferred. The input dtype must
match qw.scales.type (typically f16, bf16, or f32).
Examples
iex> w = Nx.iota({4, 128}, backend: Emily.Backend, type: :f32)
iex> qw = Emily.QuantizedWeight.from_dense(w)
iex> x = Nx.iota({3, 128}, backend: Emily.Backend, type: :f32)
iex> y = Emily.Quantization.quantized_matmul(x, qw)
iex> Nx.shape(y)
{3, 4}