Skip to content

Add float zero point support for 2-bit LUT GEMM in MatMulNBits#28354

Draft
vraspar wants to merge 1 commit intomainfrom
vraspar/matmulnbits-float-zp
Draft

Add float zero point support for 2-bit LUT GEMM in MatMulNBits#28354
vraspar wants to merge 1 commit intomainfrom
vraspar/matmulnbits-float-zp

Conversation

@vraspar
Copy link
Copy Markdown
Contributor

@vraspar vraspar commented May 4, 2026

Description

Adds support for float/float16 zero points in the 2-bit MatMulNBits LUT GEMM path, enabling AMD QAD/Quark 2-bit quantization which requires a fractional zero point of 1.5.

Addresses #28162

Problem

QAD 2-bit quantization uses non-uniform levels [-1, -1/3, 1/3, 1], expressed via dequant = (q - 1.5) * scale. The zero point 1.5 cannot be represented as a packed uint8 value. The existing LUT GEMM packing API only accepted uint8_t* zero points, and the fallback dequant path crashed with ORT_ENFORCE(nbits_ == 4) when encountering 2-bit + float ZP.

Changes

MLAS layer — Widened MlasLutGemmPack() to accept const void* QuantBZeroPoint + bool IsFloatZeroPoint, following the existing MlasQNBitGemmPackQuantBData convention. The AVX2 packer reads float ZP values directly per quantization group when IsFloatZeroPoint is set, computing the same (zp - midpoint) * scale correction stored in the packed buffer. The compute kernel (TMACComputeGemm_avx2) is unchanged — it already consumes ZP as a float correction during accumulation.

MatMulNBits CPU kernel — Relaxed the PrePack early-exit guard to allow float ZP into the LUT GEMM path (not non-LUT paths). Added fp16→fp32 conversion for ZP tensors, matching how scales are already handled. Fixed the Compute() path to null out prepacked zero_points to avoid a null dereference in CheckInputs. Fixed the 2-bit fallback dequant path: relaxed the nbits_==4 enforce, added inline 2-bit scalar dequant for float and MLFloat16 ZP with correct packed-B indexing for padded K shapes.

Tests — Added MLAS-level float ZP tests across block lengths 32/64/128 with ZP values {0, 1.5, 2, 3}. Added provider-level directed QAD tests (zp=1.5) verifying end-to-end correctness through the LUT GEMM path.

Testing

  • 72 MLAS LUT GEMM tests pass (including 36 new float ZP tests)
  • 13 provider-level 2-bit tests pass (including new QAD float ZP tests)
  • No regressions in existing uint8 ZP tests
  • lintrunner clean

Files changed

File Change
core/mlas/inc/mlas_qnbit.h API: void* ZP + IsFloatZeroPoint flag
core/mlas/lib/qlutgemm.h Dispatch typedef update
core/mlas/lib/qlutgemm.cpp Pass-through plumbing
core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp Float ZP packing branch
contrib_ops/cpu/quantization/matmul_nbits.cc PrePack guard, fallback fix, ZP validation
test/mlas/unittest/test_sqlutgemm.cpp Float ZP MLAS tests
test/mlas/bench/bench_lutgemm.cpp Updated call signature
test/contrib_ops/matmul_2bits_test.cc Float ZP provider tests

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds float/float16 zero-point handling for the CPU 2-bit MatMulNBits path, primarily by widening the MLAS LUT pack API and teaching the CPU kernel to route/convert float zero points for LUT prepack and unpacked fallback dequantization.

Changes:

  • Extends MLAS LUT GEMM packing APIs and AVX2 packing logic to accept float zero-point data.
  • Updates the CPU MatMulNBits kernel to allow LUT prepack with unquantized zero points and adds 2-bit float/FP16 fallback dequantization.
  • Adds MLAS/unit, provider, and benchmark call-site updates for the new LUT pack signature and float-ZP scenarios.

Findings

  • onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc: relaxing the early exit for unquantized zero points now lets the LUT prepack path run even when zero_points is dynamic. Because LUT packing only uses TryGetConstantInput and has no input_idx == zero_points pack step, dynamic float/float16 zero points will be ignored in the packed path.
  • onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc: both new zero-point conversion blocks assume every non-float zero-point tensor is MLFloat16. The schema also allows bfloat16, so constant BF16 zero points would be reinterpreted as FP16 during prepack.
  • onnxruntime/test/contrib_ops/matmul_2bits_test.cc: the new provider test never checks MlasIsLutGemmAvailable(), so on platforms without LUT GEMM it can pass via the unpacked fallback and fail to validate the new LUT float-ZP path.
  • onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc: the new float zero-point 2-bit fallback dequant branch is not covered by the added tests, because the new tests force LUT GEMM on LUT-compatible shapes.
  • onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc: the new MLFloat16-specific 2-bit float zero-point fallback branch also lacks coverage; the added tests only exercise float inputs/zero points.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp Adds MLAS float-zero-point LUT GEMM unit coverage and new pack call signature usage.
onnxruntime/test/mlas/bench/bench_lutgemm.cpp Updates benchmark call sites for the widened LUT pack API.
onnxruntime/test/contrib_ops/matmul_2bits_test.cc Adds provider-level float zero-point 2-bit tests.
onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp Implements AVX2 packing support for float zero points.
onnxruntime/core/mlas/lib/qlutgemm.h Updates LUT dispatch typedef to carry generic ZP pointer + type flag.
onnxruntime/core/mlas/lib/qlutgemm.cpp Threads the new zero-point arguments through MLAS LUT pack plumbing.
onnxruntime/core/mlas/inc/mlas_qnbit.h Updates public MLAS LUT pack declaration/docs for float zero points.
onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc Adjusts CPU prepack/compute logic for float zero points in LUT and fallback paths.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

if (has_g_idx_) {
return Status::OK();
}
if (has_unquantized_zero_point_ && !prefer_lut_gemm_) {
Comment on lines +264 to +265
} else {
MlasConvertHalfToFloatBuffer(zero_points->Data<MLFloat16>(), zp_fp32_buf.data(), zp_count);
Comment on lines +487 to +488
} else {
MlasConvertHalfToFloatBuffer(zero_points->Data<MLFloat16>(), zp_fp32_buf.data(), zp_count);
Comment on lines +480 to +485
SessionOptions so;
ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsMlasLutGemm, "1"));

test.Config(so)
.ConfigEp(DefaultCpuExecutionProvider())
.RunWithConfig();
Comment on lines +925 to +945
if (nbits_ == 2) {
ORT_ENFORCE(reorder_idx_data == nullptr,
"g_idx (reorder index) is not supported for 2-bit quantization with float zero points");
// Simple 2-bit dequantization with float zero points
const float* float_zp = static_cast<const float*>(zero_points_data);
int32_t k_blocks = (static_cast<int32_t>(K_) + static_cast<int32_t>(block_size_) - 1) /
static_cast<int32_t>(block_size_);
int32_t packed_k = k_blocks * static_cast<int32_t>(block_size_);
int32_t bytes_per_col = packed_k / 4;
for (int32_t n = 0; n < static_cast<int32_t>(N_); n++) {
for (int32_t k = 0; k < static_cast<int32_t>(K_); k++) {
int32_t block_idx = k / static_cast<int32_t>(block_size_);
float scale = scales_data[n * k_blocks + block_idx];
float zp = float_zp[n * k_blocks + block_idx];
int32_t packed_idx = n * bytes_per_col + k / 4;
int32_t bit_offset = (k % 4) * 2;
uint8_t q = (b_data[packed_idx] >> bit_offset) & 0x3;
tmp_b_data_ptr.get()[n * static_cast<int32_t>(K_) + k] =
(static_cast<float>(q) - zp) * scale;
}
}
Comment on lines +1085 to +1105
if (nbits_ == 2) {
ORT_ENFORCE(reorder_idx_data == nullptr,
"g_idx (reorder index) is not supported for 2-bit quantization with float zero points");
// Simple 2-bit dequantization with MLFloat16 zero points
const MLFloat16* fp16_zp = static_cast<const MLFloat16*>(zero_points_data);
int32_t k_blocks = (static_cast<int32_t>(K_) + static_cast<int32_t>(block_size_) - 1) /
static_cast<int32_t>(block_size_);
int32_t packed_k = k_blocks * static_cast<int32_t>(block_size_);
int32_t bytes_per_col = packed_k / 4;
for (int32_t n = 0; n < static_cast<int32_t>(N_); n++) {
for (int32_t k = 0; k < static_cast<int32_t>(K_); k++) {
int32_t block_idx = k / static_cast<int32_t>(block_size_);
float scale = scales_ptr[n * k_blocks + block_idx];
float zp = fp16_zp[n * k_blocks + block_idx].ToFloat();
int32_t packed_idx = n * bytes_per_col + k / 4;
int32_t bit_offset = (k % 4) * 2;
uint8_t q = (b_data[packed_idx] >> bit_offset) & 0x3;
tmp_b_data_ptr.get()[n * static_cast<int32_t>(K_) + k] =
(static_cast<float>(q) - zp) * scale;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants