From 7de582af241f2073ed561171df8db5d5cc25f1a8 Mon Sep 17 00:00:00 2001 From: Octopus Date: Mon, 6 Apr 2026 23:30:11 +0800 Subject: [PATCH] feat: add torch dequantization for IQ1_S, IQ1_M, IQ2_XXS, IQ2_S, IQ3_XXS, IQ3_S Implements native PyTorch dequantization functions for lower IQ quant types, replacing the slow numpy fallback path for models quantized with these formats (e.g. Unsloth UD quants used as text encoders). All six new functions are verified against gguf.quants.dequantize() reference. --- dequant.py | 224 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 224 insertions(+) diff --git a/dequant.py b/dequant.py index 78f5f26..025c682 100644 --- a/dequant.py +++ b/dequant.py @@ -1,5 +1,6 @@ # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) import gguf +import numpy as np import torch from tqdm import tqdm @@ -240,6 +241,24 @@ def dequantize_blocks_Q2_K(blocks, block_size, type_size, dtype=None): # IQ quants KVALUES = torch.tensor([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=torch.int8) +def _get_iq_grid(iq_cls): + iq_cls.init_grid() + return torch.from_numpy(np.array(iq_cls.grid).squeeze().copy()) + +def _get_iq_ksigns(iq_cls): + iq_cls.init_grid() + return torch.from_numpy(np.frombuffer(iq_cls.ksigns, dtype=np.uint8).copy()) + +from gguf.quants import IQ1_M as _IQ1_M, IQ1_S as _IQ1_S, IQ2_S as _IQ2_S, IQ2_XXS as _IQ2_XXS, IQ3_S as _IQ3_S, IQ3_XXS as _IQ3_XXS + +GRID_IQ3_S = _get_iq_grid(_IQ3_S) +GRID_IQ3_XXS = _get_iq_grid(_IQ3_XXS) +GRID_IQ2_S = _get_iq_grid(_IQ2_S) +GRID_IQ2_XXS = _get_iq_grid(_IQ2_XXS) +GRID_IQ1_S = _get_iq_grid(_IQ1_S) +_get_iq_grid(_IQ1_M) # IQ1_M uses the same grid as IQ1_S internally +KSIGNS_IQ2_XXS = _get_iq_ksigns(_IQ2_XXS) + def dequantize_blocks_IQ4_NL(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] @@ -284,6 +303,205 @@ def dequantize_blocks_IQ4_XS(blocks, block_size, type_size, dtype=None): return (dl * qs).reshape((n_blocks, -1)) +def dequantize_blocks_IQ3_S(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, qs, qh, signs, scales = split_block_dims(blocks, 2, 64, 8, 32) + d = d.view(torch.float16).to(dtype) + + scales = scales.view(torch.uint8) + scales = torch.stack([scales & 0xF, scales >> 4], dim=-1).reshape((n_blocks, 8)) + db = d * (1 + 2 * scales.to(dtype)) + db = db.reshape((n_blocks, 8, 1, 1)) + + shifts = torch.arange(8, device=d.device, dtype=torch.uint8).reshape((1, 1, 8)) + signs = (signs.unsqueeze(-1) >> shifts) & 1 + signs = torch.where( + signs == 0, + torch.ones(1, dtype=dtype, device=d.device), + torch.full((1,), -1.0, dtype=dtype, device=d.device), + ) + signs = signs.reshape((n_blocks, 8, 8, 4)) + + qh_bits = (qh.unsqueeze(-1) >> shifts) & 1 + qh_bits = qh_bits.reshape((n_blocks, 64)) + qs = qs.to(torch.int16) | (qh_bits.to(torch.int16) << 8) + + grid = GRID_IQ3_S.to(dtype=dtype, device=d.device) + grid_val = grid[qs.to(torch.long)].reshape((n_blocks, 8, 8, 4)) + + return (db * grid_val * signs).reshape((n_blocks, QK_K)) + +def dequantize_blocks_IQ3_XXS(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, qs, scales, _ = split_block_dims(blocks, 2, 64, 32) + d = d.view(torch.float16).to(dtype) + + scales = scales.reshape((n_blocks, 8, 4)).to(torch.int32) + scales = scales[:, :, 0] | scales[:, :, 1] << 8 | scales[:, :, 2] << 16 | scales[:, :, 3] << 24 + + db = d * (0.5 + ((scales >> 28) & 0xF).to(dtype)) * 0.5 + db = db.reshape((n_blocks, 8, 1, 1)) + + shifts = torch.tensor([0, 7, 14, 21], device=d.device, dtype=torch.int32).reshape((1, 1, 4)) + sign_indices = (scales.reshape((n_blocks, 8, 1)) >> shifts) & 0x7F + + ksigns = KSIGNS_IQ2_XXS.to(d.device) + sign_bytes = ksigns[sign_indices.to(torch.long)] + + shifts_bits = torch.arange(8, device=d.device, dtype=torch.uint8).reshape((1, 1, 1, 8)) + signs = (sign_bytes.unsqueeze(-1) >> shifts_bits) & 1 + signs = torch.where( + signs == 0, + torch.ones(1, dtype=dtype, device=d.device), + torch.full((1,), -1.0, dtype=dtype, device=d.device), + ) + signs = signs.reshape((n_blocks, 8, 4, 8)) + + grid = GRID_IQ3_XXS.to(dtype=dtype, device=d.device) + grid_val = grid[qs.to(torch.long)].reshape((n_blocks, 8, 4, 8)) + + return (db * grid_val * signs).reshape((n_blocks, QK_K)) + +def dequantize_blocks_IQ2_S(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, qs, signs, qh, scales = split_block_dims(blocks, 2, 32, 32, 8) + d = d.view(torch.float16).to(dtype) + + scales = scales.view(torch.uint8) + scales = torch.stack([scales & 0xF, scales >> 4], dim=-1).reshape((n_blocks, 16)) + db = d * (0.5 + scales.to(dtype)) * 0.25 + db = db.reshape((n_blocks, 16, 1, 1)) + + shifts = torch.arange(8, device=d.device, dtype=torch.uint8).reshape((1, 1, 8)) + signs = (signs.unsqueeze(-1) >> shifts) & 1 + signs = torch.where( + signs == 0, + torch.ones(1, dtype=dtype, device=d.device), + torch.full((1,), -1.0, dtype=dtype, device=d.device), + ) + signs = signs.reshape((n_blocks, 16, 2, 8)) + + qh_shifts = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4)) + qh_bits = (qh.view(torch.uint8).reshape((n_blocks, 8, 1)) >> qh_shifts) & 3 + qh_bits = qh_bits.reshape((n_blocks, 32)) + + qs = qs.view(torch.uint8).to(torch.int32) + indices = qs | (qh_bits.to(torch.int32) << 8) + + grid = GRID_IQ2_S.to(dtype=dtype, device=d.device) + grid_val = grid[indices.to(torch.long)].reshape((n_blocks, 16, 2, 8)) + + return (db * grid_val * signs).reshape((n_blocks, QK_K)) + +def dequantize_blocks_IQ2_XXS(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, qs = split_block_dims(blocks, 2) + d = d.view(torch.float16).to(dtype) + + u32 = qs.reshape((n_blocks, 16, 4)).to(torch.int32) + u32 = u32[:, :, 0] | (u32[:, :, 1] << 8) | (u32[:, :, 2] << 16) | (u32[:, :, 3] << 24) + u32 = u32.reshape((n_blocks, 8, 2)) + + q0 = u32[:, :, 0] # grid indices + q1 = u32[:, :, 1] # scales and signs + + db = d * (0.5 + ((q1 >> 28) & 0xF).to(dtype)) * 0.25 + db = db.reshape((n_blocks, 8, 1, 1)) + + shifts = torch.tensor([0, 7, 14, 21], device=d.device, dtype=torch.int32).reshape((1, 1, 4)) + sign_indices = (q1.unsqueeze(-1) >> shifts) & 0x7F + + ksigns = KSIGNS_IQ2_XXS.to(d.device) + sign_bytes = ksigns[sign_indices.to(torch.long)] + + shifts_bits = torch.arange(8, device=d.device, dtype=torch.uint8).reshape((1, 1, 1, 8)) + signs = (sign_bytes.unsqueeze(-1) >> shifts_bits) & 1 + signs = torch.where( + signs == 0, + torch.ones(1, dtype=dtype, device=d.device), + torch.full((1,), -1.0, dtype=dtype, device=d.device), + ) + signs = signs.reshape((n_blocks, 8, 4, 8)) + + indices = q0.contiguous().view(torch.uint8) + grid = GRID_IQ2_XXS.to(dtype=dtype, device=d.device) + grid_val = grid[indices.to(torch.long)].reshape((n_blocks, 8, 4, 8)) + + return (db * grid_val * signs).reshape((n_blocks, QK_K)) + +def dequantize_blocks_IQ1_M(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + qs, qh, scales = split_block_dims(blocks, 32, 16) + + scales_u16 = scales.reshape((n_blocks, 4, 2)).to(torch.int32) + scales_u16 = scales_u16[:, :, 0] | (scales_u16[:, :, 1] << 8) + + d_bits = ( + ((scales_u16[:, 0] & 0xF000) >> 12) + | ((scales_u16[:, 1] & 0xF000) >> 8) + | ((scales_u16[:, 2] & 0xF000) >> 4) + | (scales_u16[:, 3] & 0xF000) + ) + d = d_bits.to(torch.int16).view(torch.float16).to(dtype).reshape((n_blocks, 1)) + + sub_shifts = torch.tensor([0, 3, 6, 9], device=d.device, dtype=torch.int32).reshape((1, 1, 4)) + sub_scales = (scales_u16.reshape((n_blocks, 4, 1)) >> sub_shifts) & 7 + dl = d.reshape((n_blocks, 1, 1)) * (2 * sub_scales.to(dtype) + 1) + dl = dl.reshape((n_blocks, 8, 2, 1, 1)) + + qh_bytes = qh.to(torch.int32) + qh_shifts = torch.tensor([0, 4], device=d.device, dtype=torch.int32).reshape((1, 1, 2)) + qh_unpacked = (qh_bytes.reshape((n_blocks, 16, 1)) >> qh_shifts).reshape((n_blocks, 32)) + + delta = torch.where( + (qh_unpacked & 8) == 0, + torch.full((1,), 0.125, dtype=dtype, device=d.device), + torch.full((1,), -0.125, dtype=dtype, device=d.device), + ).reshape((n_blocks, 8, 2, 2, 1)) + + qh_bits = qh_unpacked & 7 + qs = qs.to(torch.int32) + indices = qs | (qh_bits << 8) + + grid = GRID_IQ1_S.to(dtype=dtype, device=d.device) + grid_val = grid[indices.to(torch.long)].reshape((n_blocks, 8, 2, 2, 8)) + + return (dl * (grid_val + delta)).reshape((n_blocks, QK_K)) + +def dequantize_blocks_IQ1_S(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, qs, qh = split_block_dims(blocks, 2, 32) + d = d.view(torch.float16).to(dtype) + + qh = qh.view(torch.int16).to(torch.int32) & 0xFFFF + + dl = d * (2 * ((qh >> 12) & 7).to(dtype) + 1) + delta = torch.where( + (qh & 0x8000) == 0, + torch.full((1,), 0.125, dtype=dtype, device=d.device), + torch.full((1,), -0.125, dtype=dtype, device=d.device), + ) + + shifts = torch.tensor([0, 3, 6, 9], device=d.device, dtype=torch.int32).reshape((1, 1, 4)) + qh_bits = (qh.reshape((n_blocks, 8, 1)) >> shifts) & 7 + + qs = qs.view(torch.uint8).to(torch.int32).reshape((n_blocks, 8, 4)) + indices = qs | (qh_bits << 8) + + grid = GRID_IQ1_S.to(dtype=dtype, device=d.device) + grid_val = grid[indices.to(torch.long)].reshape((n_blocks, 8, 4, 8)) + + dl = dl.reshape((n_blocks, 8, 1, 1)) + delta = delta.reshape((n_blocks, 8, 1, 1)) + + return (dl * (grid_val + delta)).reshape((n_blocks, QK_K)) + dequantize_functions = { gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16, gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0, @@ -298,4 +516,10 @@ def dequantize_blocks_IQ4_XS(blocks, block_size, type_size, dtype=None): gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K, gguf.GGMLQuantizationType.IQ4_NL: dequantize_blocks_IQ4_NL, gguf.GGMLQuantizationType.IQ4_XS: dequantize_blocks_IQ4_XS, + gguf.GGMLQuantizationType.IQ3_S: dequantize_blocks_IQ3_S, + gguf.GGMLQuantizationType.IQ3_XXS: dequantize_blocks_IQ3_XXS, + gguf.GGMLQuantizationType.IQ2_S: dequantize_blocks_IQ2_S, + gguf.GGMLQuantizationType.IQ2_XXS: dequantize_blocks_IQ2_XXS, + gguf.GGMLQuantizationType.IQ1_M: dequantize_blocks_IQ1_M, + gguf.GGMLQuantizationType.IQ1_S: dequantize_blocks_IQ1_S, }