diff --git a/dequant.py b/dequant.py index 78f5f26..025c682 100644 --- a/dequant.py +++ b/dequant.py @@ -1,5 +1,6 @@ # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) import gguf +import numpy as np import torch from tqdm import tqdm @@ -240,6 +241,24 @@ def dequantize_blocks_Q2_K(blocks, block_size, type_size, dtype=None): # IQ quants KVALUES = torch.tensor([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=torch.int8) +def _get_iq_grid(iq_cls): + iq_cls.init_grid() + return torch.from_numpy(np.array(iq_cls.grid).squeeze().copy()) + +def _get_iq_ksigns(iq_cls): + iq_cls.init_grid() + return torch.from_numpy(np.frombuffer(iq_cls.ksigns, dtype=np.uint8).copy()) + +from gguf.quants import IQ1_M as _IQ1_M, IQ1_S as _IQ1_S, IQ2_S as _IQ2_S, IQ2_XXS as _IQ2_XXS, IQ3_S as _IQ3_S, IQ3_XXS as _IQ3_XXS + +GRID_IQ3_S = _get_iq_grid(_IQ3_S) +GRID_IQ3_XXS = _get_iq_grid(_IQ3_XXS) +GRID_IQ2_S = _get_iq_grid(_IQ2_S) +GRID_IQ2_XXS = _get_iq_grid(_IQ2_XXS) +GRID_IQ1_S = _get_iq_grid(_IQ1_S) +_get_iq_grid(_IQ1_M) # IQ1_M uses the same grid as IQ1_S internally +KSIGNS_IQ2_XXS = _get_iq_ksigns(_IQ2_XXS) + def dequantize_blocks_IQ4_NL(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] @@ -284,6 +303,205 @@ def dequantize_blocks_IQ4_XS(blocks, block_size, type_size, dtype=None): return (dl * qs).reshape((n_blocks, -1)) +def dequantize_blocks_IQ3_S(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, qs, qh, signs, scales = split_block_dims(blocks, 2, 64, 8, 32) + d = d.view(torch.float16).to(dtype) + + scales = scales.view(torch.uint8) + scales = torch.stack([scales & 0xF, scales >> 4], dim=-1).reshape((n_blocks, 8)) + db = d * (1 + 2 * scales.to(dtype)) + db = db.reshape((n_blocks, 8, 1, 1)) + + shifts = torch.arange(8, device=d.device, dtype=torch.uint8).reshape((1, 1, 8)) + signs = (signs.unsqueeze(-1) >> shifts) & 1 + signs = torch.where( + signs == 0, + torch.ones(1, dtype=dtype, device=d.device), + torch.full((1,), -1.0, dtype=dtype, device=d.device), + ) + signs = signs.reshape((n_blocks, 8, 8, 4)) + + qh_bits = (qh.unsqueeze(-1) >> shifts) & 1 + qh_bits = qh_bits.reshape((n_blocks, 64)) + qs = qs.to(torch.int16) | (qh_bits.to(torch.int16) << 8) + + grid = GRID_IQ3_S.to(dtype=dtype, device=d.device) + grid_val = grid[qs.to(torch.long)].reshape((n_blocks, 8, 8, 4)) + + return (db * grid_val * signs).reshape((n_blocks, QK_K)) + +def dequantize_blocks_IQ3_XXS(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, qs, scales, _ = split_block_dims(blocks, 2, 64, 32) + d = d.view(torch.float16).to(dtype) + + scales = scales.reshape((n_blocks, 8, 4)).to(torch.int32) + scales = scales[:, :, 0] | scales[:, :, 1] << 8 | scales[:, :, 2] << 16 | scales[:, :, 3] << 24 + + db = d * (0.5 + ((scales >> 28) & 0xF).to(dtype)) * 0.5 + db = db.reshape((n_blocks, 8, 1, 1)) + + shifts = torch.tensor([0, 7, 14, 21], device=d.device, dtype=torch.int32).reshape((1, 1, 4)) + sign_indices = (scales.reshape((n_blocks, 8, 1)) >> shifts) & 0x7F + + ksigns = KSIGNS_IQ2_XXS.to(d.device) + sign_bytes = ksigns[sign_indices.to(torch.long)] + + shifts_bits = torch.arange(8, device=d.device, dtype=torch.uint8).reshape((1, 1, 1, 8)) + signs = (sign_bytes.unsqueeze(-1) >> shifts_bits) & 1 + signs = torch.where( + signs == 0, + torch.ones(1, dtype=dtype, device=d.device), + torch.full((1,), -1.0, dtype=dtype, device=d.device), + ) + signs = signs.reshape((n_blocks, 8, 4, 8)) + + grid = GRID_IQ3_XXS.to(dtype=dtype, device=d.device) + grid_val = grid[qs.to(torch.long)].reshape((n_blocks, 8, 4, 8)) + + return (db * grid_val * signs).reshape((n_blocks, QK_K)) + +def dequantize_blocks_IQ2_S(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, qs, signs, qh, scales = split_block_dims(blocks, 2, 32, 32, 8) + d = d.view(torch.float16).to(dtype) + + scales = scales.view(torch.uint8) + scales = torch.stack([scales & 0xF, scales >> 4], dim=-1).reshape((n_blocks, 16)) + db = d * (0.5 + scales.to(dtype)) * 0.25 + db = db.reshape((n_blocks, 16, 1, 1)) + + shifts = torch.arange(8, device=d.device, dtype=torch.uint8).reshape((1, 1, 8)) + signs = (signs.unsqueeze(-1) >> shifts) & 1 + signs = torch.where( + signs == 0, + torch.ones(1, dtype=dtype, device=d.device), + torch.full((1,), -1.0, dtype=dtype, device=d.device), + ) + signs = signs.reshape((n_blocks, 16, 2, 8)) + + qh_shifts = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4)) + qh_bits = (qh.view(torch.uint8).reshape((n_blocks, 8, 1)) >> qh_shifts) & 3 + qh_bits = qh_bits.reshape((n_blocks, 32)) + + qs = qs.view(torch.uint8).to(torch.int32) + indices = qs | (qh_bits.to(torch.int32) << 8) + + grid = GRID_IQ2_S.to(dtype=dtype, device=d.device) + grid_val = grid[indices.to(torch.long)].reshape((n_blocks, 16, 2, 8)) + + return (db * grid_val * signs).reshape((n_blocks, QK_K)) + +def dequantize_blocks_IQ2_XXS(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, qs = split_block_dims(blocks, 2) + d = d.view(torch.float16).to(dtype) + + u32 = qs.reshape((n_blocks, 16, 4)).to(torch.int32) + u32 = u32[:, :, 0] | (u32[:, :, 1] << 8) | (u32[:, :, 2] << 16) | (u32[:, :, 3] << 24) + u32 = u32.reshape((n_blocks, 8, 2)) + + q0 = u32[:, :, 0] # grid indices + q1 = u32[:, :, 1] # scales and signs + + db = d * (0.5 + ((q1 >> 28) & 0xF).to(dtype)) * 0.25 + db = db.reshape((n_blocks, 8, 1, 1)) + + shifts = torch.tensor([0, 7, 14, 21], device=d.device, dtype=torch.int32).reshape((1, 1, 4)) + sign_indices = (q1.unsqueeze(-1) >> shifts) & 0x7F + + ksigns = KSIGNS_IQ2_XXS.to(d.device) + sign_bytes = ksigns[sign_indices.to(torch.long)] + + shifts_bits = torch.arange(8, device=d.device, dtype=torch.uint8).reshape((1, 1, 1, 8)) + signs = (sign_bytes.unsqueeze(-1) >> shifts_bits) & 1 + signs = torch.where( + signs == 0, + torch.ones(1, dtype=dtype, device=d.device), + torch.full((1,), -1.0, dtype=dtype, device=d.device), + ) + signs = signs.reshape((n_blocks, 8, 4, 8)) + + indices = q0.contiguous().view(torch.uint8) + grid = GRID_IQ2_XXS.to(dtype=dtype, device=d.device) + grid_val = grid[indices.to(torch.long)].reshape((n_blocks, 8, 4, 8)) + + return (db * grid_val * signs).reshape((n_blocks, QK_K)) + +def dequantize_blocks_IQ1_M(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + qs, qh, scales = split_block_dims(blocks, 32, 16) + + scales_u16 = scales.reshape((n_blocks, 4, 2)).to(torch.int32) + scales_u16 = scales_u16[:, :, 0] | (scales_u16[:, :, 1] << 8) + + d_bits = ( + ((scales_u16[:, 0] & 0xF000) >> 12) + | ((scales_u16[:, 1] & 0xF000) >> 8) + | ((scales_u16[:, 2] & 0xF000) >> 4) + | (scales_u16[:, 3] & 0xF000) + ) + d = d_bits.to(torch.int16).view(torch.float16).to(dtype).reshape((n_blocks, 1)) + + sub_shifts = torch.tensor([0, 3, 6, 9], device=d.device, dtype=torch.int32).reshape((1, 1, 4)) + sub_scales = (scales_u16.reshape((n_blocks, 4, 1)) >> sub_shifts) & 7 + dl = d.reshape((n_blocks, 1, 1)) * (2 * sub_scales.to(dtype) + 1) + dl = dl.reshape((n_blocks, 8, 2, 1, 1)) + + qh_bytes = qh.to(torch.int32) + qh_shifts = torch.tensor([0, 4], device=d.device, dtype=torch.int32).reshape((1, 1, 2)) + qh_unpacked = (qh_bytes.reshape((n_blocks, 16, 1)) >> qh_shifts).reshape((n_blocks, 32)) + + delta = torch.where( + (qh_unpacked & 8) == 0, + torch.full((1,), 0.125, dtype=dtype, device=d.device), + torch.full((1,), -0.125, dtype=dtype, device=d.device), + ).reshape((n_blocks, 8, 2, 2, 1)) + + qh_bits = qh_unpacked & 7 + qs = qs.to(torch.int32) + indices = qs | (qh_bits << 8) + + grid = GRID_IQ1_S.to(dtype=dtype, device=d.device) + grid_val = grid[indices.to(torch.long)].reshape((n_blocks, 8, 2, 2, 8)) + + return (dl * (grid_val + delta)).reshape((n_blocks, QK_K)) + +def dequantize_blocks_IQ1_S(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, qs, qh = split_block_dims(blocks, 2, 32) + d = d.view(torch.float16).to(dtype) + + qh = qh.view(torch.int16).to(torch.int32) & 0xFFFF + + dl = d * (2 * ((qh >> 12) & 7).to(dtype) + 1) + delta = torch.where( + (qh & 0x8000) == 0, + torch.full((1,), 0.125, dtype=dtype, device=d.device), + torch.full((1,), -0.125, dtype=dtype, device=d.device), + ) + + shifts = torch.tensor([0, 3, 6, 9], device=d.device, dtype=torch.int32).reshape((1, 1, 4)) + qh_bits = (qh.reshape((n_blocks, 8, 1)) >> shifts) & 7 + + qs = qs.view(torch.uint8).to(torch.int32).reshape((n_blocks, 8, 4)) + indices = qs | (qh_bits << 8) + + grid = GRID_IQ1_S.to(dtype=dtype, device=d.device) + grid_val = grid[indices.to(torch.long)].reshape((n_blocks, 8, 4, 8)) + + dl = dl.reshape((n_blocks, 8, 1, 1)) + delta = delta.reshape((n_blocks, 8, 1, 1)) + + return (dl * (grid_val + delta)).reshape((n_blocks, QK_K)) + dequantize_functions = { gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16, gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0, @@ -298,4 +516,10 @@ def dequantize_blocks_IQ4_XS(blocks, block_size, type_size, dtype=None): gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K, gguf.GGMLQuantizationType.IQ4_NL: dequantize_blocks_IQ4_NL, gguf.GGMLQuantizationType.IQ4_XS: dequantize_blocks_IQ4_XS, + gguf.GGMLQuantizationType.IQ3_S: dequantize_blocks_IQ3_S, + gguf.GGMLQuantizationType.IQ3_XXS: dequantize_blocks_IQ3_XXS, + gguf.GGMLQuantizationType.IQ2_S: dequantize_blocks_IQ2_S, + gguf.GGMLQuantizationType.IQ2_XXS: dequantize_blocks_IQ2_XXS, + gguf.GGMLQuantizationType.IQ1_M: dequantize_blocks_IQ1_M, + gguf.GGMLQuantizationType.IQ1_S: dequantize_blocks_IQ1_S, }