Skip to content

OSU-STARLAB/QuantKAN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

QuantKAN: A Unified Quantization Framework for Kolmogorov–Arnold Networks

A comprehensive framework for training and quantizing Kolmogorov-Arnold Networks (KANs) with support for both Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ).


Table of Contents


Overview

QuantKAN provides a unified framework for quantizing Kolmogorov-Arnold Networks, enabling efficient deployment on resource-constrained hardware. The framework supports multiple KAN variants and a wide range of quantization techniques for both weights and activations.


Features

  • Multiple KAN Implementations:

    • EfficientKAN (B-spline based)
    • FastKAN (RBF-based)
    • PyKAN (Original KAN implementation)
    • KAGN (Gram-based KAN with convolutions)
  • Quantization-Aware Training (QAT):

    • LSQ (Learned Step Size Quantization)
    • LSQ+ (Enhanced LSQ)
    • PACT (Parameterized Clipping Activation)
    • DoReFa
    • DSQ (Differentiable Soft Quantization)
    • QIL (Quantization Interval Learning)
  • Post-Training Quantization (PTQ):

    • Uniform Quantization
    • GPTQ (Optimal Brain Quantization)
    • AdaRound (Adaptive Rounding)
    • AWQ (Activation-aware Weight Quantization)
    • HAWQ-v2 (Hessian-Aware Quantization)
    • BRECQ (Block Reconstruction Quantization)
    • SmoothQuant
    • ZeroQ (Zero-shot Quantization)
  • Training Features:

    • Mixed precision training (AMP)
    • Gradient accumulation
    • Early stopping
    • EMA (Exponential Moving Average)
    • Multiple learning rate schedulers
    • TensorBoard logging

Project Structure

├── main.py                 # Main training entry point (QAT)
├── runner.py               # PTQ runner for post-training quantization
├── ptq_eval.py             # PTQ evaluation script
├── process.py              # Training/validation loops
├── config.yaml             # Default configuration
├── logging.conf            # Logging configuration
│
├── configs/                # Experiment Configurations
│   ├── mnist_*.yaml            # MNIST experiment configs
│   ├── cifar10_*.yaml          # CIFAR-10 experiment configs
│   ├── cifar100_*.yaml         # CIFAR-100 experiment configs
│   ├── tinyimagenet_*.yaml     # TinyImageNet experiment configs
│   └── imagenet_*.yaml         # ImageNet experiment configs
│
├── datasets/               # Dataset storage directory
│
├── models/                 # Model Definitions
│   ├── model.py                # Model factory
│   ├── kan_models.py           # KAN model architectures
│   ├── vgg_kan_cifar.py        # VGG-KAN for CIFAR
│   └── vgg_kan_imagenet.py     # VGG-KAN for ImageNet
│
├── kans/                   # KAN Layer Implementations
│   ├── efficient_kan.py        # EfficientKAN layers
│   ├── fastkan.py              # FastKAN layers
│   ├── KANLayer.py             # PyKAN layers
│   ├── kagn_kagn_conv.py       # KAGN convolutional layers
│   └── conv_kagn.py            # KAN convolution implementations
│
├── qat/                    # QAT Quantizers
│   ├──quantizers 
|   |   ├── quantizer.py            # Base quantizer class
|   |   ├── lsq.py                  # Learned Step Size Quantization
|   |   ├── lsq_plus.py             # LSQ+ implementation
|   |   ├── pact.py                 # PACT quantizer
|   |   ├── dorefa.py               # DoReFa quantizer
|   |   ├── dsq.py                  # DSQ quantizer
|   |   └── qil.py                  # QIL quantizer
│   |
│   |
│   ├── quant_nn.py             # Quantized standard NN layers
│   ├── quant_efficient_kan.py  # Quantized EfficientKAN
│   ├── quant_fastkan.py        # Quantized FastKAN
│   ├── quant_kagn.py           # Quantized KAGN
│   └── quant_pykan.py          # Quantized PyKAN
│
├── ptq/                    # PTQ Methods
│   ├── uniform.py              # Uniform quantization
│   ├── gptq.py                 # GPTQ implementation
│   ├── gptq_strict.py          # Strict GPTQ variant
│   ├── adaround.py             # AdaRound implementation
│   ├── awq.py                  # AWQ implementation
│   ├── hawq_v2.py              # HAWQ-v2 implementation
│   ├── brecq.py                # BRECQ implementation
│   ├── smoothquant.py          # SmoothQuant implementation
│   ├── zeroq.py                # ZeroQ implementation
│   └── actquant.py             # Activation quantization utilities


Installation

Requirements

# Core dependencies
pip install torch torchvision
pip install pyyaml munch scikit-learn tensorboard matplotlib tqdm pandas einops huggingface_hub

Clone and Setup

git clone <repository-url>
cd QuantKAN

Quick Start

Training with QAT

# Train a KAN model on CIFAR-10 with 4-bit quantization
python main.py configs/cifar10_simplekagn.yaml

Post-Training Quantization (check ptq/README_PTQ.md)

# Run PTQ on a pretrained model
python runner.py ptq \
  --method gptq \
  --gptq_impl block \
  --gptq_mode block \
  --block_size 128 \
  --config configs/mnist_eff_fc.yaml \
  --ckpt out/MNIST_KAN_EFF_FC/MNIST_KAN_EFF_FC_best.pth.tar \
  --output_ckpt runs/kan_unified/kan_gptq_block_block_w4a32.pt \
  --nsamples 2048 \
  --damping 1e-4

Evaluation

# Evaluate a quantized model
python main.py configs/cifar10_simplekagn.yaml --eval --resume.path path/to/checkpoint.pth.tar

Configuration

Configuration is managed through YAML files. Create a custom config by copying and modifying an existing template.

Key Configuration Sections

# Experiment name
name: CIFAR10_KAGN_W4A4

# Dataset configuration
dataloader:
  dataset: cifar10          # mnist, cifar10, cifar100, tinyimagenet, imagenet
  num_classes: 10
  path: datasets
  batch_size: 128
  val_split: 0.0

# Model architecture
arch: kagn_simple_cifar10

# Quantization settings
quan:
  quantization: true
  act:
    mode: lsq               # lsq, dorefa, pact, qil, lsq_plus, dsq
    bit: 4
    per_channel: false
    symmetric: false
    all_positive: true
  weight:
    mode: lsq
    bit: 4
    per_channel: false
    symmetric: true

# Training settings
epochs: 250
optimizer:
  name: adamw
  learning_rate: 0.0001
  weight_decay: 0.00001

lr_scheduler:
  name: exp
  gamma: 0.975

# PTQ settings (for runner.py)
ptq:
  w_bit: 4
  a_bit_default: 8
  calib_batches: 32
  per_channel: true

Supported Models

KAN Architectures

Model Description Dataset
kan_mlp_mnist Simple KAN MLP MNIST
kan_mlp_mnist_fastkan FastKAN MLP MNIST
kagn_simple_cifar10 KAGN for CIFAR-10 CIFAR-10
kagn_v2 VGG-KAGN v2 CIFAR-100, ImageNet
kagn_v4 VGG-KAGN v4 ImageNet

Layer Types

  • KANLinear: B-spline based KAN layer (EfficientKAN)
  • FastKANLayer: RBF-based fast KAN layer
  • GRAMLayer: Gram polynomial based KAN layer
  • KAGNConv2DLayer: KAN convolutional layer

Quantization Methods

QAT Methods

Method Description
LSQ Learned Step Size Quantization
LSQ+ Enhanced LSQ with asymmetric quantization
PACT Parameterized Clipping Activation
DoReFa Gradient quantization
DSQ Differentiable Soft Quantization
QIL Quantization Interval Learning

PTQ Methods

Method Description
Uniform Min-max uniform quantization
GPTQ Optimal brain quantization with Hessian
AdaRound Adaptive rounding optimization
AWQ Activation-aware weight quantization
HAWQ-v2 Hessian-aware mixed precision
BRECQ Block reconstruction quantization
SmoothQuant Activation-weight migration
ZeroQ Zero-shot data generation

Usage Examples

Please use the config files in configs folder. For PTQ approaches please refer to the ReadME file in the ptq directory.

Example 1: Train KAGN on CIFAR-10 with LSQ 4-bit

python main.py configs/cifar10_simplekagn.yaml

Example 2: Resume Training from Checkpoint

# In the config file
resume:
  path: out/CIFAR10_KAGN_W4A4/checkpoint.pth.tar
  lean: false  # Full resume (optimizer, scheduler, etc.)
python main.py cifar10_simplekagn.yaml

Example 3: Evaluate Only

You can evaluate a checkpoint by making eval 'true' in the config file.

python main.py configs/cifar10_simplekagn.yaml \
  --eval \
  --resume.path out/CIFAR10_KAGN_W4A4/best.pth.tar

Example 4: Load Pretrained HuggingFace Weights

pretrained:
  load_from_hf: true

Example 5: Custom Bit-Width per Layer

quan:
  excepts:
    conv1:
      weight:
        bit: 8  # First layer at 8-bit
    fc:
      weight:
        bit: 8  # Last layer at 8-bit

Monitoring and Logging

TensorBoard

Training logs are saved to out/<experiment_name>/tb_runs/. View with:

tensorboard --logdir out/<experiment_name>/tb_runs/

Logged Metrics

  • Training/validation loss
  • Top-1 and Top-5 accuracy
  • Learning rate
  • Quantizer statistics (scale, clip ratios, MAE, MSE)
  • Gradient statistics

Checkpoints

Checkpoints are saved to out/<experiment_name>/:

  • <name>_best.pth.tar: Best validation accuracy
  • <name>_checkpoint.pth.tar: Latest checkpoint

Advanced Features

Mixed Precision Training

amp:
  enable: true
  dtype: fp16
  grad_scaler: true

Early Stopping

early_stopping:
  enable: true
  monitor: val_top1
  mode: max
  patience: 15
  min_delta: 0.01

Gradient Clipping

optimizer:
  grad_clip_norm: 1.0
  skip_nonfinite_grads: true

License

Please refer to the repository license file for licensing information.


Acknowledgments

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors