This folder contains code and experiment artifacts for both one-shot magnitude pruning, iterative magnitude pruning. We implement our regrowth strategy applying to models (e.g., resnet20, vgg16, alexnet) and evaluated on CIFAR-10 dataset.
The two main regrowth implementations are:
-
Reference-based regrowth (
rl_regrowth_nas.py): RL-based NAS allocates regrowth per-layer; within each layer weights are selected using SSIM-based layer priority + reference masks/weights. -
Saliency-based regrowth (
rl_saliency_regrowth.py): RL-based NAS allocates regrowth per-layer; within each layer weights are selected via:$s_i \propto \left(\frac{\partial \mathcal{L}}{\partial \theta_i}\right)^2,\theta_i^2$ .
A timing benchmark comparing both approaches is provided in benchmark_regrowth_methods.py and quick_benchmark.py.
You start from a highly sparse (pruned) network with masks (PyTorch pruning reparameterization):
- Each prunable module has
weight_orig(trainable dense weights)weight_mask(binary mask, 1=active, 0=pruned)- effective weight is
weight = weight_orig * weight_mask
Regrowth updates the mask by turning some pruned connections back on:
- pick
$K$ indices among currently pruned weights (weight_mask == 0) - set
weight_mask[idx] = 1 - initialize the newly regrown weights (copy from reference, zeros, Kaiming, etc.)
- optionally mini-finetune / finetune to recover accuracy
Regrowth is usually done to move from 98% sparsity → 97% sparsity (i.e., regrow 1% of all weights).
main.py: CIFAR10 pretraining + iterative pruning loop (saves pruned checkpoints)utils/model_loader.py: constructs models (resnet20,vgg16,alexnet, …)utils/data_loader.py: CIFAR-10 dataloadersutils/analysis_utils.py: pruning reparam helpers, SSIM feature extraction, mask stats
rl_regrowth_nas.py: RL allocation + reference-mask selectionrl_saliency_regrowth.py: RL allocation + saliency-based selection
utils/analysis_utils.py:count_pruned_params(model)— counts total vs surviving parameters- SSIM utilities:
BlockwiseFeatureExtractor,compute_block_ssim
utils/saliency_analysis.py: FairPrune-style saliency computation + plotssingle_layer_regrowth_analysis.py: runs “all budget on one layer” experiments (tests SSIM ↔ improvement correlation)
benchmark_regrowth_methods.py: end-to-end timing benchmark (SSIM vs saliency preprocessing, selection, mask update, mini-finetune)quick_benchmark.py: one-episode benchmark wrapper
inspect_checkpoint.py: inspection + finetune flows for saved checkpoints
Plase check requirements.txt for details, and you can use pip install -r requirements.txt to install all required packages.
utils/data_loader.py uses data_dir='./data' by default, which will download the required dataset if it is not provided.
Most experiments follow this pipeline:
- Pretrain a dense model (or resume from
./{model}/checkpoint/ckpt.pth) - Prune + finetune to a target sparsity (e.g., 0.99)
- Regrow some weights (e.g., 2% of total, adjustable) to a less sparse target (0.97)
- Finetune selectively or fully to recover accuracy
Artifacts are usually saved under:
./{model_name}/checkpoint/ckpt.pth(dense / baseline)./{model_name}/ckpt_after_prune/pruned_finetuned_mask_{sparsity}.pth(sparse checkpoint)./rl_regrow_savedir/...or./rl_saliency_regrow_savedir/...(RL training outputs)
Pretraining and pruning are combined in main.py.
Key args (from main.py):
--m_name:resnet20,vgg16,alexnet, ...--pruner: pruning method (passed intoweight_pruner_loader(args.pruner))--iter_start,--iter_end: pruning iterations--max_epochs,--patience: early stopping
Example (prune to 99% sparsity checkpoint expected by benchmarks):
python main.py --m_name resnet20 --pruner magnitude --iter_end 1If you run pruning via the ICLR2021 implementation under iclr2021_solution/ (used for LAMP-style layerwise sparsity), the prune rate is not controlled by main.py.
Instead, the per-iteration prune amount is returned by:
iclr2021_solution/tools/modelloaders.py→model_and_opt_loader(...)→amount
In iclr2021_solution/iterate.py, this value is loaded as:
_, amount_per_it, batch_size, opt_pre, opt_post = model_and_opt_loader(args.model, DEVICE)
and then applied each iteration as:
pruner(model, amount_per_it)
Illustration (what to edit):
# iclr2021_solution/tools/modelloaders.py
elif model_string == 'resnet20':
model = ResNet20().to(DEVICE)
amount = 0.985 # <-- adjust this prune ratePractical guidance:
- Larger
amountmeans more weights pruned per iteration. - Total pruning depends on both
amountand the number of prune iterations you run initerate.pyvia--iter_end.
Example run:
python iclr2021_solution/iterate.py --model resnet20 --pruner lamp --iter_end 10This should produce:
./resnet20/ckpt_after_prune/pruned_finetuned_mask_0.99.pth
Implemented in rl_regrowth_nas.py.
Core ideas:
-
Layer priority is computed once using SSIM between feature maps from:
- pretrained model (
./{m_name}/checkpoint/ckpt.pth) - pruned model (typically
pruned_finetuned_mask_0.99.pth)
Lower SSIM → more feature drift → higher priority for regrowth.
- pretrained model (
-
Allocation: an LSTM controller samples an allocation ratio per target layer.
-
Selection: within each layer, select weights to regrow using reference masks/weights (e.g., from 0.95 sparse checkpoint).
-
(Mini)Finetune to evaluate reward.
Implemented in rl_saliency_regrowth.py.
Core ideas:
- RL controller allocates the regrowth budget across layers.
- A
SaliencyComputerestimates per-weight importance using:
This is a Fisher/Hessian-diagonal approximation plus magnitude scaling.
SaliencyBasedRegrowth.apply_regrowth(...)selects the top-K pruned weights by saliency and updates the mask.- Newly regrown weights can be initialized via
--init_strategy.
This repo contains several “metrics” used to analyze regrowth decisions and results.
utils/analysis_utils.py exposes count_pruned_params(model) which reports:
- total parameters in prunable layers
- surviving (unmasked) parameters
- pruned parameters
Common derived metrics:
-
Sparsity:
$s = \frac{\text{pruned}}{\text{total}}$ -
Regrowth budget (global):
$$K = (s_{\text{start}} - s_{\text{target}})\cdot N_{\text{total}}$$
(In single_layer_regrowth_analysis.py this is described as “global budget”.)
When regrowing, each layer has a capacity:
- capacity(layer) = number of currently pruned weights =
sum(weight_mask == 0)
This bounds how many weights you can regrow in that layer.
In the reference-based RL approach, each layer gets an SSIM score computed from features.
- Feature extraction:
BlockwiseFeatureExtractor(inutils/analysis_utils.py) - Similarity:
compute_block_ssim(features_pretrained, features_pruned)
Interpretation:
- Lower SSIM ⇒ larger feature drift from baseline ⇒ layer is more damaged by pruning ⇒ regrowing there is more likely to help.
single_layer_regrowth_analysis.py is designed to measure correlation between:
- SSIM(layer)
- accuracy improvement when regrowing only that layer
Two code paths exist:
rl_saliency_regrowth.pyusesSaliencyComputer(RigL-style accumulated gradients) with FairPrune formula.utils/saliency_analysis.pyprovides a more general “per-class” saliency analyzer:- supports second-order approximation (
use_second_order=True) - can compute per-class importance tensors and visualize distributions
- supports second-order approximation (
Interpretation:
- higher saliency ⇒ parameter is important for loss/accuracy
- regrow highest-saliency weights among currently pruned positions (RigL-inspired)
Use the benchmark to compare the major components:
- preprocessing
- reference method: SSIM feature extraction + SSIM scores
- saliency method: gradient accumulation
- selection
- mask update
- mini-finetuning
python quick_benchmark.py --m_name resnet20python benchmark_regrowth_methods.py --m_name resnet20 --num_runs 3Most regrowth and benchmark scripts assume the 99% sparse checkpoint exists:
python main.py --m_name resnet20 --pruner magnitude --iter_end 1python quick_benchmark.py --m_name resnet20single_layer_regrowth_analysis.py applies the full global regrowth budget to one layer at a time, then finetunes and measures the recovered accuracy.
python single_layer_regrowth_analysis.py --m_name resnet20Use utils/saliency_analysis.py if you want diagnostic plots (per-layer / per-class distributions). This is separate from the RL saliency regrowth code.
The table below summarizes the VGG16 results from your tracking sheet.
Setting A: Iterative magnitude pruning (10 iterations)
| Sparsity (%) | Test acc (mean±std) |
|---|---|
| 40.00 | 92.09 ± 0.28 |
| 64.00 | 92.43 ± 0.12 |
| 78.40 | 92.60 ± 0.15 |
| 87.04 | 92.68 ± 0.18 |
| 92.22 | 92.68 ± 0.19 |
| 95.33 | 92.49 ± 0.26 |
| 97.20 | 92.35 ± 0.02 |
| 98.32 | 91.92 ± 0.09 |
| 98.99 | 91.01 ± 0.05 |
| 99.40 | 90.31 ± 0.18 |
Setting B: One-shot magnitude pruning
| Sparsity (%) | Test acc (mean±std) |
|---|---|
| 95.0 | 92.43 ± 0.10 |
| 96.0 | 92.41 ± 0.22 |
| 97.0 | 92.27 ± 0.06 |
| 98.0 | 91.93 ± 0.17 |
| 99.0 | 91.03 ± 0.17 |
| 99.5 | 89.53 ± 0.10 |
| 99.9 | 81.08 ± 0.85 |
| Baseline | Step | Result | Improvement |
|---|---|---|---|
| 91.93 | step1 | 92.33 | +0.40 |
| 92.27 | step2 | 92.41 | +0.14 |
| 92.41 | step3 | 92.57 | +0.16 |
| 92.43 | step4 | 92.73 | +0.30 |
Below is a single-episode timing breakdown (VGG16 @ 99% sparsity, regrow
| Method | Weight selection total (s) | Notes |
|---|---|---|
| Reference-based (SSIM + reference masks/weights) | 0.0178 | Includes loading reference masks/weights (0.0012s) and selecting 147,148 weights (0.0166s) |
| Saliency-based (gradient saliency) | 0.0041 | Selects 147,148 weights by direct saliency ranking |
In this run, saliency-based selection is
If you use this codebase, results, or figures in your work, please refer:
@article{liu2025beyond,
title={Beyond One-Way Pruning: Bidirectional Pruning-Regrowth for Extreme Accuracy-Sparsity Tradeoff},
author={Liu, Junchen and Sheng, Yi},
journal={arXiv preprint arXiv:2511.11675},
year={2025}
}We also build on the public magnitude-pruning implementation from LAMP (ICLR 2021):
@article{lee2020layer,
title={Layer-adaptive sparsity for the magnitude-based pruning},
author={Lee, Jaeho and Park, Sejun and Mo, Sangwoo and Ahn, Sungsoo and Shin, Jinwoo},
journal={arXiv preprint arXiv:2010.07611},
year={2020}
}