Skip to content

OSU-STARLAB/QuantKAN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

QuantKAN: A Unified Quantization Framework for Kolmogorov–Arnold Networks

A comprehensive framework for training and quantizing Kolmogorov-Arnold Networks (KANs) with support for both Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ).


Table of Contents


Overview

QuantKAN provides a unified framework for quantizing Kolmogorov-Arnold Networks, enabling efficient deployment on resource-constrained hardware. The framework supports multiple KAN variants and a wide range of quantization techniques for both weights and activations.


Features

  • Multiple KAN Implementations:

    • EfficientKAN (B-spline based)
    • FastKAN (RBF-based)
    • PyKAN (Original KAN implementation)
    • KAGN (Gram-based KAN with convolutions)
  • Quantization-Aware Training (QAT):

    • LSQ (Learned Step Size Quantization)
    • LSQ+ (Enhanced LSQ)
    • PACT (Parameterized Clipping Activation)
    • DoReFa
    • DSQ (Differentiable Soft Quantization)
    • QIL (Quantization Interval Learning)
  • Post-Training Quantization (PTQ):

    • Uniform Quantization
    • GPTQ (Optimal Brain Quantization)
    • AdaRound (Adaptive Rounding)
    • AWQ (Activation-aware Weight Quantization)
    • HAWQ-v2 (Hessian-Aware Quantization)
    • BRECQ (Block Reconstruction Quantization)
    • SmoothQuant
    • ZeroQ (Zero-shot Quantization)
  • Training Features:

    • Mixed precision training (AMP)
    • Gradient accumulation
    • Early stopping
    • EMA (Exponential Moving Average)
    • Multiple learning rate schedulers
    • TensorBoard logging

Project Structure

├── main.py                 # Main training entry point (QAT)
├── runner.py               # PTQ runner for post-training quantization
├── ptq_eval.py             # PTQ evaluation script
├── process.py              # Training/validation loops
├── config.yaml             # Default configuration
├── logging.conf            # Logging configuration
│
├── configs/                # Experiment Configurations
│   ├── mnist_*.yaml            # MNIST experiment configs
│   ├── cifar10_*.yaml          # CIFAR-10 experiment configs
│   ├── cifar100_*.yaml         # CIFAR-100 experiment configs
│   ├── tinyimagenet_*.yaml     # TinyImageNet experiment configs
│   └── imagenet_*.yaml         # ImageNet experiment configs
│
├── datasets/               # Dataset storage directory
│
├── models/                 # Model Definitions
│   ├── model.py                # Model factory
│   ├── kan_models.py           # KAN model architectures
│   ├── vgg_kan_cifar.py        # VGG-KAN for CIFAR
│   └── vgg_kan_imagenet.py     # VGG-KAN for ImageNet
│
├── kans/                   # KAN Layer Implementations
│   ├── efficient_kan.py        # EfficientKAN layers
│   ├── fastkan.py              # FastKAN layers
│   ├── KANLayer.py             # PyKAN layers
│   ├── kagn_kagn_conv.py       # KAGN convolutional layers
│   └── conv_kagn.py            # KAN convolution implementations
│
├── qat/                    # QAT Quantizers
│   ├──quantizers 
|   |   ├── quantizer.py            # Base quantizer class
|   |   ├── lsq.py                  # Learned Step Size Quantization
|   |   ├── lsq_plus.py             # LSQ+ implementation
|   |   ├── pact.py                 # PACT quantizer
|   |   ├── dorefa.py               # DoReFa quantizer
|   |   ├── dsq.py                  # DSQ quantizer
|   |   └── qil.py                  # QIL quantizer
│   |
│   |
│   ├── quant_nn.py             # Quantized standard NN layers
│   ├── quant_efficient_kan.py  # Quantized EfficientKAN
│   ├── quant_fastkan.py        # Quantized FastKAN
│   ├── quant_kagn.py           # Quantized KAGN
│   └── quant_pykan.py          # Quantized PyKAN
│
├── ptq/                    # PTQ Methods
│   ├── uniform.py              # Uniform quantization
│   ├── gptq.py                 # GPTQ implementation
│   ├── gptq_strict.py          # Strict GPTQ variant
│   ├── adaround.py             # AdaRound implementation
│   ├── awq.py                  # AWQ implementation
│   ├── hawq_v2.py              # HAWQ-v2 implementation
│   ├── brecq.py                # BRECQ implementation
│   ├── smoothquant.py          # SmoothQuant implementation
│   ├── zeroq.py                # ZeroQ implementation
│   └── actquant.py             # Activation quantization utilities


Installation

Requirements

# Core dependencies
pip install torch torchvision
pip install pyyaml munch scikit-learn tensorboard matplotlib tqdm pandas einops huggingface_hub

Clone and Setup

git clone <repository-url>
cd QuantKAN

Quick Start

Training with QAT

# Train a KAN model on CIFAR-10 with 4-bit quantization
python main.py configs/cifar10_simplekagn.yaml

Post-Training Quantization (check ptq/README_PTQ.md)

# Run PTQ on a pretrained model
python runner.py ptq \
  --method gptq \
  --gptq_impl block \
  --gptq_mode block \
  --block_size 128 \
  --config configs/mnist_eff_fc.yaml \
  --ckpt out/MNIST_KAN_EFF_FC/MNIST_KAN_EFF_FC_best.pth.tar \
  --output_ckpt runs/kan_unified/kan_gptq_block_block_w4a32.pt \
  --nsamples 2048 \
  --damping 1e-4

Evaluation

# Evaluate a quantized model
python main.py configs/cifar10_simplekagn.yaml --eval --resume.path path/to/checkpoint.pth.tar

Configuration

Configuration is managed through YAML files. Create a custom config by copying and modifying an existing template.

Key Configuration Sections

# Experiment name
name: CIFAR10_KAGN_W4A4

# Dataset configuration
dataloader:
  dataset: cifar10          # mnist, cifar10, cifar100, tinyimagenet, imagenet
  num_classes: 10
  path: datasets
  batch_size: 128
  val_split: 0.0

# Model architecture
arch: kagn_simple_cifar10

# Quantization settings
quan:
  quantization: true
  act:
    mode: lsq               # lsq, dorefa, pact, qil, lsq_plus, dsq
    bit: 4
    per_channel: false
    symmetric: false
    all_positive: true
  weight:
    mode: lsq
    bit: 4
    per_channel: false
    symmetric: true

# Training settings
epochs: 250
optimizer:
  name: adamw
  learning_rate: 0.0001
  weight_decay: 0.00001

lr_scheduler:
  name: exp
  gamma: 0.975

# PTQ settings (for runner.py)
ptq:
  w_bit: 4
  a_bit_default: 8
  calib_batches: 32
  per_channel: true

Supported Models

KAN Architectures

Model Description Dataset
kan_mlp_mnist Simple KAN MLP MNIST
kan_mlp_mnist_fastkan FastKAN MLP MNIST
kagn_simple_cifar10 KAGN for CIFAR-10 CIFAR-10
kagn_v2 VGG-KAGN v2 CIFAR-100, ImageNet
kagn_v4 VGG-KAGN v4 ImageNet

Layer Types

  • KANLinear: B-spline based KAN layer (EfficientKAN)
  • FastKANLayer: RBF-based fast KAN layer
  • GRAMLayer: Gram polynomial based KAN layer
  • KAGNConv2DLayer: KAN convolutional layer

Quantization Methods

QAT Methods

Method Description
LSQ Learned Step Size Quantization
LSQ+ Enhanced LSQ with asymmetric quantization
PACT Parameterized Clipping Activation
DoReFa Gradient quantization
DSQ Differentiable Soft Quantization
QIL Quantization Interval Learning

PTQ Methods

Method Description
Uniform Min-max uniform quantization
GPTQ Optimal brain quantization with Hessian
AdaRound Adaptive rounding optimization
AWQ Activation-aware weight quantization
HAWQ-v2 Hessian-aware mixed precision
BRECQ Block reconstruction quantization
SmoothQuant Activation-weight migration
ZeroQ Zero-shot data generation

Usage Examples

Please use the config files in configs folder. For PTQ approaches please refer to the ReadME file in the ptq directory.

Example 1: Train KAGN on CIFAR-10 with LSQ 4-bit

python main.py configs/cifar10_simplekagn.yaml

Example 2: Resume Training from Checkpoint

# In the config file
resume:
  path: out/CIFAR10_KAGN_W4A4/checkpoint.pth.tar
  lean: false  # Full resume (optimizer, scheduler, etc.)
python main.py cifar10_simplekagn.yaml

Example 3: Evaluate Only

You can evaluate a checkpoint by making eval 'true' in the config file.

python main.py configs/cifar10_simplekagn.yaml \
  --eval \
  --resume.path out/CIFAR10_KAGN_W4A4/best.pth.tar

Example 4: Load Pretrained HuggingFace Weights

pretrained:
  load_from_hf: true

Example 5: Custom Bit-Width per Layer

quan:
  excepts:
    conv1:
      weight:
        bit: 8  # First layer at 8-bit
    fc:
      weight:
        bit: 8  # Last layer at 8-bit

Monitoring and Logging

TensorBoard

Training logs are saved to out/<experiment_name>/tb_runs/. View with:

tensorboard --logdir out/<experiment_name>/tb_runs/

Logged Metrics

  • Training/validation loss
  • Top-1 and Top-5 accuracy
  • Learning rate
  • Quantizer statistics (scale, clip ratios, MAE, MSE)
  • Gradient statistics

Checkpoints

Checkpoints are saved to out/<experiment_name>/:

  • <name>_best.pth.tar: Best validation accuracy
  • <name>_checkpoint.pth.tar: Latest checkpoint

Advanced Features

Mixed Precision Training

amp:
  enable: true
  dtype: fp16
  grad_scaler: true

Early Stopping

early_stopping:
  enable: true
  monitor: val_top1
  mode: max
  patience: 15
  min_delta: 0.01

Gradient Clipping

optimizer:
  grad_clip_norm: 1.0
  skip_nonfinite_grads: true

License

Please refer to the repository license file for licensing information.


Acknowledgments

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published