Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions bergson/collector/gradient_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from bergson.builder import Builder
from bergson.collector.collector import HookCollectorBase
from bergson.config import IndexConfig, PreprocessConfig
from bergson.hessians.apply_hessian import load_kfac_projections
from bergson.process_autocorrelation import process_autocorrelation_matrices
from bergson.score.scorer import Scorer
from bergson.utils.projection import make_global_projector
Expand Down Expand Up @@ -72,6 +73,15 @@ def setup(self) -> None:
self.lo = torch.finfo(self.save_dtype).min
self.hi = torch.finfo(self.save_dtype).max

if self.cfg.kfac_projection_path:
load_kfac_projections(
self.cfg.kfac_projection_path,
self.processor._projection_matrices,
self.target_info.keys(),
device=self.model.device,
dtype=self.save_dtype,
)

self.per_doc_losses = torch.full(
(len(self.data),),
device=self.model.device,
Expand Down
8 changes: 8 additions & 0 deletions bergson/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,14 @@ class IndexConfig(AttributionConfig, Serializable):
processor_path: str = ""
"""Path to a precomputed processor."""

kfac_projection_path: str = ""

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure we should even make this configurable? Can't imagine a situation where we would need to change it

"""Path to a K-FAC hessian directory containing precomputed modified
projection matrices (``projection_left_sharded/``,
``projection_right_sharded/``). When set, the gradient collector loads
these instead of sampling fresh random projections — the saved matrices
are ``M = R · cov^{-1/2}``, so applying them to gradients yields
preconditioned-and-sketched gradients in a single matmul (P-SIFT)."""

optimizer_state: str = ""
"""Source for optimizer second moments used to normalize gradients.
Either a local path (a checkpoint directory containing ``optimizer.pt``,
Expand Down
252 changes: 235 additions & 17 deletions bergson/hessians/apply_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,18 @@
import numpy as np
import torch
import torch.distributed as dist
from safetensors.torch import load_file
from safetensors import safe_open
from safetensors.torch import load_file, save_file
from simple_parsing import ArgumentParser
from torch import Tensor

from bergson.collector.collector import create_projection_matrix
from bergson.data import create_index, load_gradients
from bergson.distributed import init_dist
from bergson.hessians.eigenvectors import (
_compute_full_matrix,
fair_distribute_by_cost,
)
from bergson.hessians.sharded_computation import ShardedMul
from bergson.utils.logger import get_logger
from bergson.utils.utils import get_device
Expand All @@ -35,6 +40,148 @@ class EkfacConfig:
projection_type: Literal["normal", "rademacher"] = "rademacher"


SIDE_TO_COV = {"left": "gradient", "right": "activation"}


def build_kfac_projections(
hessian_method_path: str,
projection_dim: int,
projection_type: Literal["normal", "rademacher"],
lambda_damp_factor: float,
dtype: torch.dtype,
device: torch.device,
) -> None:
"""Build and save the precondition+sketch matrices ``M = R · cov^{-1/2}``.

For each module and each side, this composes the random projection ``R``
with the damped inverse square-root of the K-FAC factor:

M = R · Q · diag((E + λ·mean(E))^{-1/2}) · Qᵀ [p, d]

The result is saved to ``hessian_method_path/projection_{side}_sharded/``.
Loaded later by the gradient collector — when present, gradients come
out of collection already preconditioned and sketched in one matmul, so
the apply step and the score step both share the same projection.

Layers are distributed across ranks via the same fair-by-cost scheme as
the eigendecomposition; each rank writes its own shard.
"""
if projection_dim <= 0:
raise ValueError(
f"build_kfac_projections requires projection_dim > 0 (compression); "
f"got {projection_dim}."
)

rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
Comment on lines +75 to +76

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure you use the correct rank and world size (important for multi node setups)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(it seems like they should be the global ones?)


side_dims: dict[str, dict[str, int]] = {"left": {}, "right": {}}
for side, cov in SIDE_TO_COV.items():
with safe_open(
os.path.join(
hessian_method_path, f"eigen_{cov}_sharded/shard_0.safetensors"
),
framework="pt",
) as f:
for name in f.keys():
side_dims[side][name] = f.get_tensor(name).shape[-1]

names = list(side_dims["left"].keys())

per_layer_dim = {n: max(side_dims["left"][n], side_dims["right"][n]) for n in names}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name is somewhat confusing and I am not sure I understand what this is doing. If you want to fairly distribute by cost, then looking at the max is not necessarily correct?

my_names = fair_distribute_by_cost(per_layer_dim, world_size)[rank]

out_dirs = {
side: os.path.join(hessian_method_path, f"projection_{side}_sharded")
for side in ("left", "right")
}
if rank == 0:
for d in out_dirs.values():
os.makedirs(d, exist_ok=True)

if dist.is_initialized():
dist.barrier()

saved: dict[str, dict[str, Tensor]] = {"left": {}, "right": {}}
for name in my_names:
for side, cov in SIDE_TO_COV.items():
d = side_dims[side][name]
Q = _compute_full_matrix(
name=name,
shard_path=os.path.join(hessian_method_path, f"eigen_{cov}_sharded"),
rank=rank,
world_size=world_size,
).to(device=device, dtype=torch.float32)
E = _compute_full_matrix(
name=name,
shard_path=os.path.join(hessian_method_path, f"eigval_{cov}_sharded"),
rank=rank,
world_size=world_size,
).to(device=device, dtype=torch.float32)

damp = lambda_damp_factor * E.mean()
D = (E + damp).clamp_min(torch.finfo(torch.float32).tiny).rsqrt() # [d]

R = create_projection_matrix(
f"{name}/{side}",
projection_dim,
d,
torch.float32,
device,
projection_type,
)
# create_projection_matrix returns unit-norm rows, so E[RᵀR] =
# (p/d)·I. An unbiased sketch needs E[RᵀR] = I, so rescale by
# √(d/p); otherwise each side carries a p/d factor and the score
# picks up a per-layer p²/(d_left·d_right) reweighting that
# corrupts ranking relative to the exact (projection_dim=0) path.
R = R * (d / projection_dim) ** 0.5
Comment on lines +133 to +138

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure how relevant that is, we look at relative rankings only anyway

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose we can leave it there


M = R @ (Q * D) @ Q.T
saved[side][name] = M.to(dtype=dtype).cpu().contiguous()

for side in ("left", "right"):
save_file(
saved[side],
os.path.join(out_dirs[side], f"shard_{rank}.safetensors"),
)

get_logger().info(
f"Saved M_left/M_right to {out_dirs['left']} and {out_dirs['right']}"
)

if dist.is_initialized():
dist.barrier()


def load_kfac_projections(
hessian_method_path: str | os.PathLike,
cache: dict,
target_names,
device: torch.device,
dtype: torch.dtype,
) -> None:
"""Populate ``cache`` with the per-layer M matrices saved by
``build_kfac_projections``. ``cache`` is the projection cache used by
``HookCollectorBase.projection()`` (keyed by ``(name, side, device)``);
a cache hit short-circuits random projection at gradient collection.

``dtype`` matches what the collector will request at hook time (the
model's gradient dtype). Saved M lives in fp32 on disk for precision
and is cast on the way in.
"""
targets = set(target_names)
for side in ("left", "right"):
side_dir = Path(hessian_method_path) / f"projection_{side}_sharded"
if not side_dir.exists():
return
for shard_file in sorted(side_dir.glob("shard_*.safetensors")):
with safe_open(str(shard_file), framework="pt", device=str(device)) as f:
for name in f.keys():
if name in targets:
cache[(name, side, device)] = f.get_tensor(name).to(dtype=dtype)


class EkfacApplicator:
def __init__(self, cfg: EkfacConfig, apply_fn=None):
self.cfg = cfg
Expand All @@ -53,6 +200,72 @@ def __init__(self, cfg: EkfacConfig, apply_fn=None):
self.sharded_computer = ShardedMul()

def compute_ivhp_sharded(self):
if self.cfg.projection_dim > 0:
if self.cfg.ev_correction:
raise ValueError(
"K-FAC + random projection (compression) is incompatible "
"with `ev_correction=True`: eigenvalue correction acts on "
"the joint S⊗A spectrum and cannot be folded into a per-side "
"inverse-square-root. Use a Kronecker-factored method "
"(kfac, tkfac, shampoo) without ev_correction."
)
return self._apply_compressed()
return self._apply_legacy()

def _apply_compressed(self):
"""Apply M_left · G_q · M_rightᵀ to the saved query gradients.

Assumes step 3.5 (``build_kfac_projections``) has already saved M
under ``hessian_method_path/projection_{side}_sharded/``. Output
shape per layer is [N, p, p].
"""
p = self.cfg.projection_dim

M_left: dict[str, Tensor] = {}
M_right: dict[str, Tensor] = {}
for side, store in (("left", M_left), ("right", M_right)):
side_dir = os.path.join(self.path, f"projection_{side}_sharded")
for shard_file in sorted(Path(side_dir).glob("shard_*.safetensors")):
shard = load_file(str(shard_file), device=str(self.device))
for k, v in shard.items():
store[k] = v.to(dtype=torch.float32)

mmap = load_gradients(self.gradient_path)
with open(os.path.join(self.gradient_path, "info.json")) as f:
info = json.load(f)

grad_sizes = {name: p * p for name in M_left}
grad_buffer = create_index(
Path(self.cfg.run_path),
num_grads=info["num_grads"],
grad_sizes=grad_sizes,
dtype=np.float32,
)

self.logger.info(
f"Loaded gradients for {len(mmap)} queries and applying M·G·Mᵀ..."
)

for name, M_l in M_left.items():
M_r = M_right[name]
d_S, d_A = M_l.shape[1], M_r.shape[1]
G = (
torch.from_numpy(mmap[name][:])
.to(device=self.device, dtype=torch.float32)
.view(-1, d_S, d_A)
)
# ĝ_q = M_left · G · M_rightᵀ [N, p, p]
sketched = torch.einsum("ps,nsa,ra->npr", M_l, G, M_r)
grad_buffer[name][:] = (
sketched.to(device="cpu", non_blocking=True).flatten(1).numpy()
)

torch.cuda.synchronize() if torch.cuda.is_available() else None
grad_buffer.flush()
self.logger.info(f"Saved sketched IVHP gradients to {self.cfg.run_path}")

def _apply_legacy(self):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this legacy?

"""Full-rank IVHP via the eigenbasis rotate-divide-rotate path."""
eigen_a = load_file(
self.path + f"/eigen_activation_sharded/shard_{self.rank}.safetensors",
device=self.device,
Expand All @@ -76,10 +289,8 @@ def compute_ivhp_sharded(self):
eigen_g[k] = eigen_g[k].to(dtype=torch.float32)
lambda_factor[k] = v.to(dtype=torch.float32)

p = self.cfg.projection_dim
grad_sizes = {
name: p * p if p > 0 else eigen_g[name].shape[1] * eigen_a[name].shape[1]
for name in eigen_a
name: eigen_g[name].shape[1] * eigen_a[name].shape[1] for name in eigen_a
}

mmap = load_gradients(self.gradient_path)
Expand Down Expand Up @@ -157,19 +368,7 @@ def compute_ivhp_sharded(self):
del eigen_a
gc.collect()

if p > 0:
projection_type = self.cfg.projection_type
for k, v in transformed_gradients.items():
d_S, d_A = v.shape[-2:]
P_l = create_projection_matrix(
f"{k}/left", p, d_S, v.dtype, v.device, projection_type
)
P_r = create_projection_matrix(
f"{k}/right", p, d_A, v.dtype, v.device, projection_type
)
transformed_gradients[k] = torch.einsum("ps,nsa,ra->npr", P_l, v, P_r)

torch.cuda.synchronize()
torch.cuda.synchronize() if torch.cuda.is_available() else None
for k, v in transformed_gradients.items():
grad_buffer[k][:] = v.to(device="cpu", non_blocking=True).flatten(1).numpy()

Expand All @@ -191,6 +390,25 @@ def apply_worker(
applicator.compute_ivhp_sharded()


def build_projections_worker(
rank: int, # global
local_rank: int, # local
world_size: int,
cfg: EkfacConfig,
):
"""Worker for Step 3.5: build ``M = R · cov^{-1/2}`` and save to disk."""
init_dist(rank, local_rank, world_size)

build_kfac_projections(
cfg.hessian_method_path,
projection_dim=cfg.projection_dim,
projection_type=cfg.projection_type,
lambda_damp_factor=cfg.lambda_damp_factor,
dtype=torch.float32,
device=get_device(rank),
)


if __name__ == "__main__":
from bergson.config import DistributedConfig
from bergson.distributed import launch_distributed_run
Expand Down
10 changes: 7 additions & 3 deletions bergson/hessians/eigenvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,6 @@ def compute_eigendecomposition(
covariance_eigenvalues[key] = (
eigenvalues.to(original_dtype).to(device="cpu").contiguous()
)
covariance_eigenvalues[key] = (
eigenvalues.to(original_dtype).to(device="cpu").contiguous()
)

covariance_eigenvectors = _gather_and_shard_along_dim_0(
input_dict=covariance_eigenvectors,
Expand All @@ -326,16 +323,23 @@ def compute_eigendecomposition(
dirname = os.path.dirname(covariance_path)
basename = os.path.basename(covariance_path)
output_path = os.path.join(dirname, "eigen_" + basename)
eigvals_path = os.path.join(dirname, "eigval_" + basename)

os.makedirs(output_path, exist_ok=True)
os.makedirs(eigvals_path, exist_ok=True)
save_file(
covariance_eigenvectors,
os.path.join(output_path, f"shard_{rank}.safetensors"),
)
save_file(
covariance_eigenvalues,
os.path.join(eigvals_path, f"shard_{rank}.safetensors"),
)

gc.collect()

get_logger().info(f"Saved eigenvectors to {output_path}")
get_logger().info(f"Saved eigenvalues to {eigvals_path}")

return covariance_eigenvalues

Expand Down
Loading