From cce255a428db7ef493b545e27090c792d954bd71 Mon Sep 17 00:00:00 2001 From: Kartik Ahluwalia Date: Sat, 27 Dec 2025 01:04:07 +0530 Subject: [PATCH 1/2] docs: clarify that micro_acc_steps does not reduce memory usage --- README.md | 2 +- sparsify/config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a2916ce7..d169a95a 100644 --- a/README.md +++ b/README.md @@ -125,7 +125,7 @@ This is simple, but very memory inefficient. If you want to train SAEs for many torchrun --nproc_per_node gpu -m sparsify meta-llama/Meta-Llama-3-8B --distribute_modules --batch_size 1 --layer_stride 2 --grad_acc_steps 8 --ctx_len 2048 --k 192 --load_in_8bit --micro_acc_steps 2 ``` -The above command trains an SAE for every _even_ layer of Llama 3 8B, using all available GPUs. It accumulates gradients over 8 minibatches, and splits each minibatch into 2 microbatches before feeding them into the SAE encoder, thus saving a lot of memory. It also loads the model in 8-bit precision using `bitsandbytes`. This command requires no more than 48GB of memory per GPU on an 8 GPU node. +The above command trains an SAE for every _even_ layer of Llama 3 8B, using all available GPUs. It accumulates gradients over 8 minibatches, and uses a micro_acc_steps multiplier of 2 for the gradient accumulation calculation. It also loads the model in 8-bit precision using `bitsandbytes`. This command requires no more than 48GB of memory per GPU on an 8 GPU node. ## TODO diff --git a/sparsify/config.py b/sparsify/config.py index 6854d042..215450f1 100644 --- a/sparsify/config.py +++ b/sparsify/config.py @@ -52,7 +52,7 @@ class TrainConfig(Serializable): """Number of steps over which to accumulate gradients.""" micro_acc_steps: int = 1 - """Chunk the activations into this number of microbatches for training.""" + """Multiplier for gradient accumulation (Note: does not currently split data/save memory)""" loss_fn: Literal["ce", "fvu", "kl"] = "fvu" """Loss function to use for training the sparse coders. From 2623a4e34d38541fbb6edcafd76c8b404cf9c49c Mon Sep 17 00:00:00 2001 From: Kartik Ahluwalia Date: Sat, 27 Dec 2025 01:16:51 +0530 Subject: [PATCH 2/2] style: fix line length for linter --- sparsify/config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sparsify/config.py b/sparsify/config.py index 215450f1..fd5f9d6e 100644 --- a/sparsify/config.py +++ b/sparsify/config.py @@ -52,7 +52,9 @@ class TrainConfig(Serializable): """Number of steps over which to accumulate gradients.""" micro_acc_steps: int = 1 - """Multiplier for gradient accumulation (Note: does not currently split data/save memory)""" + """Multiplier for gradient accumulation. + Note: does not currently split data or save memory. + """ loss_fn: Literal["ce", "fvu", "kl"] = "fvu" """Loss function to use for training the sparse coders.