Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions squish/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,9 +1220,11 @@ def _hf_list_files(repo: str, token: str | None = None) -> list[str]: # pragma:

try:
return list(list_repo_files(repo, token=token))
except Exception:
except Exception as exc: # noqa: BLE001 — HF listing; [] is a documented non-fatal fallback
_LOG.debug("list_repo_files(%s) failed: %s", repo, exc)
return []
except Exception:
except Exception as exc: # noqa: BLE001 — optional HF dependency/import; [] is non-fatal
_LOG.debug("HF repo file listing unavailable for %s: %s", repo, exc)
return []


Expand Down
13 changes: 10 additions & 3 deletions squish/quant/hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,18 +284,25 @@ def decode(self, tensor: HQQTensor) -> np.ndarray:
else:
dim_size = rows
other_dim = cols
# encode() stores axis-1 codes transposed back to the original shape
# (rows, cols); undo that so the group reshape below operates along
# the quantized axis as (other_dim, dim_size).
codes = codes.T

n_groups = tensor.scale.shape[-1]
group_size_actual = max(1, (dim_size + n_groups - 1) // n_groups)
# Use the group size encode actually used — recomputing it as
# ceil(dim_size / n_groups) is wrong whenever dim_size is not an exact
# multiple of group_size, misaligning every group against its scale/zero.
group_size = cfg.group_size if cfg.group_size != -1 else dim_size

padded = n_groups * group_size_actual
padded = n_groups * group_size
if codes.shape[-1] < padded:
codes_pad = np.zeros((other_dim, padded), dtype=np.float32)
codes_pad[:, : codes.shape[-1]] = codes
else:
codes_pad = codes

codes_g = codes_pad.reshape(other_dim, n_groups, group_size_actual)
codes_g = codes_pad.reshape(other_dim, n_groups, group_size)
scales = tensor.scale[:, :, np.newaxis] # (O, G, 1)
zeros = tensor.zero[:, :, np.newaxis]
W_hat = codes_g * scales + zeros
Expand Down
13 changes: 8 additions & 5 deletions squish/streaming/streaming_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,23 @@ class SinkStats:
Attributes:
n_tokens_seen: Total tokens added since last :meth:`SinkKVCache.reset`.
n_evictions: Number of tokens evicted from the rolling window.
window_size: Rolling-window capacity, used as the util_fraction denominator.
"""

n_tokens_seen: int = 0
n_evictions: int = 0
window_size: int = 0

@property
def util_fraction(self) -> float:
"""Fraction of the rolling window currently occupied (0–1).

Returns 0.0 before any tokens are added. Value is based on
``n_tokens_seen`` relative to the window size, clamped to [0.0, 1.0].
This is a snapshot metric only — callers needing exact occupancy
should inspect :attr:`SinkKVCache.n_recent` directly.
Returns 0.0 before any tokens are added. Value is ``n_tokens_seen``
relative to the window size, clamped to [0.0, 1.0]. This is a snapshot
metric only — callers needing exact occupancy should inspect
:attr:`SinkKVCache.n_recent` directly.
"""
return min(1.0, float(self.n_tokens_seen) / max(1, self.n_tokens_seen))
return min(1.0, float(self.n_tokens_seen) / max(1, self.window_size))

@property
def total_tokens_held(self) -> int:
Expand Down Expand Up @@ -237,6 +239,7 @@ def get_stats(self) -> SinkStats:
return SinkStats(
n_tokens_seen=self._n_tokens_seen,
n_evictions=self._n_evictions,
window_size=self._config.window_size,
)

# ------------------------------------------------------------ convenience
Expand Down
57 changes: 57 additions & 0 deletions tests/quant/test_hqq_decode_group_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Regression: HQQ decode must use the stored config group size.

decode() recomputed group_size as ceil(dim_size / n_groups), which differs from
the real group size whenever dim_size is not an exact multiple of it — every
group then misaligned against its scale/zero and reconstruction error blew up.
"""
from __future__ import annotations

import numpy as np
import pytest

from squish.quant.hqq import HQQConfig, HQQQuantizer


def _rel_err(a: np.ndarray, b: np.ndarray) -> float:
return float(np.linalg.norm(a - b) / np.linalg.norm(a))


@pytest.mark.parametrize("dim,group_size", [(100, 30), (130, 64), (96, 32), (100, 25)])
@pytest.mark.parametrize("axis", [0, 1])
def test_decode_roundtrip_non_divisible_group(dim, group_size, axis):
rng = np.random.default_rng(0)
w = (rng.standard_normal((dim, 4)) if axis == 1
else rng.standard_normal((4, dim))).astype(np.float32)
q = HQQQuantizer(HQQConfig(bits=4, group_size=group_size, axis=axis))
recon = q.decode(q.encode(w))
assert recon.shape == w.shape
# 4-bit HQQ on unit Gaussian keeps relative error well under 0.15.
assert _rel_err(w, recon) < 0.15


def test_axis1_roundtrip_was_broken_before_fix():
# decode() never transposed the stored axis-1 codes back, so axis=1 raised
# a broadcast error end-to-end (even on aligned dims).
rng = np.random.default_rng(5)
w = rng.standard_normal((96, 4)).astype(np.float32)
q = HQQQuantizer(HQQConfig(bits=4, group_size=32, axis=1))
recon = q.decode(q.encode(w))
assert recon.shape == w.shape
assert _rel_err(w, recon) < 0.15


def test_non_divisible_was_broken_before_fix():
# Sharpened guard: the non-aligned case used to be ~0.30 rel error.
rng = np.random.default_rng(1)
w = rng.standard_normal((4, 100)).astype(np.float32)
q = HQQQuantizer(HQQConfig(bits=4, group_size=30, axis=0))
assert _rel_err(w, q.decode(q.encode(w))) < 0.15


def test_full_row_group_size_minus_one():
rng = np.random.default_rng(2)
w = rng.standard_normal((4, 100)).astype(np.float32)
q = HQQQuantizer(HQQConfig(bits=4, group_size=-1, axis=0))
recon = q.decode(q.encode(w))
assert recon.shape == w.shape
assert np.isfinite(recon).all()
29 changes: 29 additions & 0 deletions tests/streaming/test_sink_util_fraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Regression: SinkStats.util_fraction must divide by window size, not itself.

The formula was n_tokens_seen / max(1, n_tokens_seen), which is 1.0 for any
nonzero count regardless of window size — so a barely-used cache reported 100%
utilization.
"""
from __future__ import annotations

from squish.streaming.streaming_sink import SinkConfig, SinkKVCache, SinkStats


def test_partial_window_reports_real_fraction():
assert SinkStats(n_tokens_seen=5, window_size=256).util_fraction == 5 / 256


def test_zero_tokens_is_zero():
assert SinkStats(n_tokens_seen=0, window_size=256).util_fraction == 0.0


def test_overfull_is_clamped_to_one():
assert SinkStats(n_tokens_seen=512, window_size=256).util_fraction == 1.0


def test_get_stats_populates_window_size():
cache = SinkKVCache(SinkConfig(n_sink_tokens=4, window_size=128), n_heads=2, head_dim=8)
stats = cache.get_stats()
assert stats.window_size == 128
# A fresh cache has seen no tokens → 0 utilization (not 1.0).
assert stats.util_fraction == 0.0
Loading