A comprehensive framework for training and quantizing Kolmogorov-Arnold Networks (KANs) with support for both Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ).
- Overview
- Features
- Project Structure
- Installation
- Quick Start
- Configuration
- Supported Models
- Quantization Methods
- Datasets
- Usage Examples
- Monitoring and Logging
QuantKAN provides a unified framework for quantizing Kolmogorov-Arnold Networks, enabling efficient deployment on resource-constrained hardware. The framework supports multiple KAN variants and a wide range of quantization techniques for both weights and activations.
-
Multiple KAN Implementations:
- EfficientKAN (B-spline based)
- FastKAN (RBF-based)
- PyKAN (Original KAN implementation)
- KAGN (Gram-based KAN with convolutions)
-
Quantization-Aware Training (QAT):
- LSQ (Learned Step Size Quantization)
- LSQ+ (Enhanced LSQ)
- PACT (Parameterized Clipping Activation)
- DoReFa
- DSQ (Differentiable Soft Quantization)
- QIL (Quantization Interval Learning)
-
Post-Training Quantization (PTQ):
- Uniform Quantization
- GPTQ (Optimal Brain Quantization)
- AdaRound (Adaptive Rounding)
- AWQ (Activation-aware Weight Quantization)
- HAWQ-v2 (Hessian-Aware Quantization)
- BRECQ (Block Reconstruction Quantization)
- SmoothQuant
- ZeroQ (Zero-shot Quantization)
-
Training Features:
- Mixed precision training (AMP)
- Gradient accumulation
- Early stopping
- EMA (Exponential Moving Average)
- Multiple learning rate schedulers
- TensorBoard logging
├── main.py # Main training entry point (QAT)
├── runner.py # PTQ runner for post-training quantization
├── ptq_eval.py # PTQ evaluation script
├── process.py # Training/validation loops
├── config.yaml # Default configuration
├── logging.conf # Logging configuration
│
├── configs/ # Experiment Configurations
│ ├── mnist_*.yaml # MNIST experiment configs
│ ├── cifar10_*.yaml # CIFAR-10 experiment configs
│ ├── cifar100_*.yaml # CIFAR-100 experiment configs
│ ├── tinyimagenet_*.yaml # TinyImageNet experiment configs
│ └── imagenet_*.yaml # ImageNet experiment configs
│
├── datasets/ # Dataset storage directory
│
├── models/ # Model Definitions
│ ├── model.py # Model factory
│ ├── kan_models.py # KAN model architectures
│ ├── vgg_kan_cifar.py # VGG-KAN for CIFAR
│ └── vgg_kan_imagenet.py # VGG-KAN for ImageNet
│
├── kans/ # KAN Layer Implementations
│ ├── efficient_kan.py # EfficientKAN layers
│ ├── fastkan.py # FastKAN layers
│ ├── KANLayer.py # PyKAN layers
│ ├── kagn_kagn_conv.py # KAGN convolutional layers
│ └── conv_kagn.py # KAN convolution implementations
│
├── qat/ # QAT Quantizers
│ ├──quantizers
| | ├── quantizer.py # Base quantizer class
| | ├── lsq.py # Learned Step Size Quantization
| | ├── lsq_plus.py # LSQ+ implementation
| | ├── pact.py # PACT quantizer
| | ├── dorefa.py # DoReFa quantizer
| | ├── dsq.py # DSQ quantizer
| | └── qil.py # QIL quantizer
│ |
│ |
│ ├── quant_nn.py # Quantized standard NN layers
│ ├── quant_efficient_kan.py # Quantized EfficientKAN
│ ├── quant_fastkan.py # Quantized FastKAN
│ ├── quant_kagn.py # Quantized KAGN
│ └── quant_pykan.py # Quantized PyKAN
│
├── ptq/ # PTQ Methods
│ ├── uniform.py # Uniform quantization
│ ├── gptq.py # GPTQ implementation
│ ├── gptq_strict.py # Strict GPTQ variant
│ ├── adaround.py # AdaRound implementation
│ ├── awq.py # AWQ implementation
│ ├── hawq_v2.py # HAWQ-v2 implementation
│ ├── brecq.py # BRECQ implementation
│ ├── smoothquant.py # SmoothQuant implementation
│ ├── zeroq.py # ZeroQ implementation
│ └── actquant.py # Activation quantization utilities
# Core dependencies
pip install torch torchvision
pip install pyyaml munch scikit-learn tensorboard matplotlib tqdm pandas einops huggingface_hubgit clone <repository-url>
cd QuantKAN# Train a KAN model on CIFAR-10 with 4-bit quantization
python main.py configs/cifar10_simplekagn.yaml# Run PTQ on a pretrained model
python runner.py ptq \
--method gptq \
--gptq_impl block \
--gptq_mode block \
--block_size 128 \
--config configs/mnist_eff_fc.yaml \
--ckpt out/MNIST_KAN_EFF_FC/MNIST_KAN_EFF_FC_best.pth.tar \
--output_ckpt runs/kan_unified/kan_gptq_block_block_w4a32.pt \
--nsamples 2048 \
--damping 1e-4# Evaluate a quantized model
python main.py configs/cifar10_simplekagn.yaml --eval --resume.path path/to/checkpoint.pth.tarConfiguration is managed through YAML files. Create a custom config by copying and modifying an existing template.
# Experiment name
name: CIFAR10_KAGN_W4A4
# Dataset configuration
dataloader:
dataset: cifar10 # mnist, cifar10, cifar100, tinyimagenet, imagenet
num_classes: 10
path: datasets
batch_size: 128
val_split: 0.0
# Model architecture
arch: kagn_simple_cifar10
# Quantization settings
quan:
quantization: true
act:
mode: lsq # lsq, dorefa, pact, qil, lsq_plus, dsq
bit: 4
per_channel: false
symmetric: false
all_positive: true
weight:
mode: lsq
bit: 4
per_channel: false
symmetric: true
# Training settings
epochs: 250
optimizer:
name: adamw
learning_rate: 0.0001
weight_decay: 0.00001
lr_scheduler:
name: exp
gamma: 0.975
# PTQ settings (for runner.py)
ptq:
w_bit: 4
a_bit_default: 8
calib_batches: 32
per_channel: true| Model | Description | Dataset |
|---|---|---|
kan_mlp_mnist |
Simple KAN MLP | MNIST |
kan_mlp_mnist_fastkan |
FastKAN MLP | MNIST |
kagn_simple_cifar10 |
KAGN for CIFAR-10 | CIFAR-10 |
kagn_v2 |
VGG-KAGN v2 | CIFAR-100, ImageNet |
kagn_v4 |
VGG-KAGN v4 | ImageNet |
- KANLinear: B-spline based KAN layer (EfficientKAN)
- FastKANLayer: RBF-based fast KAN layer
- GRAMLayer: Gram polynomial based KAN layer
- KAGNConv2DLayer: KAN convolutional layer
| Method | Description |
|---|---|
| LSQ | Learned Step Size Quantization |
| LSQ+ | Enhanced LSQ with asymmetric quantization |
| PACT | Parameterized Clipping Activation |
| DoReFa | Gradient quantization |
| DSQ | Differentiable Soft Quantization |
| QIL | Quantization Interval Learning |
| Method | Description |
|---|---|
| Uniform | Min-max uniform quantization |
| GPTQ | Optimal brain quantization with Hessian |
| AdaRound | Adaptive rounding optimization |
| AWQ | Activation-aware weight quantization |
| HAWQ-v2 | Hessian-aware mixed precision |
| BRECQ | Block reconstruction quantization |
| SmoothQuant | Activation-weight migration |
| ZeroQ | Zero-shot data generation |
Please use the config files in configs folder. For PTQ approaches please refer to the ReadME file in the ptq directory.
python main.py configs/cifar10_simplekagn.yaml# In the config file
resume:
path: out/CIFAR10_KAGN_W4A4/checkpoint.pth.tar
lean: false # Full resume (optimizer, scheduler, etc.)python main.py cifar10_simplekagn.yamlYou can evaluate a checkpoint by making eval 'true' in the config file.
python main.py configs/cifar10_simplekagn.yaml \
--eval \
--resume.path out/CIFAR10_KAGN_W4A4/best.pth.tarpretrained:
load_from_hf: truequan:
excepts:
conv1:
weight:
bit: 8 # First layer at 8-bit
fc:
weight:
bit: 8 # Last layer at 8-bitTraining logs are saved to out/<experiment_name>/tb_runs/. View with:
tensorboard --logdir out/<experiment_name>/tb_runs/- Training/validation loss
- Top-1 and Top-5 accuracy
- Learning rate
- Quantizer statistics (scale, clip ratios, MAE, MSE)
- Gradient statistics
Checkpoints are saved to out/<experiment_name>/:
<name>_best.pth.tar: Best validation accuracy<name>_checkpoint.pth.tar: Latest checkpoint
amp:
enable: true
dtype: fp16
grad_scaler: trueearly_stopping:
enable: true
monitor: val_top1
mode: max
patience: 15
min_delta: 0.01optimizer:
grad_clip_norm: 1.0
skip_nonfinite_grads: truePlease refer to the repository license file for licensing information.
- Initial QAT code structure adopted from lsq-net
- KAN implementations inspired by efficient-kan and pykan
- KAGN implementation from TorchKAN