Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,13 +275,13 @@ If you found Bergson useful in your research, please cite us:

```bibtex
@misc{quirke2026bergsonopensourcelibrary,
title={Bergson: An Open Source Library for Data Attribution},
title={Bergson: An Open Source Library for Data Attribution},
author={Lucia Quirke and Louis Jaburi and David Johnston and William Z. Li and Gonçalo Paulo and Guillaume Martres and Girish Gupta and Stella Biderman and Nora Belrose},
year={2026},
eprint={2606.11660},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2606.11660},
url={https://arxiv.org/abs/2606.11660},
}
```

Expand Down
2 changes: 1 addition & 1 deletion bergson/approx_unrolling/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def approx_unrolling_pipeline(
query_cfg.projection_dim = 0
query_preprocess_cfg = PreprocessConfig(aggregation="mean")
save_run_config(
Build(query_cfg, query_preprocess_cfg, None),
Build(query_cfg, query_preprocess_cfg),
query_cfg.partial_run_path,
)
build(query_cfg, query_preprocess_cfg)
Expand Down
12 changes: 3 additions & 9 deletions bergson/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tqdm.auto import tqdm

from bergson.collection import collect_gradients
from bergson.config.config import HessianConfig, IndexConfig, PreprocessConfig
from bergson.config.config import IndexConfig, PreprocessConfig
from bergson.data import allocate_batches
from bergson.distributed import (
cap_world_size_to_dataset,
Expand All @@ -35,7 +35,6 @@ def build_worker(
world_size: int,
index_cfg: IndexConfig,
preprocess_cfg: PreprocessConfig,
hessian_cfg: HessianConfig | None,
ds: Dataset | IterableDataset,
):
"""
Expand Down Expand Up @@ -75,12 +74,9 @@ def build_worker(
)

model, target_modules = setup_model_and_peft(index_cfg)
skip_hessians = hessian_cfg is None
processor = create_processor(model, index_cfg, target_modules)

maybe_auto_batch_size(
index_cfg, model, ds, processor, target_modules, rank, skip_hessians
)
maybe_auto_batch_size(index_cfg, model, ds, processor, target_modules, rank)

attention_cfgs = {
module: index_cfg.attention for module in index_cfg.split_attention_modules
Expand All @@ -94,7 +90,6 @@ def build_worker(
"target_modules": target_modules,
"attention_cfgs": attention_cfgs,
"preprocess_cfg": preprocess_cfg,
"skip_hessians": skip_hessians,
}

if isinstance(ds, Dataset):
Expand Down Expand Up @@ -139,7 +134,6 @@ def flush(kwargs):
def build(
index_cfg: IndexConfig,
preprocess_cfg: PreprocessConfig,
hessian_cfg: HessianConfig | None = None,
):
"""
Convert a dataset to an on-disk index.
Expand Down Expand Up @@ -170,7 +164,7 @@ def build(
launch_distributed_run(
"build",
build_worker,
[index_cfg, preprocess_cfg, hessian_cfg, ds],
[index_cfg, preprocess_cfg, ds],
dist_cfg,
)

Expand Down
37 changes: 4 additions & 33 deletions bergson/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,40 +62,18 @@ def execute(self):

@dataclass
class Build(Serializable):
"""
Build a gradient index. Simultaneously approximate an autocorrelation Hessian
by passing `--method autocorrelation`."""
"""Build a gradient index."""

index_cfg: IndexConfig

preprocess_cfg: PreprocessConfig

# Pass `--method autocorrelation` to simultaneously approximate a Hessian.
# `build` only supports autocorrelation Hessians; other methods go through
# the `hessian` command.
hessian_cfg: HessianConfig | None = None

def execute(self):
"""Build the gradient index."""
if self.index_cfg.skip_index and self.hessian_cfg is None:
raise ValueError(
"if skip_index is True HessianConfig.method must be provided"
)

if (
self.hessian_cfg is not None
and self.hessian_cfg.method != "autocorrelation"
):
raise ValueError(
f"build only supports autocorrelation Hessians, got "
f"'{self.hessian_cfg.method}'. Use the `hessian` command for "
f"{self.hessian_cfg.method}."
)

validate_run_path(self.index_cfg)

save_run_config(self, self.index_cfg.partial_run_path)
build(self.index_cfg, self.preprocess_cfg, self.hessian_cfg)
build(self.index_cfg, self.preprocess_cfg)


@dataclass
Expand Down Expand Up @@ -134,16 +112,9 @@ class Hessian(Serializable):

def execute(self):
"""Compute Hessian approximation."""

validate_run_path(self.index_cfg)

if self.hessian_cfg.method == "autocorrelation":
self.index_cfg.skip_index = True
save_run_config(self, self.index_cfg.partial_run_path)
build(self.index_cfg, PreprocessConfig(), self.hessian_cfg)
else:
save_run_config(self, self.index_cfg.partial_run_path)
approximate_hessians(self.index_cfg, self.hessian_cfg)
save_run_config(self, self.index_cfg.partial_run_path)
approximate_hessians(self.index_cfg, self.hessian_cfg)


@dataclass
Expand Down
34 changes: 14 additions & 20 deletions bergson/cli/trackstar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,28 @@
from ..config.config import (
HessianConfig,
IndexConfig,
PreprocessConfig,
TrackstarConfig,
)
from ..config.config_io import save_run_config
from ..hessians.hessian_approximations import approximate_hessians
from ..process_grads import mix_autocorrelation_matrices
from ..score.score import score_dataset
from ..utils.worker_utils import validate_run_path
from .commands import Build, Mix, Score
from .commands import Build, Hessian, Mix, Score


def _limit_split_for_hess(cfg: IndexConfig) -> None:
def _limit_split_for_hess(cfg: IndexConfig, stats_sample_size: int | None) -> None:
"""Limit the data split to stats_sample_size for hessian-only steps."""
# TODO this code is hacky and bad

if cfg.stats_sample_size is not None:
if stats_sample_size is not None:
split = cfg.data.split
# Append HF slice notation if not already present
if "[" not in split:
cfg.data.split = f"{split}[:{cfg.stats_sample_size}]"
cfg.data.split = f"{split}[:{stats_sample_size}]"
else:
base_split = split.split("[")[0]
cfg.data.split = f"{base_split}[:{cfg.stats_sample_size}]"
cfg.data.split = f"{base_split}[:{stats_sample_size}]"


def _step_complete(path: str, resume: bool) -> bool:
Expand All @@ -49,8 +49,6 @@ def trackstar(index_cfg: IndexConfig, trackstar_cfg: TrackstarConfig):
scores_path = f"{run_path}/scores"
resume = trackstar_cfg.resume

# Steps 1-2 only compute hessians, so don't preprocess grads.
hess_preprocess_cfg = PreprocessConfig()
hess_cfg = HessianConfig(method="autocorrelation")

def _validate(cfg: IndexConfig):
Expand All @@ -64,31 +62,27 @@ def _validate(cfg: IndexConfig):
if not _step_complete(value_hess_path, resume):
value_hess_cfg = deepcopy(index_cfg)
value_hess_cfg.run_path = value_hess_path
value_hess_cfg.skip_index = True
if trackstar_cfg.num_stats_sample_hessian:
_limit_split_for_hess(value_hess_cfg)
_limit_split_for_hess(value_hess_cfg, trackstar_cfg.stats_sample_size)
_validate(value_hess_cfg)
save_run_config(
Build(value_hess_cfg, hess_preprocess_cfg, hess_cfg),
Hessian(hessian_cfg=hess_cfg, index_cfg=value_hess_cfg),
value_hess_cfg.partial_run_path,
)
build(value_hess_cfg, hess_preprocess_cfg, hess_cfg)
approximate_hessians(value_hess_cfg, hess_cfg)

# Step 2: Compute hessians on query dataset
print("Step 2/5: Computing hessians on query dataset...")
if not _step_complete(query_hess_path, resume):
query_hess_cfg = deepcopy(index_cfg)
query_hess_cfg.run_path = query_hess_path
query_hess_cfg.data = deepcopy(trackstar_cfg.query)
query_hess_cfg.skip_index = True
if trackstar_cfg.num_stats_sample_hessian:
_limit_split_for_hess(query_hess_cfg)
_limit_split_for_hess(query_hess_cfg, trackstar_cfg.stats_sample_size)
_validate(query_hess_cfg)
save_run_config(
Build(query_hess_cfg, hess_preprocess_cfg, hess_cfg),
Hessian(hessian_cfg=hess_cfg, index_cfg=query_hess_cfg),
query_hess_cfg.partial_run_path,
)
build(query_hess_cfg, hess_preprocess_cfg, hess_cfg)
approximate_hessians(query_hess_cfg, hess_cfg)

# Step 3: Mix query and value hessians
print("Step 3/5: Mixing hessians...")
Expand Down Expand Up @@ -138,10 +132,10 @@ def _validate(cfg: IndexConfig):

_validate(query_cfg)
save_run_config(
Build(query_cfg, trackstar_cfg.preprocess_cfg, None),
Build(query_cfg, trackstar_cfg.preprocess_cfg),
query_cfg.partial_run_path,
)
build(query_cfg, trackstar_cfg.preprocess_cfg, None)
build(query_cfg, trackstar_cfg.preprocess_cfg)

# Step 5: Score value dataset against query using mixed hessian
print("Step 5/5: Scoring value dataset...")
Expand Down
2 changes: 0 additions & 2 deletions bergson/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def collect_gradients(
processor: GradientProcessor,
cfg: IndexConfig,
*,
skip_hessians: bool = True,
batches: list[list[int]] | None = None,
target_modules: set[str] | None = None,
attention_cfgs: dict[str, AttentionConfig] | None = None,
Expand All @@ -28,7 +27,6 @@ def collect_gradients(
model=model.base_model, # type: ignore
cfg=cfg,
processor=processor,
skip_hessians=skip_hessians,
target_modules=target_modules,
data=data,
scorer=scorer,
Expand Down
31 changes: 7 additions & 24 deletions bergson/collector/gradient_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from bergson.builder import Builder
from bergson.collector.collector import HookCollectorBase
from bergson.config import IndexConfig, PreprocessConfig
from bergson.process_autocorrelation import process_autocorrelation_matrices
from bergson.score.scorer import Scorer
from bergson.utils.utils import get_gradient_dtype

Expand All @@ -33,9 +32,6 @@ class GradientCollector(HookCollectorBase):
cfg: IndexConfig
"""Configuration for gradient index."""

skip_hessians: bool = True
"""Whether to skip estimating autocorrelation hessian statistics."""

mod_grads: dict = field(default_factory=dict)
"""Temporary storage for gradients during a batch, keyed by module name."""

Expand All @@ -48,6 +44,11 @@ class GradientCollector(HookCollectorBase):
scorer: Scorer | None = None
"""Optional scorer for computing scores instead of building an index."""

skip_index: bool = False
"""Collect gradients into ``mod_grads`` without writing an on-disk index
(e.g. batch-size probing or gradient inspection). No effect when a
``scorer`` is set, since scoring already skips the index."""

def setup(self) -> None:
"""
Initialize collector state.
Expand Down Expand Up @@ -77,7 +78,7 @@ def setup(self) -> None:
)

# Compute whether we need to save the index
self.save_index = self.scorer is None and not self.cfg.skip_index
self.save_index = self.scorer is None and not self.skip_index

if self.save_index:
grad_sizes = {name: math.prod(s) for name, s in self.shapes().items()}
Expand All @@ -94,20 +95,12 @@ def setup(self) -> None:

@HookCollectorBase.split_attention_heads
def backward_hook(self, module: nn.Module, g: Float[Tensor, "N S O"]):
"""Compute per-sample gradient, accumulate autocorrelation matrix, and store."""
"""Compute the per-sample gradient and store it for the index."""
name: str = module._name # type: ignore[assignment]
P = self._compute_gradient(module, g)

global_proj = self.processor.projection_target == "global"

# Collect per-module hessians when projection target is per_module
if not self.skip_hessians and not global_proj:
P = P.float()
if name in self.processor.hessians:
self.processor.hessians[name].addmm_(P.mT, P)
else:
self.processor.hessians[name] = P.mT @ P

if global_proj:
assert self.processor.projection_dim is not None
R = self.projection(
Expand Down Expand Up @@ -159,16 +152,6 @@ def teardown(self):
if dist.is_initialized():
dist.reduce(self.per_doc_losses, dst=0)

grad_sizes = {name: math.prod(s) for name, s in self.shapes().items()}
if self.processor.hessians:
process_autocorrelation_matrices(
self.processor,
self.processor.hessians,
len(self.data),
grad_sizes,
self.rank,
)

if self.builder:
self.builder.teardown()

Expand Down
14 changes: 4 additions & 10 deletions bergson/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,13 +400,6 @@ class IndexConfig(AttributionConfig, Serializable):
or a path to an optimizer state file directly) or a Hugging Face URI
``hf://<repo>[@<revision>][/<path>]``."""

skip_index: bool = False
"""Whether to skip building the gradient index."""

stats_sample_size: int | None = 10_000
"""Number of examples to use for estimating the autocorrelation Hessian.
This feature is experimental and may be removed."""

loss_fn: Literal["ce", "kl"] = "ce"
"""Loss function to use."""

Expand Down Expand Up @@ -677,9 +670,10 @@ class TrackstarConfig:
index hessians intersect at this component. Typical value is
~1000 out of ~65K total components."""

num_stats_sample_hessian: bool = True
"""Whether to use num_stats_sample items or the full dataset to
compute hessians."""
stats_sample_size: int | None = 10_000
"""Number of examples to use for estimating the autocorrelation Hessian
in the trackstar pipeline's hessian-fitting steps. Set to None to use
the full dataset."""

resume: bool = False
"""Skip pipeline steps whose output directory already exists."""
Loading
Loading