Emily.Quantization (emily v0.3.5)

Copy Markdown View Source

Quantized inference primitives.

Public API

  • quantized_matmul/2 — eager-mode fused kernel over a materialized %Emily.QuantizedWeight{} and an Nx.Tensor. Calls the MLX quantized_matmul C++ kernel directly; produces Nx.dot(x, to_dense(qw) |> Nx.transpose()) within quantization tolerance.
  • dequantize_defn/1 — defn-native analogue of Emily.QuantizedWeight.to_dense/1, composed from Nx.right_shift / Nx.bitwise_and / multiply / add. Use inside Nx.Defn.jit-traced Axon forward passes where a fused quantized_matmul node isn't available; Nx.dot(x, dequantize_defn(qw)) runs in two kernels (dequantize then dense matmul) instead of one.
  • defn_supported_bits/0 — enumerates the bit widths the defn-native path supports ([2, 4, 8]).

bits ∈ {3, 6} use cross-u32 lane packing; only Emily.QuantizedWeight.to_dense/1 (the native-NIF path) handles them.

See Emily.QuantizedWeight for the container struct and Emily.QuantizedWeight.from_dense/2 for building one.

Summary

Functions

Bit widths supported by dequantize_defn/1 (and therefore by Emily.Quantization.Layers.quantized_dense/4 and any Axon graph rewrite that wires it in).

Reconstruct a dense tensor from a QuantizedWeight, built entirely from Nx primitives so it composes inside defn traces.

Compute x @ W^T where W is represented as a QuantizedWeight.

Functions

defn_supported_bits()

@spec defn_supported_bits() :: [pos_integer()]

Bit widths supported by dequantize_defn/1 (and therefore by Emily.Quantization.Layers.quantized_dense/4 and any Axon graph rewrite that wires it in).

bits ∈ {3, 6} use cross-u32 lane packing and aren't supported by the defn-native path; QuantizedWeight.to_dense/1 (the Native path) still handles them.

Examples

iex> Emily.Quantization.defn_supported_bits()
[2, 4, 8]

dequantize_defn(qw)

@spec dequantize_defn(Emily.QuantizedWeight.t()) :: Nx.Tensor.t()

Reconstruct a dense tensor from a QuantizedWeight, built entirely from Nx primitives so it composes inside defn traces.

This is the defn-compatible analogue of QuantizedWeight.to_dense/1. The math is identical to MLX's dequantize:

w[i] = (w_q_packed >> ((i mod lpu) * bits)) & mask * scales[g] + biases[g]

where lpu = div(32, bits) (lanes per u32), mask = (1 <<< bits) - 1, and g = div(i, group_size) is the group index along the last axis.

Supported: bits ∈ [2, 4, 8]. bits ∈ {3, 6} pack across u32 boundaries and are out of scope here.

Examples

iex> w = Nx.iota({4, 64}, backend: Emily.Backend, type: :f32)
iex> qw = Emily.QuantizedWeight.from_dense(w, group_size: 64, bits: 4)
iex> dense = Emily.Quantization.dequantize_defn(qw)
iex> Nx.shape(dense)
{4, 64}

quantized_matmul(x, qw)

@spec quantized_matmul(Nx.Tensor.t(), Emily.QuantizedWeight.t()) :: Nx.Tensor.t()

Compute x @ W^T where W is represented as a QuantizedWeight.

With qw.transpose == true (the default from QuantizedWeight.from_dense/2) this matches Nx.dot(x, QuantizedWeight.to_dense(qw) |> Nx.transpose()) — i.e. a dense-kernel dot with a pre-transposed, dequantized weight — within MLX's quantization tolerance. With transpose == false, MLX interprets the packed layout as already transposed (the AWQ convention).

Both operands must live on Emily.Backend; pass scalars/tensors from Nx.BinaryBackend and they will be transferred. The input dtype must match qw.scales.type (typically f16, bf16, or f32).

Examples

iex> w = Nx.iota({4, 128}, backend: Emily.Backend, type: :f32)
iex> qw = Emily.QuantizedWeight.from_dense(w)
iex> x = Nx.iota({3, 128}, backend: Emily.Backend, type: :f32)
iex> y = Emily.Quantization.quantized_matmul(x, qw)
iex> Nx.shape(y)
{3, 4}