Skip to content

ok858ok/CP-ViT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

46 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CP-ViT

Code for "CP-ViT: Cascade Vision Transformer Pruning via Progressive Sparsity Prediction" on CIFAR-10/100.

CP-ViT: a cascade pruning framework named CP-ViT by predicting sparsity in ViT models progressively and dynamically to reduce computational redundancy while minimizing the accuracy loss. Specifically, we define the cumulative score to reserve the informative patches and heads across the ViT model for better accuracy. We also propose the dynamic pruning ratio adjustment technique based on layer-aware attention range. CP-ViT has great general applicability for practical deployment, which can be applied to a wide range of ViT models and can achieve superior accuracy with or without finetuning.

Usage

Prerequisite

We have tested our codes under the following environments:

python == 3.9.5
pytorch == 1.9.0
torchvision == 0.10.0
CUDA == 11.2

Pretrained Vision Transformer Models

To start with, you can first download pre-trained models from:

ViT-B_16/224 cifar-10

ViT-B_16/224 cifar-100

and place them under folder./CP-ViT/output/.

Of course you can download other pre-trained models from Google Cloud .

Pruning without Finetuning

We then prune ViT model without finetuning by:

python3 eval.py \
        --name="CP-ViT test" \
        --dataset="cifar10" \
        --model_type="ViT-B_16" \
        --pretrained_dir='output/cifar10_checkpoint.pth' \
        --eval_batch_size=64 

Pruning with Finetuning

We can finetune the CP-ViT model by:

python3 train.py \
        --name="CP-ViT finetune" \
        --dataset="cifar10" \
        --model_type="ViT-B_16" \
        --pretrained_dir='output/cifar10_checkpoint.pth' \
        --train_batch_size=64 \
        --eval_every=3125 \
        --learning_rate=3e-2 \
        --num_steps=10000 \
        --decay_type="cosine" 

Acknowledge Related Repos

Pytorch Image Models: https://github.com/rwightman/pytorch-image-models

ViT: https://github.com/google-research/vision_transformer

DeiT: https://github.com/facebookresearch/deit

About

Code for "CP-ViT: Cascade Vision Transformer Pruning via Progressive Sparsity Prediction" on CIFAR-10/100.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages