Add float zero point support for 2-bit LUT GEMM in MatMulNBits#28354
Draft
Add float zero point support for 2-bit LUT GEMM in MatMulNBits#28354
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
Adds float/float16 zero-point handling for the CPU 2-bit MatMulNBits path, primarily by widening the MLAS LUT pack API and teaching the CPU kernel to route/convert float zero points for LUT prepack and unpacked fallback dequantization.
Changes:
- Extends MLAS LUT GEMM packing APIs and AVX2 packing logic to accept float zero-point data.
- Updates the CPU
MatMulNBitskernel to allow LUT prepack with unquantized zero points and adds 2-bit float/FP16 fallback dequantization. - Adds MLAS/unit, provider, and benchmark call-site updates for the new LUT pack signature and float-ZP scenarios.
Findings
onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc: relaxing the early exit for unquantized zero points now lets the LUT prepack path run even whenzero_pointsis dynamic. Because LUT packing only usesTryGetConstantInputand has noinput_idx == zero_pointspack step, dynamic float/float16 zero points will be ignored in the packed path.onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc: both new zero-point conversion blocks assume every non-floatzero-point tensor isMLFloat16. The schema also allowsbfloat16, so constant BF16 zero points would be reinterpreted as FP16 during prepack.onnxruntime/test/contrib_ops/matmul_2bits_test.cc: the new provider test never checksMlasIsLutGemmAvailable(), so on platforms without LUT GEMM it can pass via the unpacked fallback and fail to validate the new LUT float-ZP path.onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc: the new float zero-point 2-bit fallback dequant branch is not covered by the added tests, because the new tests force LUT GEMM on LUT-compatible shapes.onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc: the new MLFloat16-specific 2-bit float zero-point fallback branch also lacks coverage; the added tests only exercise float inputs/zero points.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp |
Adds MLAS float-zero-point LUT GEMM unit coverage and new pack call signature usage. |
onnxruntime/test/mlas/bench/bench_lutgemm.cpp |
Updates benchmark call sites for the widened LUT pack API. |
onnxruntime/test/contrib_ops/matmul_2bits_test.cc |
Adds provider-level float zero-point 2-bit tests. |
onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp |
Implements AVX2 packing support for float zero points. |
onnxruntime/core/mlas/lib/qlutgemm.h |
Updates LUT dispatch typedef to carry generic ZP pointer + type flag. |
onnxruntime/core/mlas/lib/qlutgemm.cpp |
Threads the new zero-point arguments through MLAS LUT pack plumbing. |
onnxruntime/core/mlas/inc/mlas_qnbit.h |
Updates public MLAS LUT pack declaration/docs for float zero points. |
onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc |
Adjusts CPU prepack/compute logic for float zero points in LUT and fallback paths. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if (has_g_idx_) { | ||
| return Status::OK(); | ||
| } | ||
| if (has_unquantized_zero_point_ && !prefer_lut_gemm_) { |
Comment on lines
+264
to
+265
| } else { | ||
| MlasConvertHalfToFloatBuffer(zero_points->Data<MLFloat16>(), zp_fp32_buf.data(), zp_count); |
Comment on lines
+487
to
+488
| } else { | ||
| MlasConvertHalfToFloatBuffer(zero_points->Data<MLFloat16>(), zp_fp32_buf.data(), zp_count); |
Comment on lines
+480
to
+485
| SessionOptions so; | ||
| ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsMlasLutGemm, "1")); | ||
|
|
||
| test.Config(so) | ||
| .ConfigEp(DefaultCpuExecutionProvider()) | ||
| .RunWithConfig(); |
Comment on lines
+925
to
+945
| if (nbits_ == 2) { | ||
| ORT_ENFORCE(reorder_idx_data == nullptr, | ||
| "g_idx (reorder index) is not supported for 2-bit quantization with float zero points"); | ||
| // Simple 2-bit dequantization with float zero points | ||
| const float* float_zp = static_cast<const float*>(zero_points_data); | ||
| int32_t k_blocks = (static_cast<int32_t>(K_) + static_cast<int32_t>(block_size_) - 1) / | ||
| static_cast<int32_t>(block_size_); | ||
| int32_t packed_k = k_blocks * static_cast<int32_t>(block_size_); | ||
| int32_t bytes_per_col = packed_k / 4; | ||
| for (int32_t n = 0; n < static_cast<int32_t>(N_); n++) { | ||
| for (int32_t k = 0; k < static_cast<int32_t>(K_); k++) { | ||
| int32_t block_idx = k / static_cast<int32_t>(block_size_); | ||
| float scale = scales_data[n * k_blocks + block_idx]; | ||
| float zp = float_zp[n * k_blocks + block_idx]; | ||
| int32_t packed_idx = n * bytes_per_col + k / 4; | ||
| int32_t bit_offset = (k % 4) * 2; | ||
| uint8_t q = (b_data[packed_idx] >> bit_offset) & 0x3; | ||
| tmp_b_data_ptr.get()[n * static_cast<int32_t>(K_) + k] = | ||
| (static_cast<float>(q) - zp) * scale; | ||
| } | ||
| } |
Comment on lines
+1085
to
+1105
| if (nbits_ == 2) { | ||
| ORT_ENFORCE(reorder_idx_data == nullptr, | ||
| "g_idx (reorder index) is not supported for 2-bit quantization with float zero points"); | ||
| // Simple 2-bit dequantization with MLFloat16 zero points | ||
| const MLFloat16* fp16_zp = static_cast<const MLFloat16*>(zero_points_data); | ||
| int32_t k_blocks = (static_cast<int32_t>(K_) + static_cast<int32_t>(block_size_) - 1) / | ||
| static_cast<int32_t>(block_size_); | ||
| int32_t packed_k = k_blocks * static_cast<int32_t>(block_size_); | ||
| int32_t bytes_per_col = packed_k / 4; | ||
| for (int32_t n = 0; n < static_cast<int32_t>(N_); n++) { | ||
| for (int32_t k = 0; k < static_cast<int32_t>(K_); k++) { | ||
| int32_t block_idx = k / static_cast<int32_t>(block_size_); | ||
| float scale = scales_ptr[n * k_blocks + block_idx]; | ||
| float zp = fp16_zp[n * k_blocks + block_idx].ToFloat(); | ||
| int32_t packed_idx = n * bytes_per_col + k / 4; | ||
| int32_t bit_offset = (k % 4) * 2; | ||
| uint8_t q = (b_data[packed_idx] >> bit_offset) & 0x3; | ||
| tmp_b_data_ptr.get()[n * static_cast<int32_t>(K_) + k] = | ||
| (static_cast<float>(q) - zp) * scale; | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Adds support for float/float16 zero points in the 2-bit MatMulNBits LUT GEMM path, enabling AMD QAD/Quark 2-bit quantization which requires a fractional zero point of 1.5.
Addresses #28162
Problem
QAD 2-bit quantization uses non-uniform levels
[-1, -1/3, 1/3, 1], expressed viadequant = (q - 1.5) * scale. The zero point 1.5 cannot be represented as a packed uint8 value. The existing LUT GEMM packing API only accepteduint8_t*zero points, and the fallback dequant path crashed withORT_ENFORCE(nbits_ == 4)when encountering 2-bit + float ZP.Changes
MLAS layer — Widened
MlasLutGemmPack()to acceptconst void* QuantBZeroPoint+bool IsFloatZeroPoint, following the existingMlasQNBitGemmPackQuantBDataconvention. The AVX2 packer reads float ZP values directly per quantization group whenIsFloatZeroPointis set, computing the same(zp - midpoint) * scalecorrection stored in the packed buffer. The compute kernel (TMACComputeGemm_avx2) is unchanged — it already consumes ZP as a float correction during accumulation.MatMulNBits CPU kernel — Relaxed the PrePack early-exit guard to allow float ZP into the LUT GEMM path (not non-LUT paths). Added fp16→fp32 conversion for ZP tensors, matching how scales are already handled. Fixed the Compute() path to null out prepacked zero_points to avoid a null dereference in CheckInputs. Fixed the 2-bit fallback dequant path: relaxed the
nbits_==4enforce, added inline 2-bit scalar dequant for float and MLFloat16 ZP with correct packed-B indexing for padded K shapes.Tests — Added MLAS-level float ZP tests across block lengths 32/64/128 with ZP values {0, 1.5, 2, 3}. Added provider-level directed QAD tests (
zp=1.5) verifying end-to-end correctness through the LUT GEMM path.Testing
Files changed
core/mlas/inc/mlas_qnbit.hvoid*ZP +IsFloatZeroPointflagcore/mlas/lib/qlutgemm.hcore/mlas/lib/qlutgemm.cppcore/mlas/lib/sqnbitgemm_lut_kernel_avx2.cppcontrib_ops/cpu/quantization/matmul_nbits.cctest/mlas/unittest/test_sqlutgemm.cpptest/mlas/bench/bench_lutgemm.cpptest/contrib_ops/matmul_2bits_test.cc