diff --git a/README.md b/README.md index 2e64d705..a3d4d45b 100644 --- a/README.md +++ b/README.md @@ -275,13 +275,13 @@ If you found Bergson useful in your research, please cite us: ```bibtex @misc{quirke2026bergsonopensourcelibrary, - title={Bergson: An Open Source Library for Data Attribution}, + title={Bergson: An Open Source Library for Data Attribution}, author={Lucia Quirke and Louis Jaburi and David Johnston and William Z. Li and Gonçalo Paulo and Guillaume Martres and Girish Gupta and Stella Biderman and Nora Belrose}, year={2026}, eprint={2606.11660}, archivePrefix={arXiv}, primaryClass={cs.LG}, - url={https://arxiv.org/abs/2606.11660}, + url={https://arxiv.org/abs/2606.11660}, } ``` diff --git a/bergson/approx_unrolling/pipeline.py b/bergson/approx_unrolling/pipeline.py index 8711dd67..25e1162c 100644 --- a/bergson/approx_unrolling/pipeline.py +++ b/bergson/approx_unrolling/pipeline.py @@ -186,7 +186,7 @@ def approx_unrolling_pipeline( query_cfg.projection_dim = 0 query_preprocess_cfg = PreprocessConfig(aggregation="mean") save_run_config( - Build(query_cfg, query_preprocess_cfg, None), + Build(query_cfg, query_preprocess_cfg), query_cfg.partial_run_path, ) build(query_cfg, query_preprocess_cfg) diff --git a/bergson/build.py b/bergson/build.py index f7f9c087..defee1eb 100644 --- a/bergson/build.py +++ b/bergson/build.py @@ -8,7 +8,7 @@ from tqdm.auto import tqdm from bergson.collection import collect_gradients -from bergson.config.config import HessianConfig, IndexConfig, PreprocessConfig +from bergson.config.config import IndexConfig, PreprocessConfig from bergson.data import allocate_batches from bergson.distributed import ( cap_world_size_to_dataset, @@ -35,7 +35,6 @@ def build_worker( world_size: int, index_cfg: IndexConfig, preprocess_cfg: PreprocessConfig, - hessian_cfg: HessianConfig | None, ds: Dataset | IterableDataset, ): """ @@ -75,12 +74,9 @@ def build_worker( ) model, target_modules = setup_model_and_peft(index_cfg) - skip_hessians = hessian_cfg is None processor = create_processor(model, index_cfg, target_modules) - maybe_auto_batch_size( - index_cfg, model, ds, processor, target_modules, rank, skip_hessians - ) + maybe_auto_batch_size(index_cfg, model, ds, processor, target_modules, rank) attention_cfgs = { module: index_cfg.attention for module in index_cfg.split_attention_modules @@ -94,7 +90,6 @@ def build_worker( "target_modules": target_modules, "attention_cfgs": attention_cfgs, "preprocess_cfg": preprocess_cfg, - "skip_hessians": skip_hessians, } if isinstance(ds, Dataset): @@ -139,7 +134,6 @@ def flush(kwargs): def build( index_cfg: IndexConfig, preprocess_cfg: PreprocessConfig, - hessian_cfg: HessianConfig | None = None, ): """ Convert a dataset to an on-disk index. @@ -170,7 +164,7 @@ def build( launch_distributed_run( "build", build_worker, - [index_cfg, preprocess_cfg, hessian_cfg, ds], + [index_cfg, preprocess_cfg, ds], dist_cfg, ) diff --git a/bergson/cli/commands.py b/bergson/cli/commands.py index 35ff2f7d..73fdb898 100644 --- a/bergson/cli/commands.py +++ b/bergson/cli/commands.py @@ -62,40 +62,18 @@ def execute(self): @dataclass class Build(Serializable): - """ - Build a gradient index. Simultaneously approximate an autocorrelation Hessian - by passing `--method autocorrelation`.""" + """Build a gradient index.""" index_cfg: IndexConfig preprocess_cfg: PreprocessConfig - # Pass `--method autocorrelation` to simultaneously approximate a Hessian. - # `build` only supports autocorrelation Hessians; other methods go through - # the `hessian` command. - hessian_cfg: HessianConfig | None = None - def execute(self): """Build the gradient index.""" - if self.index_cfg.skip_index and self.hessian_cfg is None: - raise ValueError( - "if skip_index is True HessianConfig.method must be provided" - ) - - if ( - self.hessian_cfg is not None - and self.hessian_cfg.method != "autocorrelation" - ): - raise ValueError( - f"build only supports autocorrelation Hessians, got " - f"'{self.hessian_cfg.method}'. Use the `hessian` command for " - f"{self.hessian_cfg.method}." - ) - validate_run_path(self.index_cfg) save_run_config(self, self.index_cfg.partial_run_path) - build(self.index_cfg, self.preprocess_cfg, self.hessian_cfg) + build(self.index_cfg, self.preprocess_cfg) @dataclass @@ -134,16 +112,9 @@ class Hessian(Serializable): def execute(self): """Compute Hessian approximation.""" - validate_run_path(self.index_cfg) - - if self.hessian_cfg.method == "autocorrelation": - self.index_cfg.skip_index = True - save_run_config(self, self.index_cfg.partial_run_path) - build(self.index_cfg, PreprocessConfig(), self.hessian_cfg) - else: - save_run_config(self, self.index_cfg.partial_run_path) - approximate_hessians(self.index_cfg, self.hessian_cfg) + save_run_config(self, self.index_cfg.partial_run_path) + approximate_hessians(self.index_cfg, self.hessian_cfg) @dataclass diff --git a/bergson/cli/trackstar.py b/bergson/cli/trackstar.py index 83fd83c8..282fe067 100644 --- a/bergson/cli/trackstar.py +++ b/bergson/cli/trackstar.py @@ -5,28 +5,28 @@ from ..config.config import ( HessianConfig, IndexConfig, - PreprocessConfig, TrackstarConfig, ) from ..config.config_io import save_run_config +from ..hessians.hessian_approximations import approximate_hessians from ..process_grads import mix_autocorrelation_matrices from ..score.score import score_dataset from ..utils.worker_utils import validate_run_path -from .commands import Build, Mix, Score +from .commands import Build, Hessian, Mix, Score -def _limit_split_for_hess(cfg: IndexConfig) -> None: +def _limit_split_for_hess(cfg: IndexConfig, stats_sample_size: int | None) -> None: """Limit the data split to stats_sample_size for hessian-only steps.""" # TODO this code is hacky and bad - if cfg.stats_sample_size is not None: + if stats_sample_size is not None: split = cfg.data.split # Append HF slice notation if not already present if "[" not in split: - cfg.data.split = f"{split}[:{cfg.stats_sample_size}]" + cfg.data.split = f"{split}[:{stats_sample_size}]" else: base_split = split.split("[")[0] - cfg.data.split = f"{base_split}[:{cfg.stats_sample_size}]" + cfg.data.split = f"{base_split}[:{stats_sample_size}]" def _step_complete(path: str, resume: bool) -> bool: @@ -49,8 +49,6 @@ def trackstar(index_cfg: IndexConfig, trackstar_cfg: TrackstarConfig): scores_path = f"{run_path}/scores" resume = trackstar_cfg.resume - # Steps 1-2 only compute hessians, so don't preprocess grads. - hess_preprocess_cfg = PreprocessConfig() hess_cfg = HessianConfig(method="autocorrelation") def _validate(cfg: IndexConfig): @@ -64,15 +62,13 @@ def _validate(cfg: IndexConfig): if not _step_complete(value_hess_path, resume): value_hess_cfg = deepcopy(index_cfg) value_hess_cfg.run_path = value_hess_path - value_hess_cfg.skip_index = True - if trackstar_cfg.num_stats_sample_hessian: - _limit_split_for_hess(value_hess_cfg) + _limit_split_for_hess(value_hess_cfg, trackstar_cfg.stats_sample_size) _validate(value_hess_cfg) save_run_config( - Build(value_hess_cfg, hess_preprocess_cfg, hess_cfg), + Hessian(hessian_cfg=hess_cfg, index_cfg=value_hess_cfg), value_hess_cfg.partial_run_path, ) - build(value_hess_cfg, hess_preprocess_cfg, hess_cfg) + approximate_hessians(value_hess_cfg, hess_cfg) # Step 2: Compute hessians on query dataset print("Step 2/5: Computing hessians on query dataset...") @@ -80,15 +76,13 @@ def _validate(cfg: IndexConfig): query_hess_cfg = deepcopy(index_cfg) query_hess_cfg.run_path = query_hess_path query_hess_cfg.data = deepcopy(trackstar_cfg.query) - query_hess_cfg.skip_index = True - if trackstar_cfg.num_stats_sample_hessian: - _limit_split_for_hess(query_hess_cfg) + _limit_split_for_hess(query_hess_cfg, trackstar_cfg.stats_sample_size) _validate(query_hess_cfg) save_run_config( - Build(query_hess_cfg, hess_preprocess_cfg, hess_cfg), + Hessian(hessian_cfg=hess_cfg, index_cfg=query_hess_cfg), query_hess_cfg.partial_run_path, ) - build(query_hess_cfg, hess_preprocess_cfg, hess_cfg) + approximate_hessians(query_hess_cfg, hess_cfg) # Step 3: Mix query and value hessians print("Step 3/5: Mixing hessians...") @@ -138,10 +132,10 @@ def _validate(cfg: IndexConfig): _validate(query_cfg) save_run_config( - Build(query_cfg, trackstar_cfg.preprocess_cfg, None), + Build(query_cfg, trackstar_cfg.preprocess_cfg), query_cfg.partial_run_path, ) - build(query_cfg, trackstar_cfg.preprocess_cfg, None) + build(query_cfg, trackstar_cfg.preprocess_cfg) # Step 5: Score value dataset against query using mixed hessian print("Step 5/5: Scoring value dataset...") diff --git a/bergson/collection.py b/bergson/collection.py index 67d03257..0aaf712b 100644 --- a/bergson/collection.py +++ b/bergson/collection.py @@ -14,7 +14,6 @@ def collect_gradients( processor: GradientProcessor, cfg: IndexConfig, *, - skip_hessians: bool = True, batches: list[list[int]] | None = None, target_modules: set[str] | None = None, attention_cfgs: dict[str, AttentionConfig] | None = None, @@ -28,7 +27,6 @@ def collect_gradients( model=model.base_model, # type: ignore cfg=cfg, processor=processor, - skip_hessians=skip_hessians, target_modules=target_modules, data=data, scorer=scorer, diff --git a/bergson/collector/gradient_collectors.py b/bergson/collector/gradient_collectors.py index 65cca886..223bd84e 100644 --- a/bergson/collector/gradient_collectors.py +++ b/bergson/collector/gradient_collectors.py @@ -12,7 +12,6 @@ from bergson.builder import Builder from bergson.collector.collector import HookCollectorBase from bergson.config import IndexConfig, PreprocessConfig -from bergson.process_autocorrelation import process_autocorrelation_matrices from bergson.score.scorer import Scorer from bergson.utils.utils import get_gradient_dtype @@ -33,9 +32,6 @@ class GradientCollector(HookCollectorBase): cfg: IndexConfig """Configuration for gradient index.""" - skip_hessians: bool = True - """Whether to skip estimating autocorrelation hessian statistics.""" - mod_grads: dict = field(default_factory=dict) """Temporary storage for gradients during a batch, keyed by module name.""" @@ -48,6 +44,11 @@ class GradientCollector(HookCollectorBase): scorer: Scorer | None = None """Optional scorer for computing scores instead of building an index.""" + skip_index: bool = False + """Collect gradients into ``mod_grads`` without writing an on-disk index + (e.g. batch-size probing or gradient inspection). No effect when a + ``scorer`` is set, since scoring already skips the index.""" + def setup(self) -> None: """ Initialize collector state. @@ -77,7 +78,7 @@ def setup(self) -> None: ) # Compute whether we need to save the index - self.save_index = self.scorer is None and not self.cfg.skip_index + self.save_index = self.scorer is None and not self.skip_index if self.save_index: grad_sizes = {name: math.prod(s) for name, s in self.shapes().items()} @@ -94,20 +95,12 @@ def setup(self) -> None: @HookCollectorBase.split_attention_heads def backward_hook(self, module: nn.Module, g: Float[Tensor, "N S O"]): - """Compute per-sample gradient, accumulate autocorrelation matrix, and store.""" + """Compute the per-sample gradient and store it for the index.""" name: str = module._name # type: ignore[assignment] P = self._compute_gradient(module, g) global_proj = self.processor.projection_target == "global" - # Collect per-module hessians when projection target is per_module - if not self.skip_hessians and not global_proj: - P = P.float() - if name in self.processor.hessians: - self.processor.hessians[name].addmm_(P.mT, P) - else: - self.processor.hessians[name] = P.mT @ P - if global_proj: assert self.processor.projection_dim is not None R = self.projection( @@ -159,16 +152,6 @@ def teardown(self): if dist.is_initialized(): dist.reduce(self.per_doc_losses, dst=0) - grad_sizes = {name: math.prod(s) for name, s in self.shapes().items()} - if self.processor.hessians: - process_autocorrelation_matrices( - self.processor, - self.processor.hessians, - len(self.data), - grad_sizes, - self.rank, - ) - if self.builder: self.builder.teardown() diff --git a/bergson/config/config.py b/bergson/config/config.py index fceea112..f5897bc1 100644 --- a/bergson/config/config.py +++ b/bergson/config/config.py @@ -400,13 +400,6 @@ class IndexConfig(AttributionConfig, Serializable): or a path to an optimizer state file directly) or a Hugging Face URI ``hf://[@][/]``.""" - skip_index: bool = False - """Whether to skip building the gradient index.""" - - stats_sample_size: int | None = 10_000 - """Number of examples to use for estimating the autocorrelation Hessian. - This feature is experimental and may be removed.""" - loss_fn: Literal["ce", "kl"] = "ce" """Loss function to use.""" @@ -677,9 +670,10 @@ class TrackstarConfig: index hessians intersect at this component. Typical value is ~1000 out of ~65K total components.""" - num_stats_sample_hessian: bool = True - """Whether to use num_stats_sample items or the full dataset to - compute hessians.""" + stats_sample_size: int | None = 10_000 + """Number of examples to use for estimating the autocorrelation Hessian + in the trackstar pipeline's hessian-fitting steps. Set to None to use + the full dataset.""" resume: bool = False """Skip pipeline steps whose output directory already exists.""" diff --git a/bergson/hessians/autocorrelation.py b/bergson/hessians/autocorrelation.py new file mode 100644 index 00000000..c92f2704 --- /dev/null +++ b/bergson/hessians/autocorrelation.py @@ -0,0 +1,62 @@ +import math +from dataclasses import dataclass +from pathlib import Path + +import torch.nn as nn +from datasets import Dataset +from jaxtyping import Float +from torch import Tensor + +from bergson.collector.collector import HookCollectorBase +from bergson.process_autocorrelation import process_autocorrelation_matrices + + +@dataclass(kw_only=True) +class AutocorrelationCollector(HookCollectorBase): + """Fit a per-module autocorrelation Hessian approximation. + + For each target module this accumulates the per-example gradient Gram + ``H = Σ_n vec(g_n)ᵀ vec(g_n)`` and eigendecomposes it in ``teardown`` via + :func:`process_autocorrelation_matrices`. + """ + + data: Dataset + """The dataset the Hessian is fit on (its length normalizes the Gram).""" + + path: str + """Directory the fitted GradientProcessor is saved to.""" + + def setup(self) -> None: + assert self.processor.projection_target != "global", ( + "Autocorrelation Hessian fitting requires per-module projection; " + "projection_target='global' sums all modules into a single key and " + "has no per-module Hessian." + ) + + @HookCollectorBase.split_attention_heads + def backward_hook(self, module: nn.Module, g: Float[Tensor, "N S O"]): + """Accumulate the per-module per-example gradient Gram ``PᵀP``.""" + name: str = module._name # type: ignore[assignment] + P = self._compute_gradient(module, g).float() + if name in self.processor.hessians: + self.processor.hessians[name].addmm_(P.mT, P) + else: + self.processor.hessians[name] = P.mT @ P + + def process_batch(self, indices: list[int], **kwargs): + """No per-batch output; the Gram accumulates directly on the processor.""" + return + + def teardown(self): + """Reduce/eigendecompose the accumulated Grams and save the processor.""" + grad_sizes = {name: math.prod(s) for name, s in self.shapes().items()} + if self.processor.hessians: + process_autocorrelation_matrices( + self.processor, + self.processor.hessians, + len(self.data), + grad_sizes, + self.rank, + ) + if self.rank == 0: + self.processor.save(Path(self.path)) diff --git a/bergson/hessians/hessian_approximations.py b/bergson/hessians/hessian_approximations.py index c3bb364e..8f529e86 100644 --- a/bergson/hessians/hessian_approximations.py +++ b/bergson/hessians/hessian_approximations.py @@ -14,6 +14,7 @@ from bergson.data import allocate_batches from bergson.distributed import init_dist, launch_distributed_run from bergson.gradients import GradientProcessor +from bergson.hessians.autocorrelation import AutocorrelationCollector from bergson.hessians.eigenvectors import ( LambdaCollector, compute_eigendecomposition, @@ -27,6 +28,7 @@ setup_reproducibility, ) from bergson.utils.worker_utils import ( + create_processor, setup_data_pipeline, setup_model_and_peft, ) @@ -128,16 +130,42 @@ def hessian_worker( if target_modules is None: target_modules = peft_target_modules + attention_cfgs = { + module: index_cfg.attention for module in index_cfg.split_attention_modules + } + batches = allocate_batches(ds["length"][:], index_cfg.token_batch_size) + + # The autocorrelation Hessian is a dense per-module gradient Gram so + # it computes in one pass and skips the factored eigendecomposition + if hessian_cfg.method == "autocorrelation": + processor = create_processor(model, index_cfg, target_modules) + collector = AutocorrelationCollector( + model=model.base_model, # type: ignore + data=ds, + path=str(index_cfg.partial_run_path), + processor=processor, + target_modules=target_modules, + attention_cfgs=attention_cfgs, + filter_modules=index_cfg.filter_modules, + ) + computer = CollectorComputer( + model=model, # type: ignore + data=ds, + collector=collector, + batches=batches, + cfg=index_cfg, + ) + computer.run_with_collector_hooks(desc="Approximating autocorrelation Hessian") + return + kwargs = { "model": model, "data": ds, "index_cfg": index_cfg, "hessian_cfg": hessian_cfg, "target_modules": target_modules, - "attention_cfgs": { - module: index_cfg.attention for module in index_cfg.split_attention_modules - }, - "batches": allocate_batches(ds["length"][:], index_cfg.token_batch_size), + "attention_cfgs": attention_cfgs, + "batches": batches, } collect_hessians(**kwargs) diff --git a/bergson/hessians/pipeline.py b/bergson/hessians/pipeline.py index da7e6f8d..86edcffe 100644 --- a/bergson/hessians/pipeline.py +++ b/bergson/hessians/pipeline.py @@ -83,10 +83,10 @@ def _validate(cfg: IndexConfig): query_preprocess_cfg = PreprocessConfig(aggregation="mean") save_run_config( - Build(query_cfg, query_preprocess_cfg, None), + Build(query_cfg, query_preprocess_cfg), query_cfg.partial_run_path, ) - build(query_cfg, query_preprocess_cfg, None) + build(query_cfg, query_preprocess_cfg) # ── Step 2: Fit Hessian factors on training data ────────────────────── print(f"Step 2/4: Fitting {method} factors on training data...") diff --git a/bergson/score/score.py b/bergson/score/score.py index 364eca1b..6e8738e2 100644 --- a/bergson/score/score.py +++ b/bergson/score/score.py @@ -282,7 +282,6 @@ def score_worker( "cfg": index_cfg, "target_modules": target_modules, "attention_cfgs": attention_cfgs, - "skip_hessians": True, } score_dtype = ( diff --git a/bergson/utils/batch_size.py b/bergson/utils/batch_size.py index 58fa48fa..063b1c82 100644 --- a/bergson/utils/batch_size.py +++ b/bergson/utils/batch_size.py @@ -1,6 +1,5 @@ import gc import json -from dataclasses import replace from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Optional @@ -74,7 +73,6 @@ def maybe_auto_batch_size( processor: GradientProcessor, target_modules: set[str] | None, rank: int = 0, - skip_hessians: bool = True, ) -> None: """Run auto batch size determination if enabled. @@ -97,22 +95,20 @@ def maybe_auto_batch_size( gloo_group = None if rank == 0: - # skip_index=True avoids creating a Builder, whose - # create_index() calls dist.barrier() on the default - # NCCL group — which would deadlock since only rank 0 - # creates this collector. - probe_cfg = replace(cfg, skip_index=True) cfg.token_batch_size = determine_batch_size( root=Path(".cache"), cfg=cfg, model=model, + # skip_index=True avoids creating a Builder, whose create_index() + # calls dist.barrier() on the default NCCL group — which would + # deadlock since only rank 0 creates this collector. collector=GradientCollector( model=model.base_model, - cfg=probe_cfg, + cfg=cfg, processor=processor, - skip_hessians=skip_hessians, target_modules=target_modules, data=ds, # type: ignore + skip_index=True, ), starting_batch_size=cfg.token_batch_size, ) diff --git a/docs/pipeline.rst b/docs/pipeline.rst index 98729eb0..5cd54330 100644 --- a/docs/pipeline.rst +++ b/docs/pipeline.rst @@ -127,8 +127,7 @@ A directory at ``run_path`` containing: --truncation \ --aggregation mean \ --unit_normalize \ - --projection_dim 0 \ - --skip_hessians + --projection_dim 0 .. note:: @@ -196,8 +195,7 @@ A directory at ``run_path`` containing: --query_path runs/my-query \ --score individual \ --unit_normalize \ - --projection_dim 0 \ - --skip_hessians + --projection_dim 0 .. _hessian-command: diff --git a/examples/magic/compare/q3_trackstar.yaml b/examples/magic/compare/q3_trackstar.yaml index 5af8d6d1..4dce4a0c 100644 --- a/examples/magic/compare/q3_trackstar.yaml +++ b/examples/magic/compare/q3_trackstar.yaml @@ -12,7 +12,6 @@ steps: run_path: runs/compare/q3_trackstar model: runs/compare/q3_random/hf_model overwrite: true - stats_sample_size: 10000 # GPT-2's max context length is 1024 token_batch_size: 1024 # GPT-2 uses an nn.Linear for its LM head. @@ -28,6 +27,7 @@ steps: truncation: true trackstar_cfg: + stats_sample_size: 10000 query: dataset: Salesforce/wikitext subset: wikitext-2-raw-v1 diff --git a/examples/pipelines/trackstar_wmdp.yaml b/examples/pipelines/trackstar_wmdp.yaml index 4344ec5b..ca7a0cb1 100644 --- a/examples/pipelines/trackstar_wmdp.yaml +++ b/examples/pipelines/trackstar_wmdp.yaml @@ -16,13 +16,13 @@ steps: index_cfg: run_path: runs/trackstar_wmdp model: EleutherAI/pythia-160m - stats_sample_size: 10000 overwrite: true data: dataset: NeelNanda/pile-10k split: train truncation: true trackstar_cfg: + stats_sample_size: 10000 query: dataset: cais/wmdp split: test diff --git a/examples/pipelines/trackstar_wmdp_steps.yaml b/examples/pipelines/trackstar_wmdp_steps.yaml index 6476e9d5..4a9d3b6a 100644 --- a/examples/pipelines/trackstar_wmdp_steps.yaml +++ b/examples/pipelines/trackstar_wmdp_steps.yaml @@ -5,27 +5,24 @@ run_path: runs/trackstar_wmdp_pieces steps: # Step 1/5: Compute autocorrelation hessian on the value (training) dataset. - - build: + - hessian: index_cfg: run_path: runs/trackstar_wmdp_pieces/value_hessian overwrite: true model: EleutherAI/pythia-160m - skip_index: true data: dataset: NeelNanda/pile-10k split: train[:10000] truncation: true hessian_cfg: method: autocorrelation - preprocess_cfg: {} # Step 2/5: Compute autocorrelation hessian on the query dataset. - - build: + - hessian: index_cfg: run_path: runs/trackstar_wmdp_pieces/query_hessian overwrite: true model: EleutherAI/pythia-160m - skip_index: true data: dataset: cais/wmdp split: test[:10000] @@ -34,7 +31,6 @@ steps: truncation: true hessian_cfg: method: autocorrelation - preprocess_cfg: {} # Step 3/5: Mix the query and value hessians. - mix: diff --git a/tests/test_attribute_tokens.py b/tests/test_attribute_tokens.py index 2b0c0b1f..32b4ddac 100644 --- a/tests/test_attribute_tokens.py +++ b/tests/test_attribute_tokens.py @@ -361,7 +361,6 @@ def test_token_score_e2e(tmp_path: Path, model, dataset): run_path=str(tmp_path / "run"), token_batch_size=1024, attribute_tokens=True, - skip_index=True, ) collect_gradients( @@ -473,7 +472,6 @@ def _collect_in_memory( token_batch_size=1024, attribute_tokens=attribute_tokens, loss_reduction="sum", - skip_index=True, include_bias=include_bias, ) cfg.partial_run_path.mkdir(parents=True, exist_ok=True) diff --git a/tests/test_build.py b/tests/test_build.py index 6a586554..64c520fc 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -1,4 +1,3 @@ -import subprocess from pathlib import Path import numpy as np @@ -11,47 +10,6 @@ from bergson.data import load_gradients -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_build_e2e(tmp_path: Path): - result = subprocess.run( - [ - "python", - "-m", - "bergson", - "build", - "test_e2e", - "--model", - "EleutherAI/pythia-14m", - "--dataset", - "NeelNanda/pile-10k", - "--split", - "train[:100]", - "--truncation", - "--projection_dim", - "4", - "--token_batch_size", - "1024", - "--precision", - "bf16", - "--method", - "autocorrelation", - ], - cwd=tmp_path, - capture_output=True, # Add this - text=True, # Add this to get strings instead of bytes - ) - - assert "Error" not in result.stderr, f"Error found in stderr:\n{result.stderr}" - - processor = GradientProcessor.load(tmp_path / "test_e2e") - - assert processor.hessians is not None - assert processor.hessians_eigen is not None - - assert len(processor.hessians) > 0 - assert len(processor.hessians_eigen) > 0 - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_build_consistency(tmp_path: Path, model, dataset): model = model.float() diff --git a/tests/test_config_runner.py b/tests/test_config_runner.py index ccc8ffd3..33aef4f5 100644 --- a/tests/test_config_runner.py +++ b/tests/test_config_runner.py @@ -186,8 +186,8 @@ def test_omitted_fields_use_dataclass_defaults(tmp_path, registry): assert cmd.hessian_cfg.ev_correction == HessianConfig(method="kfac").ev_correction -def test_build_without_method_skips_hessian(tmp_path, registry): - """No hessian_cfg in a build step means index-only (no Hessian).""" +def test_build_step_is_index_only(tmp_path, registry): + """A build step builds the index only; Hessians are fit via the hessian step.""" yaml_path = write( tmp_path, """ @@ -200,7 +200,6 @@ def test_build_without_method_skips_hessian(tmp_path, registry): steps = parse(yaml_path, registry) _, cmd = steps[0] assert isinstance(cmd, Build) - assert cmd.hessian_cfg is None def make_steps() -> list[tuple[str, object]]: @@ -255,20 +254,18 @@ def test_load_subconfig_searches_all_steps(tmp_path): assert load_subconfig(tmp_path / "missing", "index_cfg", IndexConfig) is None -def test_build_with_method_computes_hessian(tmp_path, registry): - """An explicit method on a build step requests a Hessian approximation.""" +def test_hessian_step_fits_autocorrelation(tmp_path, registry): + """Autocorrelation is fit via the hessian step, like every other method.""" yaml_path = write( tmp_path, """ steps: - - build: + - hessian: index_cfg: {run_path: runs/test} - preprocess_cfg: {} hessian_cfg: {method: autocorrelation} """, ) steps = parse(yaml_path, registry) _, cmd = steps[0] - assert isinstance(cmd, Build) - assert cmd.hessian_cfg is not None + assert isinstance(cmd, Hessian) assert cmd.hessian_cfg.method == "autocorrelation" diff --git a/tests/test_global_projection.py b/tests/test_global_projection.py index ab17445f..bb24e6fa 100644 --- a/tests/test_global_projection.py +++ b/tests/test_global_projection.py @@ -185,7 +185,7 @@ def test_global_project_values_cpu(tmp_path: Path, model, dataset): """ proj_dim = 16 tokens = torch.tensor([dataset[0]["input_ids"]]) - cfg = IndexConfig(run_path=str(tmp_path), skip_index=True) + cfg = IndexConfig(run_path=str(tmp_path)) # First pass: capture raw per-module gradients (no projection) raw_collector = GradientCollector( @@ -193,6 +193,7 @@ def test_global_project_values_cpu(tmp_path: Path, model, dataset): cfg=cfg, data=dataset, processor=GradientProcessor(projection_dim=None), + skip_index=True, ) with raw_collector: model.zero_grad() @@ -208,6 +209,7 @@ def test_global_project_values_cpu(tmp_path: Path, model, dataset): cfg=cfg, data=dataset, processor=global_processor, + skip_index=True, ) with global_collector: model.zero_grad() diff --git a/tests/test_gradients.py b/tests/test_gradients.py index c3572640..6339f425 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -164,7 +164,6 @@ def test_gradient_collector_proj_norm(): for p in (16, None): cfg = IndexConfig( run_path=str(temp_dir / "run"), - skip_index=True, ) processor = GradientProcessor(projection_dim=p) collector = GradientCollector( @@ -172,7 +171,7 @@ def test_gradient_collector_proj_norm(): cfg=cfg, data=data, processor=processor, - skip_hessians=p is None, + skip_index=True, ) with collector: model.zero_grad() @@ -285,7 +284,6 @@ def test_gradient_collector_batched( # Create config for GradientCollector cfg = IndexConfig( run_path=str(temp_dir / "run"), - skip_index=True, ) processor = GradientProcessor( @@ -295,6 +293,7 @@ def test_gradient_collector_batched( model=model, cfg=cfg, data=dummy_data, + skip_index=True, processor=processor, target_modules={"fc1", "fc2"}, ) @@ -399,7 +398,6 @@ def compute_ground_truth(model) -> torch.Tensor: # Create config for GradientCollector cfg = IndexConfig( run_path=str(temp_dir / "run"), - skip_index=True, ) processor = GradientProcessor(include_bias=True, projection_dim=None) @@ -407,6 +405,7 @@ def compute_ground_truth(model) -> torch.Tensor: model=model, cfg=cfg, data=dummy_data, + skip_index=True, processor=processor, target_modules={"fc"}, ) @@ -464,7 +463,6 @@ def test_gradient_collector_with_projection( # Create config for GradientCollector cfg = IndexConfig( run_path=str(temp_dir / "run"), - skip_index=True, ) processor = GradientProcessor( @@ -474,6 +472,7 @@ def test_gradient_collector_with_projection( model=model, cfg=cfg, data=dummy_data, + skip_index=True, processor=processor, target_modules={"fc1", "fc2"}, ) @@ -535,7 +534,6 @@ def test_adafactor_normalization_ground_truth( dummy_data = Dataset.from_dict({"input_ids": [[1] * 10] * N}) cfg = IndexConfig( run_path=str(temp_dir / "run"), - skip_index=True, ) processor = GradientProcessor( @@ -547,6 +545,7 @@ def test_adafactor_normalization_ground_truth( model=model, cfg=cfg, data=dummy_data, + skip_index=True, processor=processor, target_modules={"fc1", "fc2"}, ) @@ -617,7 +616,6 @@ def test_in_features_restored_after_collector(test_params, simple_model_class): dummy_data = Dataset.from_dict({"input_ids": [[1] * 10] * N}) cfg = IndexConfig( run_path=str(temp_dir / "run"), - skip_index=True, ) # Record original in_features for all layers @@ -632,6 +630,7 @@ def test_in_features_restored_after_collector(test_params, simple_model_class): model=model, cfg=cfg, data=dummy_data, + skip_index=True, processor=processor, target_modules=set(original_in_features.keys()), ) @@ -683,7 +682,7 @@ def test_projected_bias_gradients_match_full_projection( bias_avg_sq=torch.rand(layer.out_features) + 0.1, ) - cfg = IndexConfig(run_path=str(temp_dir / "run"), skip_index=True) + cfg = IndexConfig(run_path=str(temp_dir / "run")) processor = GradientProcessor( normalizers=normalizers, projection_dim=P, include_bias=True ) @@ -692,6 +691,7 @@ def test_projected_bias_gradients_match_full_projection( model=model, cfg=cfg, data=dummy_data, + skip_index=True, processor=processor, target_modules={"fc1", "fc2"}, ) diff --git a/tests/test_hessian.py b/tests/test_hessian.py new file mode 100644 index 00000000..239d4218 --- /dev/null +++ b/tests/test_hessian.py @@ -0,0 +1,48 @@ +import subprocess +from pathlib import Path + +import pytest +import torch + +from bergson import GradientProcessor + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_autocorrelation_hessian_e2e(tmp_path: Path): + result = subprocess.run( + [ + "python", + "-m", + "bergson", + "hessian", + "test_e2e", + "--model", + "EleutherAI/pythia-14m", + "--dataset", + "NeelNanda/pile-10k", + "--split", + "train[:100]", + "--truncation", + "--projection_dim", + "4", + "--token_batch_size", + "1024", + "--precision", + "bf16", + "--method", + "autocorrelation", + ], + cwd=tmp_path, + capture_output=True, + text=True, + ) + + assert "Error" not in result.stderr, f"Error found in stderr:\n{result.stderr}" + + processor = GradientProcessor.load(tmp_path / "test_e2e") + + assert processor.hessians is not None + assert processor.hessians_eigen is not None + + assert len(processor.hessians) > 0 + assert len(processor.hessians_eigen) > 0 diff --git a/tests/test_reduce.py b/tests/test_reduce.py index cfef7e6b..92d37ceb 100644 --- a/tests/test_reduce.py +++ b/tests/test_reduce.py @@ -15,6 +15,7 @@ ) from bergson.build import build from bergson.data import load_gradient_dataset +from bergson.hessians.autocorrelation import AutocorrelationCollector @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -76,20 +77,27 @@ def test_programmatic_reduce(tmp_path: Path): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_reduce_with_preconditioning(tmp_path: Path, model, dataset): - # Step 1: build an index WITH hessians - build_cfg = IndexConfig(run_path=str(tmp_path / "build"), token_batch_size=1024) + # Step 1: fit an autocorrelation Hessian on the data + fit_cfg = IndexConfig(run_path=str(tmp_path / "hessian"), token_batch_size=1024) + fit_cfg.partial_run_path.mkdir(parents=True, exist_ok=True) - collect_gradients( - model=model, + hess_collector = AutocorrelationCollector( + model=model.base_model, data=dataset, + path=str(fit_cfg.partial_run_path), processor=GradientProcessor(), - cfg=build_cfg, - skip_hessians=False, + attention_cfgs={}, ) + CollectorComputer( + model=model, + data=dataset, + collector=hess_collector, + cfg=fit_cfg, + ).run_with_collector_hooks(desc="Fit autocorrelation Hessian") - # Step 2: reduce with preconditioning pointing at the built index + # Step 2: reduce with preconditioning pointing at the fitted Hessian preprocess_cfg = PreprocessConfig( - aggregation="mean", hessian_path=str(build_cfg.partial_run_path) + aggregation="mean", hessian_path=str(fit_cfg.partial_run_path) ) reduce_index_cfg = IndexConfig( run_path=str(tmp_path / "reduce_hess"),