Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 224 additions & 0 deletions dequant.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
import gguf
import numpy as np
import torch
from tqdm import tqdm

Expand Down Expand Up @@ -240,6 +241,24 @@ def dequantize_blocks_Q2_K(blocks, block_size, type_size, dtype=None):
# IQ quants
KVALUES = torch.tensor([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=torch.int8)

def _get_iq_grid(iq_cls):
iq_cls.init_grid()
return torch.from_numpy(np.array(iq_cls.grid).squeeze().copy())

def _get_iq_ksigns(iq_cls):
iq_cls.init_grid()
return torch.from_numpy(np.frombuffer(iq_cls.ksigns, dtype=np.uint8).copy())

from gguf.quants import IQ1_M as _IQ1_M, IQ1_S as _IQ1_S, IQ2_S as _IQ2_S, IQ2_XXS as _IQ2_XXS, IQ3_S as _IQ3_S, IQ3_XXS as _IQ3_XXS

GRID_IQ3_S = _get_iq_grid(_IQ3_S)
GRID_IQ3_XXS = _get_iq_grid(_IQ3_XXS)
GRID_IQ2_S = _get_iq_grid(_IQ2_S)
GRID_IQ2_XXS = _get_iq_grid(_IQ2_XXS)
GRID_IQ1_S = _get_iq_grid(_IQ1_S)
_get_iq_grid(_IQ1_M) # IQ1_M uses the same grid as IQ1_S internally
KSIGNS_IQ2_XXS = _get_iq_ksigns(_IQ2_XXS)

def dequantize_blocks_IQ4_NL(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]

Expand Down Expand Up @@ -284,6 +303,205 @@ def dequantize_blocks_IQ4_XS(blocks, block_size, type_size, dtype=None):

return (dl * qs).reshape((n_blocks, -1))

def dequantize_blocks_IQ3_S(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]

d, qs, qh, signs, scales = split_block_dims(blocks, 2, 64, 8, 32)
d = d.view(torch.float16).to(dtype)

scales = scales.view(torch.uint8)
scales = torch.stack([scales & 0xF, scales >> 4], dim=-1).reshape((n_blocks, 8))
db = d * (1 + 2 * scales.to(dtype))
db = db.reshape((n_blocks, 8, 1, 1))

shifts = torch.arange(8, device=d.device, dtype=torch.uint8).reshape((1, 1, 8))
signs = (signs.unsqueeze(-1) >> shifts) & 1
signs = torch.where(
signs == 0,
torch.ones(1, dtype=dtype, device=d.device),
torch.full((1,), -1.0, dtype=dtype, device=d.device),
)
signs = signs.reshape((n_blocks, 8, 8, 4))

qh_bits = (qh.unsqueeze(-1) >> shifts) & 1
qh_bits = qh_bits.reshape((n_blocks, 64))
qs = qs.to(torch.int16) | (qh_bits.to(torch.int16) << 8)

grid = GRID_IQ3_S.to(dtype=dtype, device=d.device)
grid_val = grid[qs.to(torch.long)].reshape((n_blocks, 8, 8, 4))

return (db * grid_val * signs).reshape((n_blocks, QK_K))

def dequantize_blocks_IQ3_XXS(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]

d, qs, scales, _ = split_block_dims(blocks, 2, 64, 32)
d = d.view(torch.float16).to(dtype)

scales = scales.reshape((n_blocks, 8, 4)).to(torch.int32)
scales = scales[:, :, 0] | scales[:, :, 1] << 8 | scales[:, :, 2] << 16 | scales[:, :, 3] << 24

db = d * (0.5 + ((scales >> 28) & 0xF).to(dtype)) * 0.5
db = db.reshape((n_blocks, 8, 1, 1))

shifts = torch.tensor([0, 7, 14, 21], device=d.device, dtype=torch.int32).reshape((1, 1, 4))
sign_indices = (scales.reshape((n_blocks, 8, 1)) >> shifts) & 0x7F

ksigns = KSIGNS_IQ2_XXS.to(d.device)
sign_bytes = ksigns[sign_indices.to(torch.long)]

shifts_bits = torch.arange(8, device=d.device, dtype=torch.uint8).reshape((1, 1, 1, 8))
signs = (sign_bytes.unsqueeze(-1) >> shifts_bits) & 1
signs = torch.where(
signs == 0,
torch.ones(1, dtype=dtype, device=d.device),
torch.full((1,), -1.0, dtype=dtype, device=d.device),
)
signs = signs.reshape((n_blocks, 8, 4, 8))

grid = GRID_IQ3_XXS.to(dtype=dtype, device=d.device)
grid_val = grid[qs.to(torch.long)].reshape((n_blocks, 8, 4, 8))

return (db * grid_val * signs).reshape((n_blocks, QK_K))

def dequantize_blocks_IQ2_S(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]

d, qs, signs, qh, scales = split_block_dims(blocks, 2, 32, 32, 8)
d = d.view(torch.float16).to(dtype)

scales = scales.view(torch.uint8)
scales = torch.stack([scales & 0xF, scales >> 4], dim=-1).reshape((n_blocks, 16))
db = d * (0.5 + scales.to(dtype)) * 0.25
db = db.reshape((n_blocks, 16, 1, 1))

shifts = torch.arange(8, device=d.device, dtype=torch.uint8).reshape((1, 1, 8))
signs = (signs.unsqueeze(-1) >> shifts) & 1
signs = torch.where(
signs == 0,
torch.ones(1, dtype=dtype, device=d.device),
torch.full((1,), -1.0, dtype=dtype, device=d.device),
)
signs = signs.reshape((n_blocks, 16, 2, 8))

qh_shifts = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4))
qh_bits = (qh.view(torch.uint8).reshape((n_blocks, 8, 1)) >> qh_shifts) & 3
qh_bits = qh_bits.reshape((n_blocks, 32))

qs = qs.view(torch.uint8).to(torch.int32)
indices = qs | (qh_bits.to(torch.int32) << 8)

grid = GRID_IQ2_S.to(dtype=dtype, device=d.device)
grid_val = grid[indices.to(torch.long)].reshape((n_blocks, 16, 2, 8))

return (db * grid_val * signs).reshape((n_blocks, QK_K))

def dequantize_blocks_IQ2_XXS(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]

d, qs = split_block_dims(blocks, 2)
d = d.view(torch.float16).to(dtype)

u32 = qs.reshape((n_blocks, 16, 4)).to(torch.int32)
u32 = u32[:, :, 0] | (u32[:, :, 1] << 8) | (u32[:, :, 2] << 16) | (u32[:, :, 3] << 24)
u32 = u32.reshape((n_blocks, 8, 2))

q0 = u32[:, :, 0] # grid indices
q1 = u32[:, :, 1] # scales and signs

db = d * (0.5 + ((q1 >> 28) & 0xF).to(dtype)) * 0.25
db = db.reshape((n_blocks, 8, 1, 1))

shifts = torch.tensor([0, 7, 14, 21], device=d.device, dtype=torch.int32).reshape((1, 1, 4))
sign_indices = (q1.unsqueeze(-1) >> shifts) & 0x7F

ksigns = KSIGNS_IQ2_XXS.to(d.device)
sign_bytes = ksigns[sign_indices.to(torch.long)]

shifts_bits = torch.arange(8, device=d.device, dtype=torch.uint8).reshape((1, 1, 1, 8))
signs = (sign_bytes.unsqueeze(-1) >> shifts_bits) & 1
signs = torch.where(
signs == 0,
torch.ones(1, dtype=dtype, device=d.device),
torch.full((1,), -1.0, dtype=dtype, device=d.device),
)
signs = signs.reshape((n_blocks, 8, 4, 8))

indices = q0.contiguous().view(torch.uint8)
grid = GRID_IQ2_XXS.to(dtype=dtype, device=d.device)
grid_val = grid[indices.to(torch.long)].reshape((n_blocks, 8, 4, 8))

return (db * grid_val * signs).reshape((n_blocks, QK_K))

def dequantize_blocks_IQ1_M(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]

qs, qh, scales = split_block_dims(blocks, 32, 16)

scales_u16 = scales.reshape((n_blocks, 4, 2)).to(torch.int32)
scales_u16 = scales_u16[:, :, 0] | (scales_u16[:, :, 1] << 8)

d_bits = (
((scales_u16[:, 0] & 0xF000) >> 12)
| ((scales_u16[:, 1] & 0xF000) >> 8)
| ((scales_u16[:, 2] & 0xF000) >> 4)
| (scales_u16[:, 3] & 0xF000)
)
d = d_bits.to(torch.int16).view(torch.float16).to(dtype).reshape((n_blocks, 1))

sub_shifts = torch.tensor([0, 3, 6, 9], device=d.device, dtype=torch.int32).reshape((1, 1, 4))
sub_scales = (scales_u16.reshape((n_blocks, 4, 1)) >> sub_shifts) & 7
dl = d.reshape((n_blocks, 1, 1)) * (2 * sub_scales.to(dtype) + 1)
dl = dl.reshape((n_blocks, 8, 2, 1, 1))

qh_bytes = qh.to(torch.int32)
qh_shifts = torch.tensor([0, 4], device=d.device, dtype=torch.int32).reshape((1, 1, 2))
qh_unpacked = (qh_bytes.reshape((n_blocks, 16, 1)) >> qh_shifts).reshape((n_blocks, 32))

delta = torch.where(
(qh_unpacked & 8) == 0,
torch.full((1,), 0.125, dtype=dtype, device=d.device),
torch.full((1,), -0.125, dtype=dtype, device=d.device),
).reshape((n_blocks, 8, 2, 2, 1))

qh_bits = qh_unpacked & 7
qs = qs.to(torch.int32)
indices = qs | (qh_bits << 8)

grid = GRID_IQ1_S.to(dtype=dtype, device=d.device)
grid_val = grid[indices.to(torch.long)].reshape((n_blocks, 8, 2, 2, 8))

return (dl * (grid_val + delta)).reshape((n_blocks, QK_K))

def dequantize_blocks_IQ1_S(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]

d, qs, qh = split_block_dims(blocks, 2, 32)
d = d.view(torch.float16).to(dtype)

qh = qh.view(torch.int16).to(torch.int32) & 0xFFFF

dl = d * (2 * ((qh >> 12) & 7).to(dtype) + 1)
delta = torch.where(
(qh & 0x8000) == 0,
torch.full((1,), 0.125, dtype=dtype, device=d.device),
torch.full((1,), -0.125, dtype=dtype, device=d.device),
)

shifts = torch.tensor([0, 3, 6, 9], device=d.device, dtype=torch.int32).reshape((1, 1, 4))
qh_bits = (qh.reshape((n_blocks, 8, 1)) >> shifts) & 7

qs = qs.view(torch.uint8).to(torch.int32).reshape((n_blocks, 8, 4))
indices = qs | (qh_bits << 8)

grid = GRID_IQ1_S.to(dtype=dtype, device=d.device)
grid_val = grid[indices.to(torch.long)].reshape((n_blocks, 8, 4, 8))

dl = dl.reshape((n_blocks, 8, 1, 1))
delta = delta.reshape((n_blocks, 8, 1, 1))

return (dl * (grid_val + delta)).reshape((n_blocks, QK_K))

dequantize_functions = {
gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
Expand All @@ -298,4 +516,10 @@ def dequantize_blocks_IQ4_XS(blocks, block_size, type_size, dtype=None):
gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K,
gguf.GGMLQuantizationType.IQ4_NL: dequantize_blocks_IQ4_NL,
gguf.GGMLQuantizationType.IQ4_XS: dequantize_blocks_IQ4_XS,
gguf.GGMLQuantizationType.IQ3_S: dequantize_blocks_IQ3_S,
gguf.GGMLQuantizationType.IQ3_XXS: dequantize_blocks_IQ3_XXS,
gguf.GGMLQuantizationType.IQ2_S: dequantize_blocks_IQ2_S,
gguf.GGMLQuantizationType.IQ2_XXS: dequantize_blocks_IQ2_XXS,
gguf.GGMLQuantizationType.IQ1_M: dequantize_blocks_IQ1_M,
gguf.GGMLQuantizationType.IQ1_S: dequantize_blocks_IQ1_S,
}