This project implements a production-grade vit_base_patch16_224 model from timm for plant disease classification using Parameter-Efficient Fine-Tuning (PEFT). By leveraging Low-Rank Adaptation (LoRA), we drastically reduce the computational resources required for training while maintaining state-of-the-art accuracy on a 10-class subset of the PlantVillage dataset (Apple Scab, Apple Black Rot, Apple Cedar Rust, Apple Healthy, Corn Common Rust, Grape Black Rot, Grape Healthy, Potato Early Blight, Tomato Bacterial Spot, Tomato Healthy). The model correctly identifies plant health statesβvital for precision agriculture and early disease intervention.
The headline achievement of this pipeline is its efficiency: only 0.35% of the model parameters are trained. Instead of fine-tuning the entire network, LoRA injects trainable low-rank matrices into the target qkv attention layers across all 12 blocks. This enables rapid training on a single T4 GPU in Google Colab while preventing catastrophic forgetting of the foundational features learned by the base ViT model.
To bridge the gap between model accuracy and real-world trust, this project integrates Grad-CAM explainability. For AI systems deployed in critical domains like agriculture, simply outputting a prediction is not enough. Grad-CAM generates heatmaps highlighting exactly which visual features (e.g., leaf spots, discoloration) drove the model's decision, allowing agronomists and end-users to visually verify and trust the network's reasoning.
Input Image
β
βΌ
[ Preprocessing ] (Resize 256 β CenterCrop 224 β Normalize)
β
βΌ
ββββββββββββββββββββββββ
β ViT Base (Frozen) β βββ Pretrained timm vit_base_patch16_224
β ββββββββββββββββββββ β
β β Block 0 ... 11 β β
β β [Q, K, V] β β βββ LoRA Adapters injected (Trainable)
β ββββββββββββββββββββ β
ββββββββββββ¬ββββββββββββ
β
βΌ
[ Classifier Head ] (Trainable Linear Layer)
β
ββββ> [ Prediction (Class & Confidence) ]
β
βΌ
[ Grad-CAM Engine ] (Targets block 11 norm1, custom reshape)
β
βΌ
[ Heatmap Overlay ] (Visual explanation of model focus)
| Metric | Value |
|---|---|
| Test Accuracy | 100% |
| Number of Classes | 10 |
| Trainable Parameters | 302,602 (0.35%) |
| Total Parameters | 86,101,258 |
| Training Device | Google Colab T4 GPU |
| LoRA Rank | 8 |
| LoRA Alpha | 16 |
| Base Model | vit_base_patch16_224 |
How to interpret these heatmaps: The colors represent the model's spatial attention when making its prediction.
- Red/Yellow regions: High attention. The model focused heavily on these specific pixels (e.g., disease spots or lesions) to determine the class.
- Blue/Cool regions: Low attention. These areas were deemed irrelevant to the final prediction.
Note: Tomato Bacterial Spot shows some confusion with Potato Early Blight β both present as brown lesion patterns, a known challenge in plant pathology datasets.
lora-image-classifier/
βββ api/
β βββ main.py # FastAPI application with predict/explain endpoints
βββ app.py # Gradio demo app with Predict and Explain tabs
βββ notebooks/
β βββ colab_training.ipynb # Full Colab training notebook with Drive integration
βββ src/
β βββ dataset.py # PyTorch Datasets, dataloaders, and transforms
β βββ model.py # Model loading, freezing, and PEFT LoRA adapter setup
β βββ train.py # Training loop with validation and early stopping
β βββ evaluate.py # Evaluation logic, confusion matrix, and misclassifications
β βββ gradcam.py # Grad-CAM implementation with ViT reshape transform
β βββ download_dataset.py # Kaggle API script to fetch PlantVillage dataset
βββ config.py # Centralized configuration dataclass
βββ requirements.txt # Project dependencies
βββ .gitignore # Ignored files and output directories
βββ README.md # Project documentation
git clone https://github.com/anantha037/lora-image-classifier.git
cd lora-image-classifierpip install -r requirements.txtpython src/download_dataset.pyTraining is designed to be executed in Google Colab. Open notebooks/colab_training.ipynb in Colab, mount your Google Drive, and run the cells sequentially. The trained adapter weights will be saved to your Google Drive. Download them and place them in the outputs/checkpoints/ directory.
Once weights are present, you can generate the summary grid of Grad-CAM explanations:
python src/gradcam.pyLaunch the FastAPI server to serve predictions and explainability natively:
uvicorn api.main:app --reload --port 8000Interactive API docs will be available at: http://localhost:8000/docs
Launch the interactive Gradio demo for visual testing:
uvicorn api.main:app --reload --port 8080Then in a separate terminal:
python app.pyDemo will be available at: http://localhost:7860
| Endpoint | Method | Description | Input | Output |
|---|---|---|---|---|
/health |
GET | Check API and model status | None | JSON with status, classes count, and device |
/classes |
GET | List all supported disease classes | None | JSON list of class names |
/predict |
POST | Predict disease from an image | Form-Data Image | JSON with predicted class, confidences, and time |
/explain |
POST | Generate prediction + Grad-CAM overlay | Form-Data Image | JSON with prediction and base64 encoded PNG heatmap |
Predict Endpoint:
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: multipart/form-data' \
-F 'image=@data/plantvillage/Apple___Apple_scab/image_name.jpg;type=image/jpeg'Explain Endpoint:
curl -X 'POST' \
'http://localhost:8000/explain' \
-H 'accept: application/json' \
-H 'Content-Type: multipart/form-data' \
-F 'image=@data/plantvillage/Apple___Apple_scab/image_name.jpg;type=image/jpeg'| Library | Version | Purpose |
|---|---|---|
| PyTorch | 2.1.0+ | Deep learning framework and tensor operations |
| timm | 0.9.12 | Pre-trained Vision Transformer models |
| PEFT | 0.7.1 | Low-Rank Adaptation (LoRA) implementation |
| grad-cam | 1.5.0 | Explaining model predictions visually |
| FastAPI | 0.108.0 | High-performance async web API framework |
| scikit-learn | 1.3.2 | Stratified train/val/test data splitting |
| Gradio | 4.0+ | Interactive ML demo interface |
- Parameter-Efficient Fine-Tuning with LoRA: I learned how to successfully adapt large foundation models (like ViT) using fractions of the computational cost by freezing the base model and injecting small trainable rank-decomposition matrices.
- ViT Architecture and Attention Mechanism: I gained deep insights into how Vision Transformers divide images into patches, process sequence tokens, and route information through Multi-Head Self-Attention layers.
- Grad-CAM Reshape Transform: I tackled the specific challenge of implementing Grad-CAM for token-based architectures. This required writing custom transformations to bypass the CLS token and reconstruct 2D spatial feature grids from flat sequences.
- Production API Design with FastAPI Lifespan: I mastered serving PyTorch models in a robust production setting, specifically utilizing async context managers (
lifespan) to load the model securely at startup and avoid memory leaks. - Gradio Demo Interface: Built an interactive web demo using Gradio with separate tabs for prediction and Grad-CAM explanation, making the model accessible without any frontend code.
This project is licensed under the MIT License - see the LICENSE file for details.
