Figure 1: Schematic overview of our proposal on model sparsity driven MU. |
This is the official code repository for the NeurIPS 2023 Spotlight paper Model Sparsity Can Simplify Machine Unlearning.
In response to recent data regulation requirements, machine unlearning (MU) has emerged as a critical process to remove the influence of specific examples from a given model. Although exact unlearning can be achieved through complete model retraining using the remaining dataset, the associated computational costs have driven the development of efficient, approximate unlearning techniques. Moving beyond data-centric MU approaches, our study introduces a novel model-based perspective: model sparsification via weight pruning, which is capable of reducing the gap between exact unlearning and approximate unlearning. We show in both theory and practice that model sparsity can boost the multi-criteria unlearning performance of an approximate unlearner, closing the approximation gap, while continuing to be efficient. This leads to a new MU paradigm, termed prune first, then unlearn, which infuses a sparse model prior into the unlearning process. Building on this insight, we also develop a sparsity-aware unlearning method that utilizes sparsity regularization to enhance the training process of approximate unlearning. Extensive experiments show that our proposals consistently benefit MU in various unlearning scenarios. A notable highlight is the 77% unlearning efficacy gain of fine-tuning (one of the simplest unlearning methods) when using sparsity-aware unlearning. Furthermore, we demonstrate the practical impact of our proposed MU methods in addressing other machine learning challenges, such as defending against backdoor attacks and enhancing transfer learning.
conda env create -f environment.yml
- Run
trainfirst to produce a sparse checkpoint. - Run
unlearnwith that checkpoint (via--mask) to remove target-data influence. - Check accuracy + MIA results, and optionally run
backdoorfor security-focused evaluation.
DATA=./data
DATASET=cifar10
ARCH=resnet18
SAVE_DIR=./runs/${DATASET}_${ARCH}
MASK_PATH=${SAVE_DIR}/0model_SA_best.pth.tar# OMP
python -u main.py train --profile imp \
--data ${DATA} --dataset ${DATASET} --arch ${ARCH} \
--save_dir ${SAVE_DIR} --prune_type rewind_lt --rewind_epoch 8 \
--rate 0.95 --pruning_times 1 --workers 8
# IMP
python -u main.py train --profile imp \
--data ${DATA} --dataset ${DATASET} --arch ${ARCH} \
--save_dir ${SAVE_DIR} --prune_type rewind_lt --rewind_epoch 8 \
--rate 0.2 --pruning_times 5 --workers 8
# SynFlow
python -u main.py train --profile synflow \
--data ${DATA} --dataset ${DATASET} --arch ${ARCH} \
--save_dir ${SAVE_DIR}_synflow --prune_type rewind_lt --rewind_epoch 8 \
--rate 0.95 --pruning_times 1 --workers 8This command trains a model and then prunes it into a sparse structure, which is the paper's "prune first" step for creating a sparse prior before unlearning; pruning_times controls whether pruning is one-shot or iterative, and the key output is a sparse checkpoint (for example 0model_SA_best.pth.tar) for the next stage.
Other training profiles:
python -u main.py train --profile ls --data ${DATA} --dataset ${DATASET} --arch ${ARCH} --save_dir ${SAVE_DIR}_ls
python -u main.py train --profile sam --data ${DATA} --dataset ${DATASET} --arch ${ARCH} --save_dir ${SAVE_DIR}_sam
python -u main.py train --profile vit --data ${DATA} --dataset ${DATASET} --arch ${ARCH} --save_dir ${SAVE_DIR}_vitProfile summary (high level):
imp: standard SGD + cross-entropy baseline training, then iterative pruning.ls: same pipeline as baseline, but uses label smoothing loss during training.sam: same pipeline shape, but uses SAM optimizer to seek flatter minima before/through pruning.vit: keeps the same prune framework, but uses a ViT-style optimizer/scheduler setup (Adam + cosine schedule).synflow: applies SynFlow global pruning first, then trains the pruned model.
Forget-setting flags:
Use --num_indexes_to_replace to set how many samples to forget, --class_to_replace to choose the source class (default 0), and --mask to pass the sparse checkpoint from pruning.
# Retrain
python -u main.py unlearn --save_dir ${SAVE_DIR} --mask ${MASK_PATH} \
--dataset ${DATASET} --arch ${ARCH} --unlearn retrain \
--num_indexes_to_replace 4500 --unlearn_epochs 160 --unlearn_lr 0.1
# FT
python -u main.py unlearn --save_dir ${SAVE_DIR} --mask ${MASK_PATH} \
--dataset ${DATASET} --arch ${ARCH} --unlearn FT \
--num_indexes_to_replace 4500 --unlearn_epochs 10 --unlearn_lr 0.01
# GA
python -u main.py unlearn --save_dir ${SAVE_DIR} --mask ${MASK_PATH} \
--dataset ${DATASET} --arch ${ARCH} --unlearn GA \
--num_indexes_to_replace 4500 --unlearn_epochs 5 --unlearn_lr 0.0001
# FF (fisher_new)
python -u main.py unlearn --save_dir ${SAVE_DIR} --mask ${MASK_PATH} \
--dataset ${DATASET} --arch ${ARCH} --unlearn fisher_new \
--num_indexes_to_replace 4500 --alpha 0.2
# IU (wfisher)
python -u main.py unlearn --save_dir ${SAVE_DIR} --mask ${MASK_PATH} \
--dataset ${DATASET} --arch ${ARCH} --unlearn wfisher \
--num_indexes_to_replace 4500 --alpha 0.2
# l1-sparse (FT_prune)
python -u main.py unlearn --save_dir ${SAVE_DIR} --mask ${MASK_PATH} \
--dataset ${DATASET} --arch ${ARCH} --unlearn FT_prune \
--num_indexes_to_replace 4500 --alpha 0.2 --unlearn_lr 0.01 --unlearn_epochs 10ImageNet unlearning:
python -u main.py unlearn \
--dataset imagenet --arch resnet18 --imagenet_arch \
--save_dir ./runs/imagenet_resnet18 \
--mask ./runs/imagenet_resnet18/0model_SA_best.pth.tar \
--class_to_replace 0 --unlearn FT --unlearn_epochs 5 --unlearn_lr 0.001This command removes the influence of a target class or selected samples from a trained model and runs on top of the sparse model from train (via --mask), which is the central idea in this work; different --unlearn choices trade off quality and speed (for example, retrain is strongest but slowest, FT is a fast fine-tuning baseline, and GA actively pushes the model away from forgotten data), and the outputs include both utility and privacy-style signals such as accuracy and MIA.
Available --unlearn methods (algorithm-level):
retrain: retrains on retain data only, approximating full retraining without forgotten data.FT: fine-tunes the model on retain data for a small number of epochs (efficient baseline).GA: performs gradient ascent on forget data (negative CE) to actively unfit forgotten samples.fisher_new(FF): uses a Hessian/Fisher-style curvature estimate and samples new weights from a noise model.wfisher(IU): computes a WoodFisher-style second-order perturbation using retain/forget gradients, then updates parameters once.FT_prune(l1-sparse): applies sparsity-aware fine-tuning with L1 regularization as an unlearning strategy.
python -u main.py backdoor \
--save_dir ${SAVE_DIR}_backdoor \
--mask ${SAVE_DIR}_backdoor/mask.pt \
--dataset ${DATASET} --arch ${ARCH} \
--unlearn FT --num_indexes_to_replace 4500 --class_to_replace 0This command evaluates unlearning in a security setting by first simulating poisoned trigger behavior and then applying unlearning to clean it, with the key readout being whether attack success drops while clean accuracy remains usable.
# resume pruning
python -u main.py train --profile imp --resume --checkpoint ${SAVE_DIR}/0checkpoint.pth.tar --save_dir ${SAVE_DIR}
# resume unlearning
python -u main.py unlearn --resume --save_dir ${SAVE_DIR} --mask ${MASK_PATH} --unlearn FTThis command resumes interrupted long-running experiments, which is useful for reproducibility and large sweeps where pruning or unlearning may take many hours.
The source code is organized as follows:
-
Entrypoint
main.py: Unified CLI entrypoint.- Subcommands:
train->src/pipelines/train_pipeline.pyunlearn->src/pipelines/forget_pipeline.pybackdoor->src/pipelines/backdoor_pipeline.py
-
Core CLI / Utilities
src/core/arg_parser.py: Shared argument definitions used by all workflows.src/core/utils.py: Setup, checkpointing, data conversion, and general helper utilities.
-
Pipelines
src/pipelines/train_pipeline.py: Training + pruning implementations (IMP/LS/SAM/ViT profiles and SynFlow).src/pipelines/forget_pipeline.py: Unlearning pipeline; loads sparse masks, runs unlearn methods, and evaluates accuracy/MIA.src/pipelines/backdoor_pipeline.py: Backdoor cleansing pipeline; builds poisoned data loaders and applies unlearning.
-
Training / Optimization / Pruning
src/trainer/: Train/validate loops (standard SGD loop and SAM training loop).src/optim/: Optimization utilities (e.g., SAM and label-smoothing criterion).src/optim/unlearn/: Unlearning method implementations (e.g.,retrain,FT,GA, and variants).src/pruner/: Pruning utilities and methods (mask extraction/application, OMP-style pruning, SynFlow logic).
-
Data / Models / Evaluation
src/dataio/dataset.py: Dataset loading/splitting and forget-sample marking utilities.src/dataio/imagenet.pyandsrc/dataio/lmdb_dataset.py: ImageNet / LMDB data utilities.src/models/: Model definitions.src/evaluation/: MIA evaluation code (including SVC-based MIA metrics).
If you find this repository or the ideas presented in our paper useful, please consider citing.
@inproceedings{jia2023model,
title={Model Sparsity Can Simplify Machine Unlearning},
author={Jia, Jinghan and Liu, Jiancheng and Ram, Parikshit and Yao, Yuguang and Liu, Gaowen and Liu, Yang and Sharma, Pranay and Liu, Sijia},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023}
}
