From 18655af21bf4df7f785c79b9cd3bd768f45b7025 Mon Sep 17 00:00:00 2001 From: Girish Gupta Date: Mon, 25 May 2026 14:47:16 -0700 Subject: [PATCH 1/8] Implement P-SIFT: precondition-and-sketch via modified projections MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per Louis Jaburi's review on #275: compose the random projection with H^{-1/2} into a single saved matrix M = R · cov^{-1/2}, then use M wherever the gradient collector used to sample R. Query and training gradients both go through M, so <ĝ_z, ĝ_q> ≈ without ever applying H⁻¹ in low-dim space. - eigenvectors.py: persist per-side eigenvalues to eigval_{activation, gradient}_sharded/ alongside the existing eigenvectors. - apply_hessian.py: build_kfac_projections builds M = R · Q · diag((E + λ·mean(E))^{-1/2}) · Qᵀ per (layer, side) and writes projection_{left,right}_sharded/. _apply_compressed applies ĝ_q = M_l · G_q · M_rᵀ. Legacy rotate-divide-rotate path is unchanged. Early guard rejects ev_correction + compression (the joint S⊗A eigenvalue correction breaks the Kronecker structure). - config.py + gradient_collectors.py: new IndexConfig.kfac_projection_path pre-populates processor._projection_matrices from disk in GradientCollector.setup(), so HookCollectorBase.projection() returns the saved M instead of sampling a fresh R at hook time. - pipeline.py: visible Step 3.5 builds M and projects the query; Step 4 sets kfac_projection_path so the training-side collector uses the same M. Top-level guard fails fast on projection_dim > 0 + ev_correction=True before any work runs. Untested on GPU (Mac, no CUDA); smoke yaml at runs/p_sift_smoke.yaml. Co-Authored-By: Claude Opus 4.7 --- bergson/collector/gradient_collectors.py | 16 ++ bergson/config.py | 8 + bergson/hessians/apply_hessian.py | 258 +++++++++++++++++++++-- bergson/hessians/eigenvectors.py | 7 + bergson/hessians/pipeline.py | 89 ++++++-- 5 files changed, 341 insertions(+), 37 deletions(-) diff --git a/bergson/collector/gradient_collectors.py b/bergson/collector/gradient_collectors.py index a1c817e9..557eecbc 100644 --- a/bergson/collector/gradient_collectors.py +++ b/bergson/collector/gradient_collectors.py @@ -72,6 +72,22 @@ def setup(self) -> None: self.lo = torch.finfo(self.save_dtype).min self.hi = torch.finfo(self.save_dtype).max + # When a K-FAC compressed-projection path is provided, pre-populate + # the per-layer projection cache with ``M = R · cov^{-1/2}``. The + # backward hook's ``self.projection(...)`` call then hits the cache + # and uses M in place of a freshly sampled random matrix — so the + # gradient comes out preconditioned and sketched in one step. + if self.cfg.kfac_projection_path: + from bergson.hessians.apply_hessian import load_kfac_projections + + load_kfac_projections( + self.cfg.kfac_projection_path, + self.processor._projection_matrices, + self.target_info.keys(), + device=self.model.device, + dtype=self.save_dtype, + ) + self.per_doc_losses = torch.full( (len(self.data),), device=self.model.device, diff --git a/bergson/config.py b/bergson/config.py index 4f5a0fae..96e3df7a 100644 --- a/bergson/config.py +++ b/bergson/config.py @@ -397,6 +397,14 @@ class IndexConfig(AttributionConfig, Serializable): processor_path: str = "" """Path to a precomputed processor.""" + kfac_projection_path: str = "" + """Path to a K-FAC hessian directory containing precomputed modified + projection matrices (``projection_left_sharded/``, + ``projection_right_sharded/``). When set, the gradient collector loads + these instead of sampling fresh random projections — the saved matrices + are ``M = R · cov^{-1/2}``, so applying them to gradients yields + preconditioned-and-sketched gradients in a single matmul (P-SIFT).""" + optimizer_state: str = "" """Source for optimizer second moments used to normalize gradients. Either a local path (a checkpoint directory containing ``optimizer.pt``, diff --git a/bergson/hessians/apply_hessian.py b/bergson/hessians/apply_hessian.py index ea226851..8acd5706 100644 --- a/bergson/hessians/apply_hessian.py +++ b/bergson/hessians/apply_hessian.py @@ -8,13 +8,18 @@ import numpy as np import torch import torch.distributed as dist -from safetensors.torch import load_file +from safetensors import safe_open +from safetensors.torch import load_file, save_file from simple_parsing import ArgumentParser from torch import Tensor from bergson.collector.collector import create_projection_matrix from bergson.data import create_index, load_gradients from bergson.distributed import init_dist +from bergson.hessians.eigenvectors import ( + _compute_full_matrix, + fair_distribute_by_cost, +) from bergson.hessians.sharded_computation import ShardedMul from bergson.utils.logger import get_logger from bergson.utils.utils import get_device @@ -35,6 +40,154 @@ class EkfacConfig: projection_type: Literal["normal", "rademacher"] = "rademacher" +# Map projection "side" (used by the gradient collector) to the K-FAC +# covariance whose inverse-square-root we fold into the random projection. +# - "left" → applied to the output-gradient g, so it consumes S (gradient cov) +# - "right" → applied to the input activation a, so it consumes A (activation cov) +SIDE_TO_COV = {"left": "gradient", "right": "activation"} + + +def build_kfac_projections( + hessian_method_path: str, + projection_dim: int, + projection_type: Literal["normal", "rademacher"], + lambda_damp_factor: float, + dtype: torch.dtype, + device: torch.device, +) -> None: + """Build and save the precondition+sketch matrices ``M = R · cov^{-1/2}``. + + For each module and each side, this composes the random projection ``R`` + with the damped inverse square-root of the K-FAC factor: + + M = R · Q · diag((E + λ·mean(E))^{-1/2}) · Qᵀ [p, d] + + The result is saved to ``hessian_method_path/projection_{side}_sharded/``. + Loaded later by the gradient collector — when present, gradients come + out of collection already preconditioned and sketched in one matmul, so + the apply step and the score step both share the same projection. + + Layers are distributed across ranks via the same fair-by-cost scheme as + the eigendecomposition; each rank writes its own shard. + """ + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + # Discover layer names and per-side dimensions from one shard of each + # eigenvector file. The eigendecomp stores Q with shape [d/W, d] per + # shard, so the column dim is the full layer dimension. + side_dims: dict[str, dict[str, int]] = {"left": {}, "right": {}} + for side, cov in SIDE_TO_COV.items(): + with safe_open( + os.path.join( + hessian_method_path, f"eigen_{cov}_sharded/shard_0.safetensors" + ), + framework="pt", + ) as f: + for name in f.keys(): + side_dims[side][name] = f.get_tensor(name).shape[-1] + + names = list(side_dims["left"].keys()) + # fair_distribute_by_cost treats its values as matrix dimensions and + # ranks by d**3. Pass the per-layer max of the two sides — that bounds + # the dominant matmul cost (M = R · Q · diag(D) · Qᵀ, O(d**3) on the + # larger side). + per_layer_dim = {n: max(side_dims["left"][n], side_dims["right"][n]) for n in names} + my_names = fair_distribute_by_cost(per_layer_dim, world_size)[rank] + + out_dirs = { + side: os.path.join(hessian_method_path, f"projection_{side}_sharded") + for side in ("left", "right") + } + if rank == 0: + for d in out_dirs.values(): + os.makedirs(d, exist_ok=True) + + if dist.is_initialized(): + dist.barrier() + + saved: dict[str, dict[str, Tensor]] = {"left": {}, "right": {}} + for name in my_names: + for side, cov in SIDE_TO_COV.items(): + d = side_dims[side][name] + Q = _compute_full_matrix( + name=name, + shard_path=os.path.join(hessian_method_path, f"eigen_{cov}_sharded"), + rank=rank, + world_size=world_size, + ).to(device=device, dtype=torch.float32) + E = _compute_full_matrix( + name=name, + shard_path=os.path.join(hessian_method_path, f"eigval_{cov}_sharded"), + rank=rank, + world_size=world_size, + ).to(device=device, dtype=torch.float32) + + # Per-side adaptive damping, mirroring ShardedMul._sharded_hadamard + # but applied to each factor independently. + damp = lambda_damp_factor * E.mean() + D = (E + damp).clamp_min(torch.finfo(torch.float32).tiny).rsqrt() # [d] + + # Same identifier convention as HookCollectorBase.projection() so + # the precompute matches the projection the collector would have + # sampled for this layer/side. + R = create_projection_matrix( + f"{name}/{side}", + projection_dim, + d, + torch.float32, + device, + projection_type, + ) + + # M = R · Q · diag(D) · Qᵀ [p, d] + # Q has eigenvectors as columns, so Q·diag(D) scales column k by D[k]; + # that is element-wise (Q * D) broadcasting D across rows. + M = R @ (Q * D) @ Q.T + saved[side][name] = M.to(dtype=dtype).cpu().contiguous() + + for side in ("left", "right"): + save_file( + saved[side], + os.path.join(out_dirs[side], f"shard_{rank}.safetensors"), + ) + + get_logger().info( + f"Saved M_left/M_right to {out_dirs['left']} and {out_dirs['right']}" + ) + + if dist.is_initialized(): + dist.barrier() + + +def load_kfac_projections( + hessian_method_path: str | os.PathLike, + cache: dict, + target_names, + device: torch.device, + dtype: torch.dtype, +) -> None: + """Populate ``cache`` with the per-layer M matrices saved by + ``build_kfac_projections``. ``cache`` is the projection cache used by + ``HookCollectorBase.projection()`` (keyed by ``(name, side, device)``); + a cache hit short-circuits random projection at gradient collection. + + ``dtype`` matches what the collector will request at hook time (the + model's gradient dtype). Saved M lives in fp32 on disk for precision + and is cast on the way in. + """ + targets = set(target_names) + for side in ("left", "right"): + side_dir = Path(hessian_method_path) / f"projection_{side}_sharded" + if not side_dir.exists(): + return + for shard_file in sorted(side_dir.glob("shard_*.safetensors")): + with safe_open(str(shard_file), framework="pt", device=str(device)) as f: + for name in f.keys(): + if name in targets: + cache[(name, side, device)] = f.get_tensor(name).to(dtype=dtype) + + class EkfacApplicator: def __init__(self, cfg: EkfacConfig, apply_fn=None): self.cfg = cfg @@ -53,6 +206,72 @@ def __init__(self, cfg: EkfacConfig, apply_fn=None): self.sharded_computer = ShardedMul() def compute_ivhp_sharded(self): + if self.cfg.projection_dim > 0: + if self.cfg.ev_correction: + raise ValueError( + "K-FAC + random projection (compression) is incompatible " + "with `ev_correction=True`: eigenvalue correction acts on " + "the joint S⊗A spectrum and cannot be folded into a per-side " + "inverse-square-root. Use a Kronecker-factored method " + "(kfac, shampoo) without ev_correction." + ) + return self._apply_compressed() + return self._apply_legacy() + + def _apply_compressed(self): + """Apply M_left · G_q · M_rightᵀ to the saved query gradients. + + Assumes step 3.5 (``build_kfac_projections``) has already saved M + under ``hessian_method_path/projection_{side}_sharded/``. Output + shape per layer is [N, p, p]. + """ + p = self.cfg.projection_dim + + M_left: dict[str, Tensor] = {} + M_right: dict[str, Tensor] = {} + for side, store in (("left", M_left), ("right", M_right)): + side_dir = os.path.join(self.path, f"projection_{side}_sharded") + for shard_file in sorted(Path(side_dir).glob("shard_*.safetensors")): + shard = load_file(str(shard_file), device=str(self.device)) + for k, v in shard.items(): + store[k] = v.to(dtype=torch.float32) + + mmap = load_gradients(self.gradient_path) + with open(os.path.join(self.gradient_path, "info.json")) as f: + info = json.load(f) + + grad_sizes = {name: p * p for name in M_left} + grad_buffer = create_index( + Path(self.cfg.run_path), + num_grads=info["num_grads"], + grad_sizes=grad_sizes, + dtype=np.float32, + ) + + self.logger.info( + f"Loaded gradients for {len(mmap)} queries and applying M·G·Mᵀ..." + ) + + for name, M_l in M_left.items(): + M_r = M_right[name] + d_S, d_A = M_l.shape[1], M_r.shape[1] + G = ( + torch.from_numpy(mmap[name][:]) + .to(device=self.device, dtype=torch.float32) + .view(-1, d_S, d_A) + ) + # ĝ_q = M_left · G · M_rightᵀ [N, p, p] + sketched = torch.einsum("ps,nsa,ra->npr", M_l, G, M_r) + grad_buffer[name][:] = ( + sketched.to(device="cpu", non_blocking=True).flatten(1).numpy() + ) + + torch.cuda.synchronize() if torch.cuda.is_available() else None + grad_buffer.flush() + self.logger.info(f"Saved sketched IVHP gradients to {self.cfg.run_path}") + + def _apply_legacy(self): + """Full-rank IVHP via the eigenbasis rotate-divide-rotate path.""" eigen_a = load_file( self.path + f"/eigen_activation_sharded/shard_{self.rank}.safetensors", device=self.device, @@ -76,10 +295,8 @@ def compute_ivhp_sharded(self): eigen_g[k] = eigen_g[k].to(dtype=torch.float32) lambda_factor[k] = v.to(dtype=torch.float32) - p = self.cfg.projection_dim grad_sizes = { - name: p * p if p > 0 else eigen_g[name].shape[1] * eigen_a[name].shape[1] - for name in eigen_a + name: eigen_g[name].shape[1] * eigen_a[name].shape[1] for name in eigen_a } mmap = load_gradients(self.gradient_path) @@ -157,19 +374,7 @@ def compute_ivhp_sharded(self): del eigen_a gc.collect() - if p > 0: - projection_type = self.cfg.projection_type - for k, v in transformed_gradients.items(): - d_S, d_A = v.shape[-2:] - P_l = create_projection_matrix( - f"{k}/left", p, d_S, v.dtype, v.device, projection_type - ) - P_r = create_projection_matrix( - f"{k}/right", p, d_A, v.dtype, v.device, projection_type - ) - transformed_gradients[k] = torch.einsum("ps,nsa,ra->npr", P_l, v, P_r) - - torch.cuda.synchronize() + torch.cuda.synchronize() if torch.cuda.is_available() else None for k, v in transformed_gradients.items(): grad_buffer[k][:] = v.to(device="cpu", non_blocking=True).flatten(1).numpy() @@ -191,6 +396,25 @@ def apply_worker( applicator.compute_ivhp_sharded() +def build_projections_worker( + rank: int, # global + local_rank: int, # local + world_size: int, + cfg: EkfacConfig, +): + """Worker for Step 3.5: build ``M = R · cov^{-1/2}`` and save to disk.""" + init_dist(rank, local_rank, world_size) + + build_kfac_projections( + cfg.hessian_method_path, + projection_dim=cfg.projection_dim, + projection_type=cfg.projection_type, + lambda_damp_factor=cfg.lambda_damp_factor, + dtype=torch.float32, + device=get_device(rank), + ) + + if __name__ == "__main__": from bergson.config import DistributedConfig from bergson.distributed import launch_distributed_run diff --git a/bergson/hessians/eigenvectors.py b/bergson/hessians/eigenvectors.py index 6531be5c..63988260 100644 --- a/bergson/hessians/eigenvectors.py +++ b/bergson/hessians/eigenvectors.py @@ -326,16 +326,23 @@ def compute_eigendecomposition( dirname = os.path.dirname(covariance_path) basename = os.path.basename(covariance_path) output_path = os.path.join(dirname, "eigen_" + basename) + eigvals_path = os.path.join(dirname, "eigval_" + basename) os.makedirs(output_path, exist_ok=True) + os.makedirs(eigvals_path, exist_ok=True) save_file( covariance_eigenvectors, os.path.join(output_path, f"shard_{rank}.safetensors"), ) + save_file( + covariance_eigenvalues, + os.path.join(eigvals_path, f"shard_{rank}.safetensors"), + ) gc.collect() get_logger().info(f"Saved eigenvectors to {output_path}") + get_logger().info(f"Saved eigenvalues to {eigvals_path}") return covariance_eigenvalues diff --git a/bergson/hessians/pipeline.py b/bergson/hessians/pipeline.py index f98c64ed..0fa6806e 100644 --- a/bergson/hessians/pipeline.py +++ b/bergson/hessians/pipeline.py @@ -14,7 +14,11 @@ from ..distributed import launch_distributed_run from ..score.score import score_dataset from ..utils.worker_utils import validate_run_path -from .apply_hessian import EkfacConfig, apply_worker +from .apply_hessian import ( + EkfacConfig, + apply_worker, + build_projections_worker, +) from .hessian_approximations import approximate_hessians @@ -51,7 +55,10 @@ def hessian_pipeline( 1. Build mean query gradient. 2. Fit Hessian factors (kfac, tkfac, shampoo) on the training dataset. - 3. Apply the inverse Hessian to the mean query gradient. + 3. (legacy) Apply the inverse Hessian to the mean query gradient. + OR + 3.5 (compression) Compute R · cov^{-1/2} (precondition+sketch) and + project the query gradient with the resulting M. 4. Score each training example against the transformed query gradient. """ run_path = index_cfg.run_path @@ -62,6 +69,16 @@ def hessian_pipeline( scores_path = f"{run_path}/scores" resume = hessian_pipeline_cfg.resume + if index_cfg.projection_dim > 0 and hessian_cfg.ev_correction: + # Eigenvalue correction acts on the joint S⊗A spectrum and cannot be + # folded into a per-side inverse-square-root, so it is incompatible + # with the compressed M = R · cov^{-1/2} path. Fail before any work. + raise ValueError( + "Compression (projection_dim > 0) is incompatible with " + "HessianConfig.ev_correction=True. Use a Kronecker-factored " + "method (kfac, tkfac, shampoo) without ev_correction." + ) + def _validate(cfg: IndexConfig): if resume and cfg.partial_run_path.exists(): return @@ -93,25 +110,51 @@ def _validate(cfg: IndexConfig): approximate_hessians(hessian_index_cfg, hessian_cfg) - # ── Step 3: Apply inverse Hessian to the mean query gradient ────────── - print(f"Step 3/4: Applying {method} inverse Hessian to mean query gradient...") - if not _step_complete(transformed_query_path, resume): - hessian_method_path = f"{hessian_path}/{method}" - ekfac_cfg = EkfacConfig( - hessian_method_path=hessian_method_path, - gradient_path=query_path, - run_path=transformed_query_path, - ev_correction=hessian_cfg.ev_correction, - lambda_damp_factor=hessian_pipeline_cfg.lambda_damp_factor, - projection_dim=index_cfg.projection_dim, - projection_type=index_cfg.projection_type, - ) - launch_distributed_run( - "apply_hessian", - apply_worker, - [ekfac_cfg], - index_cfg.distributed, + hessian_method_path = f"{hessian_path}/{method}" + projections_path = f"{hessian_method_path}/projection_left_sharded" + ekfac_cfg = EkfacConfig( + hessian_method_path=hessian_method_path, + gradient_path=query_path, + run_path=transformed_query_path, + ev_correction=hessian_cfg.ev_correction, + lambda_damp_factor=hessian_pipeline_cfg.lambda_damp_factor, + projection_dim=index_cfg.projection_dim, + projection_type=index_cfg.projection_type, + ) + + if index_cfg.projection_dim > 0: + # ── Step 3.5 (compression): Compute R · cov^{-1/2} ──────────────── + # Build M = R · cov^{-1/2} per (layer, side) and save to disk; then + # project the query gradient with the same M. The score step's + # gradient collector loads M so train and query share one sketch. + print("Step 3.5/4: Computing R · cov^{-1/2} (precondition+sketch)...") + if not _step_complete(projections_path, resume): + with _timed("step3.5_build_projections", durations): + launch_distributed_run( + "build_projections", + build_projections_worker, + [ekfac_cfg], + index_cfg.distributed, + ) + if not _step_complete(transformed_query_path, resume): + launch_distributed_run( + "apply_hessian", + apply_worker, + [ekfac_cfg], + index_cfg.distributed, + ) + else: + # ── Step 3 (legacy): Apply inverse Hessian via rotate-divide-rotate + print( + f"Step 3/4: Applying {method} inverse Hessian to mean query " "gradient..." ) + if not _step_complete(transformed_query_path, resume): + launch_distributed_run( + "apply_hessian", + apply_worker, + [ekfac_cfg], + index_cfg.distributed, + ) # ── Step 4: Score training examples ─────────────────────────────────── print("Step 4/4: Scoring training data against transformed query...") @@ -119,6 +162,12 @@ def _validate(cfg: IndexConfig): score_index_cfg = deepcopy(index_cfg) score_index_cfg.run_path = scores_path score_index_cfg.skip_hessians = True + # When compression is on, step 3.5 has saved M = R · cov^{-1/2} + # under the hessian directory. Point the training-side gradient + # collector at it so it projects with M (not a fresh random R) — + # query and training gradients then share the same sketch. + if index_cfg.projection_dim > 0: + score_index_cfg.kfac_projection_path = hessian_method_path score_cfg.query_path = transformed_query_path score_cfg.higher_is_better = True _validate(score_index_cfg) From 0fcd1d4b620b934d172c24a568749f2bf14bfdbb Mon Sep 17 00:00:00 2001 From: Girish Gupta Date: Mon, 25 May 2026 15:07:25 -0700 Subject: [PATCH 2/8] Deletes unnecessary comments --- bergson/collector/gradient_collectors.py | 5 ----- bergson/hessians/apply_hessian.py | 17 +---------------- bergson/hessians/pipeline.py | 11 ----------- 3 files changed, 1 insertion(+), 32 deletions(-) diff --git a/bergson/collector/gradient_collectors.py b/bergson/collector/gradient_collectors.py index 557eecbc..20d0aed1 100644 --- a/bergson/collector/gradient_collectors.py +++ b/bergson/collector/gradient_collectors.py @@ -72,11 +72,6 @@ def setup(self) -> None: self.lo = torch.finfo(self.save_dtype).min self.hi = torch.finfo(self.save_dtype).max - # When a K-FAC compressed-projection path is provided, pre-populate - # the per-layer projection cache with ``M = R · cov^{-1/2}``. The - # backward hook's ``self.projection(...)`` call then hits the cache - # and uses M in place of a freshly sampled random matrix — so the - # gradient comes out preconditioned and sketched in one step. if self.cfg.kfac_projection_path: from bergson.hessians.apply_hessian import load_kfac_projections diff --git a/bergson/hessians/apply_hessian.py b/bergson/hessians/apply_hessian.py index 8acd5706..92fe8c85 100644 --- a/bergson/hessians/apply_hessian.py +++ b/bergson/hessians/apply_hessian.py @@ -40,10 +40,6 @@ class EkfacConfig: projection_type: Literal["normal", "rademacher"] = "rademacher" -# Map projection "side" (used by the gradient collector) to the K-FAC -# covariance whose inverse-square-root we fold into the random projection. -# - "left" → applied to the output-gradient g, so it consumes S (gradient cov) -# - "right" → applied to the input activation a, so it consumes A (activation cov) SIDE_TO_COV = {"left": "gradient", "right": "activation"} @@ -88,10 +84,7 @@ def build_kfac_projections( side_dims[side][name] = f.get_tensor(name).shape[-1] names = list(side_dims["left"].keys()) - # fair_distribute_by_cost treats its values as matrix dimensions and - # ranks by d**3. Pass the per-layer max of the two sides — that bounds - # the dominant matmul cost (M = R · Q · diag(D) · Qᵀ, O(d**3) on the - # larger side). + per_layer_dim = {n: max(side_dims["left"][n], side_dims["right"][n]) for n in names} my_names = fair_distribute_by_cost(per_layer_dim, world_size)[rank] @@ -123,14 +116,9 @@ def build_kfac_projections( world_size=world_size, ).to(device=device, dtype=torch.float32) - # Per-side adaptive damping, mirroring ShardedMul._sharded_hadamard - # but applied to each factor independently. damp = lambda_damp_factor * E.mean() D = (E + damp).clamp_min(torch.finfo(torch.float32).tiny).rsqrt() # [d] - # Same identifier convention as HookCollectorBase.projection() so - # the precompute matches the projection the collector would have - # sampled for this layer/side. R = create_projection_matrix( f"{name}/{side}", projection_dim, @@ -140,9 +128,6 @@ def build_kfac_projections( projection_type, ) - # M = R · Q · diag(D) · Qᵀ [p, d] - # Q has eigenvectors as columns, so Q·diag(D) scales column k by D[k]; - # that is element-wise (Q * D) broadcasting D across rows. M = R @ (Q * D) @ Q.T saved[side][name] = M.to(dtype=dtype).cpu().contiguous() diff --git a/bergson/hessians/pipeline.py b/bergson/hessians/pipeline.py index 0fa6806e..80286d49 100644 --- a/bergson/hessians/pipeline.py +++ b/bergson/hessians/pipeline.py @@ -70,9 +70,6 @@ def hessian_pipeline( resume = hessian_pipeline_cfg.resume if index_cfg.projection_dim > 0 and hessian_cfg.ev_correction: - # Eigenvalue correction acts on the joint S⊗A spectrum and cannot be - # folded into a per-side inverse-square-root, so it is incompatible - # with the compressed M = R · cov^{-1/2} path. Fail before any work. raise ValueError( "Compression (projection_dim > 0) is incompatible with " "HessianConfig.ev_correction=True. Use a Kronecker-factored " @@ -123,10 +120,6 @@ def _validate(cfg: IndexConfig): ) if index_cfg.projection_dim > 0: - # ── Step 3.5 (compression): Compute R · cov^{-1/2} ──────────────── - # Build M = R · cov^{-1/2} per (layer, side) and save to disk; then - # project the query gradient with the same M. The score step's - # gradient collector loads M so train and query share one sketch. print("Step 3.5/4: Computing R · cov^{-1/2} (precondition+sketch)...") if not _step_complete(projections_path, resume): with _timed("step3.5_build_projections", durations): @@ -162,10 +155,6 @@ def _validate(cfg: IndexConfig): score_index_cfg = deepcopy(index_cfg) score_index_cfg.run_path = scores_path score_index_cfg.skip_hessians = True - # When compression is on, step 3.5 has saved M = R · cov^{-1/2} - # under the hessian directory. Point the training-side gradient - # collector at it so it projects with M (not a fresh random R) — - # query and training gradients then share the same sketch. if index_cfg.projection_dim > 0: score_index_cfg.kfac_projection_path = hessian_method_path score_cfg.query_path = transformed_query_path From ebd6169e3bd9df03b21e0beafbf2e229c943582a Mon Sep 17 00:00:00 2001 From: Girish Gupta Date: Mon, 25 May 2026 15:25:10 -0700 Subject: [PATCH 3/8] Removes unnecessary comment --- bergson/hessians/apply_hessian.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/bergson/hessians/apply_hessian.py b/bergson/hessians/apply_hessian.py index 92fe8c85..8165200b 100644 --- a/bergson/hessians/apply_hessian.py +++ b/bergson/hessians/apply_hessian.py @@ -69,9 +69,6 @@ def build_kfac_projections( rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 - # Discover layer names and per-side dimensions from one shard of each - # eigenvector file. The eigendecomp stores Q with shape [d/W, d] per - # shard, so the column dim is the full layer dimension. side_dims: dict[str, dict[str, int]] = {"left": {}, "right": {}} for side, cov in SIDE_TO_COV.items(): with safe_open( From d301985abdf4eab023955c90807ec6a71874d114 Mon Sep 17 00:00:00 2001 From: Girish Gupta Date: Mon, 25 May 2026 22:42:37 +0000 Subject: [PATCH 4/8] Fix P-SIFT projection scaling: unbias the per-side sketch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit create_projection_matrix returns unit-norm rows, so E[RᵀR] = (p/d)·I rather than the identity an unbiased JL sketch requires. For the single global projection this p/d factor is a harmless global constant, but P-SIFT applies the projection per-side and per-layer with differing d, so the score picked up a per-layer p²/(d_left·d_right) reweighting that varies ~4x across gpt2 layer shapes and corrupted ranking vs the exact projection_dim=0 path. Rescale R by √(d/p) in build_kfac_projections to restore E[RᵀR] = I. Verified by projection_dim sweep against the exact (legacy) scores: Spearman ρ vs legacy went from a declining 0.60/0.23/0.30/0.13 (bug) to a climbing 0.64/0.85/0.985/0.994 for p = 16/32/64/128, confirming the sketch now converges to H⁻¹ as p grows. Co-Authored-By: Claude Opus 4.7 (1M context) --- bergson/hessians/apply_hessian.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bergson/hessians/apply_hessian.py b/bergson/hessians/apply_hessian.py index 8165200b..811ac83e 100644 --- a/bergson/hessians/apply_hessian.py +++ b/bergson/hessians/apply_hessian.py @@ -124,6 +124,12 @@ def build_kfac_projections( device, projection_type, ) + # create_projection_matrix returns unit-norm rows, so E[RᵀR] = + # (p/d)·I. An unbiased sketch needs E[RᵀR] = I, so rescale by + # √(d/p); otherwise each side carries a p/d factor and the score + # picks up a per-layer p²/(d_left·d_right) reweighting that + # corrupts ranking relative to the exact (projection_dim=0) path. + R = R * (d / projection_dim) ** 0.5 M = R @ (Q * D) @ Q.T saved[side][name] = M.to(dtype=dtype).cpu().contiguous() From 901a03520e171c4d992b35f19e38ad9decdb4795 Mon Sep 17 00:00:00 2001 From: Girish Gupta Date: Thu, 28 May 2026 21:29:58 -0700 Subject: [PATCH 5/8] Move load_kfac_projections import to module top The deferred import inside GradientCollector.setup() wasn't avoiding a circular import (nothing imports gradient_collectors, and apply_hessian's import closure never reaches it), so hoist it alongside the other bergson imports per the project convention of imports-at-top. Co-Authored-By: Claude Opus 4.8 (1M context) --- bergson/collector/gradient_collectors.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/bergson/collector/gradient_collectors.py b/bergson/collector/gradient_collectors.py index 20d0aed1..f3edaf26 100644 --- a/bergson/collector/gradient_collectors.py +++ b/bergson/collector/gradient_collectors.py @@ -13,6 +13,7 @@ from bergson.builder import Builder from bergson.collector.collector import HookCollectorBase from bergson.config import IndexConfig, PreprocessConfig +from bergson.hessians.apply_hessian import load_kfac_projections from bergson.process_autocorrelation import process_autocorrelation_matrices from bergson.score.scorer import Scorer from bergson.utils.projection import make_global_projector @@ -73,8 +74,6 @@ def setup(self) -> None: self.hi = torch.finfo(self.save_dtype).max if self.cfg.kfac_projection_path: - from bergson.hessians.apply_hessian import load_kfac_projections - load_kfac_projections( self.cfg.kfac_projection_path, self.processor._projection_matrices, From 0177dc6078ca531562aaeebee6b0e4497ddf1f0f Mon Sep 17 00:00:00 2001 From: Girish Gupta Date: Thu, 28 May 2026 21:31:32 -0700 Subject: [PATCH 6/8] Refactor --- bergson/hessians/pipeline.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/bergson/hessians/pipeline.py b/bergson/hessians/pipeline.py index 80286d49..0f56f5ee 100644 --- a/bergson/hessians/pipeline.py +++ b/bergson/hessians/pipeline.py @@ -120,6 +120,7 @@ def _validate(cfg: IndexConfig): ) if index_cfg.projection_dim > 0: + # ── Step 3.5 (compression): build M = R · cov^{-1/2} before applying ── print("Step 3.5/4: Computing R · cov^{-1/2} (precondition+sketch)...") if not _step_complete(projections_path, resume): with _timed("step3.5_build_projections", durations): @@ -129,25 +130,19 @@ def _validate(cfg: IndexConfig): [ekfac_cfg], index_cfg.distributed, ) - if not _step_complete(transformed_query_path, resume): - launch_distributed_run( - "apply_hessian", - apply_worker, - [ekfac_cfg], - index_cfg.distributed, - ) else: - # ── Step 3 (legacy): Apply inverse Hessian via rotate-divide-rotate + # ── Step 3 (legacy): Apply inverse Hessian via rotate-divide-rotate ── print( f"Step 3/4: Applying {method} inverse Hessian to mean query " "gradient..." ) - if not _step_complete(transformed_query_path, resume): - launch_distributed_run( - "apply_hessian", - apply_worker, - [ekfac_cfg], - index_cfg.distributed, - ) + + if not _step_complete(transformed_query_path, resume): + launch_distributed_run( + "apply_hessian", + apply_worker, + [ekfac_cfg], + index_cfg.distributed, + ) # ── Step 4: Score training examples ─────────────────────────────────── print("Step 4/4: Scoring training data against transformed query...") From dd86adbec27f68f31ddfc173cabdf76fe6cdb2c5 Mon Sep 17 00:00:00 2001 From: Girish Gupta Date: Thu, 28 May 2026 21:31:42 -0700 Subject: [PATCH 7/8] Repeated line --- bergson/hessians/eigenvectors.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/bergson/hessians/eigenvectors.py b/bergson/hessians/eigenvectors.py index 63988260..b420acc5 100644 --- a/bergson/hessians/eigenvectors.py +++ b/bergson/hessians/eigenvectors.py @@ -301,9 +301,6 @@ def compute_eigendecomposition( covariance_eigenvalues[key] = ( eigenvalues.to(original_dtype).to(device="cpu").contiguous() ) - covariance_eigenvalues[key] = ( - eigenvalues.to(original_dtype).to(device="cpu").contiguous() - ) covariance_eigenvectors = _gather_and_shard_along_dim_0( input_dict=covariance_eigenvectors, From 208616f48770d1ed5b5cd9468b63da17ca74e196 Mon Sep 17 00:00:00 2001 From: Girish Gupta Date: Thu, 28 May 2026 21:32:33 -0700 Subject: [PATCH 8/8] Adds ValueError --- bergson/hessians/apply_hessian.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/bergson/hessians/apply_hessian.py b/bergson/hessians/apply_hessian.py index 811ac83e..458ba0ff 100644 --- a/bergson/hessians/apply_hessian.py +++ b/bergson/hessians/apply_hessian.py @@ -66,6 +66,12 @@ def build_kfac_projections( Layers are distributed across ranks via the same fair-by-cost scheme as the eigendecomposition; each rank writes its own shard. """ + if projection_dim <= 0: + raise ValueError( + f"build_kfac_projections requires projection_dim > 0 (compression); " + f"got {projection_dim}." + ) + rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 @@ -201,7 +207,7 @@ def compute_ivhp_sharded(self): "with `ev_correction=True`: eigenvalue correction acts on " "the joint S⊗A spectrum and cannot be folded into a per-side " "inverse-square-root. Use a Kronecker-factored method " - "(kfac, shampoo) without ev_correction." + "(kfac, tkfac, shampoo) without ev_correction." ) return self._apply_compressed() return self._apply_legacy()