-
Notifications
You must be signed in to change notification settings - Fork 22
P-SIFT: precondition-and-sketch via modified projections (Louis's review of #275) #288
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: feat/first-class-kfac-build-support
Are you sure you want to change the base?
Changes from all commits
18655af
0fcd1d4
ebd6169
d301985
901a035
0177dc6
dd86adb
208616f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,13 +8,18 @@ | |
| import numpy as np | ||
| import torch | ||
| import torch.distributed as dist | ||
| from safetensors.torch import load_file | ||
| from safetensors import safe_open | ||
| from safetensors.torch import load_file, save_file | ||
| from simple_parsing import ArgumentParser | ||
| from torch import Tensor | ||
|
|
||
| from bergson.collector.collector import create_projection_matrix | ||
| from bergson.data import create_index, load_gradients | ||
| from bergson.distributed import init_dist | ||
| from bergson.hessians.eigenvectors import ( | ||
| _compute_full_matrix, | ||
| fair_distribute_by_cost, | ||
| ) | ||
| from bergson.hessians.sharded_computation import ShardedMul | ||
| from bergson.utils.logger import get_logger | ||
| from bergson.utils.utils import get_device | ||
|
|
@@ -35,6 +40,148 @@ class EkfacConfig: | |
| projection_type: Literal["normal", "rademacher"] = "rademacher" | ||
|
|
||
|
|
||
| SIDE_TO_COV = {"left": "gradient", "right": "activation"} | ||
|
|
||
|
|
||
| def build_kfac_projections( | ||
| hessian_method_path: str, | ||
| projection_dim: int, | ||
| projection_type: Literal["normal", "rademacher"], | ||
| lambda_damp_factor: float, | ||
| dtype: torch.dtype, | ||
| device: torch.device, | ||
| ) -> None: | ||
| """Build and save the precondition+sketch matrices ``M = R · cov^{-1/2}``. | ||
|
|
||
| For each module and each side, this composes the random projection ``R`` | ||
| with the damped inverse square-root of the K-FAC factor: | ||
|
|
||
| M = R · Q · diag((E + λ·mean(E))^{-1/2}) · Qᵀ [p, d] | ||
|
|
||
| The result is saved to ``hessian_method_path/projection_{side}_sharded/``. | ||
| Loaded later by the gradient collector — when present, gradients come | ||
| out of collection already preconditioned and sketched in one matmul, so | ||
| the apply step and the score step both share the same projection. | ||
|
|
||
| Layers are distributed across ranks via the same fair-by-cost scheme as | ||
| the eigendecomposition; each rank writes its own shard. | ||
| """ | ||
| if projection_dim <= 0: | ||
| raise ValueError( | ||
| f"build_kfac_projections requires projection_dim > 0 (compression); " | ||
| f"got {projection_dim}." | ||
| ) | ||
|
|
||
| rank = dist.get_rank() if dist.is_initialized() else 0 | ||
| world_size = dist.get_world_size() if dist.is_initialized() else 1 | ||
|
Comment on lines
+75
to
+76
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make sure you use the correct rank and world size (important for multi node setups)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (it seems like they should be the global ones?) |
||
|
|
||
| side_dims: dict[str, dict[str, int]] = {"left": {}, "right": {}} | ||
| for side, cov in SIDE_TO_COV.items(): | ||
| with safe_open( | ||
| os.path.join( | ||
| hessian_method_path, f"eigen_{cov}_sharded/shard_0.safetensors" | ||
| ), | ||
| framework="pt", | ||
| ) as f: | ||
| for name in f.keys(): | ||
| side_dims[side][name] = f.get_tensor(name).shape[-1] | ||
|
|
||
| names = list(side_dims["left"].keys()) | ||
|
|
||
| per_layer_dim = {n: max(side_dims["left"][n], side_dims["right"][n]) for n in names} | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. name is somewhat confusing and I am not sure I understand what this is doing. If you want to fairly distribute by cost, then looking at the max is not necessarily correct? |
||
| my_names = fair_distribute_by_cost(per_layer_dim, world_size)[rank] | ||
|
|
||
| out_dirs = { | ||
| side: os.path.join(hessian_method_path, f"projection_{side}_sharded") | ||
| for side in ("left", "right") | ||
| } | ||
| if rank == 0: | ||
| for d in out_dirs.values(): | ||
| os.makedirs(d, exist_ok=True) | ||
|
|
||
| if dist.is_initialized(): | ||
| dist.barrier() | ||
|
|
||
| saved: dict[str, dict[str, Tensor]] = {"left": {}, "right": {}} | ||
| for name in my_names: | ||
| for side, cov in SIDE_TO_COV.items(): | ||
| d = side_dims[side][name] | ||
| Q = _compute_full_matrix( | ||
| name=name, | ||
| shard_path=os.path.join(hessian_method_path, f"eigen_{cov}_sharded"), | ||
| rank=rank, | ||
| world_size=world_size, | ||
| ).to(device=device, dtype=torch.float32) | ||
| E = _compute_full_matrix( | ||
| name=name, | ||
| shard_path=os.path.join(hessian_method_path, f"eigval_{cov}_sharded"), | ||
| rank=rank, | ||
| world_size=world_size, | ||
| ).to(device=device, dtype=torch.float32) | ||
|
|
||
| damp = lambda_damp_factor * E.mean() | ||
| D = (E + damp).clamp_min(torch.finfo(torch.float32).tiny).rsqrt() # [d] | ||
|
|
||
| R = create_projection_matrix( | ||
| f"{name}/{side}", | ||
| projection_dim, | ||
| d, | ||
| torch.float32, | ||
| device, | ||
| projection_type, | ||
| ) | ||
| # create_projection_matrix returns unit-norm rows, so E[RᵀR] = | ||
| # (p/d)·I. An unbiased sketch needs E[RᵀR] = I, so rescale by | ||
| # √(d/p); otherwise each side carries a p/d factor and the score | ||
| # picks up a per-layer p²/(d_left·d_right) reweighting that | ||
| # corrupts ranking relative to the exact (projection_dim=0) path. | ||
| R = R * (d / projection_dim) ** 0.5 | ||
|
Comment on lines
+133
to
+138
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure how relevant that is, we look at relative rankings only anyway
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suppose we can leave it there |
||
|
|
||
| M = R @ (Q * D) @ Q.T | ||
| saved[side][name] = M.to(dtype=dtype).cpu().contiguous() | ||
|
|
||
| for side in ("left", "right"): | ||
| save_file( | ||
| saved[side], | ||
| os.path.join(out_dirs[side], f"shard_{rank}.safetensors"), | ||
| ) | ||
|
|
||
| get_logger().info( | ||
| f"Saved M_left/M_right to {out_dirs['left']} and {out_dirs['right']}" | ||
| ) | ||
|
|
||
| if dist.is_initialized(): | ||
| dist.barrier() | ||
|
|
||
|
|
||
| def load_kfac_projections( | ||
| hessian_method_path: str | os.PathLike, | ||
| cache: dict, | ||
| target_names, | ||
| device: torch.device, | ||
| dtype: torch.dtype, | ||
| ) -> None: | ||
| """Populate ``cache`` with the per-layer M matrices saved by | ||
| ``build_kfac_projections``. ``cache`` is the projection cache used by | ||
| ``HookCollectorBase.projection()`` (keyed by ``(name, side, device)``); | ||
| a cache hit short-circuits random projection at gradient collection. | ||
|
|
||
| ``dtype`` matches what the collector will request at hook time (the | ||
| model's gradient dtype). Saved M lives in fp32 on disk for precision | ||
| and is cast on the way in. | ||
| """ | ||
| targets = set(target_names) | ||
| for side in ("left", "right"): | ||
| side_dir = Path(hessian_method_path) / f"projection_{side}_sharded" | ||
| if not side_dir.exists(): | ||
| return | ||
| for shard_file in sorted(side_dir.glob("shard_*.safetensors")): | ||
| with safe_open(str(shard_file), framework="pt", device=str(device)) as f: | ||
| for name in f.keys(): | ||
| if name in targets: | ||
| cache[(name, side, device)] = f.get_tensor(name).to(dtype=dtype) | ||
|
|
||
|
|
||
| class EkfacApplicator: | ||
| def __init__(self, cfg: EkfacConfig, apply_fn=None): | ||
| self.cfg = cfg | ||
|
|
@@ -53,6 +200,72 @@ def __init__(self, cfg: EkfacConfig, apply_fn=None): | |
| self.sharded_computer = ShardedMul() | ||
|
|
||
| def compute_ivhp_sharded(self): | ||
| if self.cfg.projection_dim > 0: | ||
| if self.cfg.ev_correction: | ||
| raise ValueError( | ||
| "K-FAC + random projection (compression) is incompatible " | ||
| "with `ev_correction=True`: eigenvalue correction acts on " | ||
| "the joint S⊗A spectrum and cannot be folded into a per-side " | ||
| "inverse-square-root. Use a Kronecker-factored method " | ||
| "(kfac, tkfac, shampoo) without ev_correction." | ||
| ) | ||
| return self._apply_compressed() | ||
| return self._apply_legacy() | ||
|
|
||
| def _apply_compressed(self): | ||
| """Apply M_left · G_q · M_rightᵀ to the saved query gradients. | ||
|
|
||
| Assumes step 3.5 (``build_kfac_projections``) has already saved M | ||
| under ``hessian_method_path/projection_{side}_sharded/``. Output | ||
| shape per layer is [N, p, p]. | ||
| """ | ||
| p = self.cfg.projection_dim | ||
|
|
||
| M_left: dict[str, Tensor] = {} | ||
| M_right: dict[str, Tensor] = {} | ||
| for side, store in (("left", M_left), ("right", M_right)): | ||
| side_dir = os.path.join(self.path, f"projection_{side}_sharded") | ||
| for shard_file in sorted(Path(side_dir).glob("shard_*.safetensors")): | ||
| shard = load_file(str(shard_file), device=str(self.device)) | ||
| for k, v in shard.items(): | ||
| store[k] = v.to(dtype=torch.float32) | ||
|
|
||
| mmap = load_gradients(self.gradient_path) | ||
| with open(os.path.join(self.gradient_path, "info.json")) as f: | ||
| info = json.load(f) | ||
|
|
||
| grad_sizes = {name: p * p for name in M_left} | ||
| grad_buffer = create_index( | ||
| Path(self.cfg.run_path), | ||
| num_grads=info["num_grads"], | ||
| grad_sizes=grad_sizes, | ||
| dtype=np.float32, | ||
| ) | ||
|
|
||
| self.logger.info( | ||
| f"Loaded gradients for {len(mmap)} queries and applying M·G·Mᵀ..." | ||
| ) | ||
|
|
||
| for name, M_l in M_left.items(): | ||
| M_r = M_right[name] | ||
| d_S, d_A = M_l.shape[1], M_r.shape[1] | ||
| G = ( | ||
| torch.from_numpy(mmap[name][:]) | ||
| .to(device=self.device, dtype=torch.float32) | ||
| .view(-1, d_S, d_A) | ||
| ) | ||
| # ĝ_q = M_left · G · M_rightᵀ [N, p, p] | ||
| sketched = torch.einsum("ps,nsa,ra->npr", M_l, G, M_r) | ||
| grad_buffer[name][:] = ( | ||
| sketched.to(device="cpu", non_blocking=True).flatten(1).numpy() | ||
| ) | ||
|
|
||
| torch.cuda.synchronize() if torch.cuda.is_available() else None | ||
| grad_buffer.flush() | ||
| self.logger.info(f"Saved sketched IVHP gradients to {self.cfg.run_path}") | ||
|
|
||
| def _apply_legacy(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this legacy? |
||
| """Full-rank IVHP via the eigenbasis rotate-divide-rotate path.""" | ||
| eigen_a = load_file( | ||
| self.path + f"/eigen_activation_sharded/shard_{self.rank}.safetensors", | ||
| device=self.device, | ||
|
|
@@ -76,10 +289,8 @@ def compute_ivhp_sharded(self): | |
| eigen_g[k] = eigen_g[k].to(dtype=torch.float32) | ||
| lambda_factor[k] = v.to(dtype=torch.float32) | ||
|
|
||
| p = self.cfg.projection_dim | ||
| grad_sizes = { | ||
| name: p * p if p > 0 else eigen_g[name].shape[1] * eigen_a[name].shape[1] | ||
| for name in eigen_a | ||
| name: eigen_g[name].shape[1] * eigen_a[name].shape[1] for name in eigen_a | ||
| } | ||
|
|
||
| mmap = load_gradients(self.gradient_path) | ||
|
|
@@ -157,19 +368,7 @@ def compute_ivhp_sharded(self): | |
| del eigen_a | ||
| gc.collect() | ||
|
|
||
| if p > 0: | ||
| projection_type = self.cfg.projection_type | ||
| for k, v in transformed_gradients.items(): | ||
| d_S, d_A = v.shape[-2:] | ||
| P_l = create_projection_matrix( | ||
| f"{k}/left", p, d_S, v.dtype, v.device, projection_type | ||
| ) | ||
| P_r = create_projection_matrix( | ||
| f"{k}/right", p, d_A, v.dtype, v.device, projection_type | ||
| ) | ||
| transformed_gradients[k] = torch.einsum("ps,nsa,ra->npr", P_l, v, P_r) | ||
|
|
||
| torch.cuda.synchronize() | ||
| torch.cuda.synchronize() if torch.cuda.is_available() else None | ||
| for k, v in transformed_gradients.items(): | ||
| grad_buffer[k][:] = v.to(device="cpu", non_blocking=True).flatten(1).numpy() | ||
|
|
||
|
|
@@ -191,6 +390,25 @@ def apply_worker( | |
| applicator.compute_ivhp_sharded() | ||
|
|
||
|
|
||
| def build_projections_worker( | ||
| rank: int, # global | ||
| local_rank: int, # local | ||
| world_size: int, | ||
| cfg: EkfacConfig, | ||
| ): | ||
| """Worker for Step 3.5: build ``M = R · cov^{-1/2}`` and save to disk.""" | ||
| init_dist(rank, local_rank, world_size) | ||
|
|
||
| build_kfac_projections( | ||
| cfg.hessian_method_path, | ||
| projection_dim=cfg.projection_dim, | ||
| projection_type=cfg.projection_type, | ||
| lambda_damp_factor=cfg.lambda_damp_factor, | ||
| dtype=torch.float32, | ||
| device=get_device(rank), | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| from bergson.config import DistributedConfig | ||
| from bergson.distributed import launch_distributed_run | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure we should even make this configurable? Can't imagine a situation where we would need to change it