Multi-label CNN classifier for chess puzzle themes and openings
1,616 labels | ~5M training samples | Best F1: 0.99
This model classifies chess board positions into themes (tactical patterns, game phases) and openings. Performance varies by category, with game-phase themes achieving near-perfect classification and tactical patterns showing moderate accuracy.
| Theme | F1 Score | Support |
|---|---|---|
| pawnEndgame | 0.99 | 307 |
| endgame | 0.98 | 4,795 |
| middlegame | 0.96 | 4,628 |
| queenEndgame | 0.95 | 119 |
| rookEndgame | 0.93 | 550 |
| opening | 0.88 | 553 |
| short | 0.69 | 5,237 |
| bishopEndgame | 0.64 | 127 |
| queenRookEndgame | 0.63 | 83 |
| crushing | 0.63 | 3,974 |
| mate | 0.60 | 2,877 |
| advantage | 0.54 | 2,957 |
| backRankMate | 0.44 | 338 |
| long | 0.41 | 2,497 |
| mateIn2 | 0.37 | 1,272 |
| Opening | F1 Score | Support |
|---|---|---|
| Sicilian Defense | 0.32 | 317 |
| Italian Game | 0.29 | 129 |
| Ruy Lopez | 0.19 | 70 |
| Russian Game | 0.17 | 35 |
| English Opening | 0.15 | 83 |
| French Defense | 0.13 | 131 |
| Caro-Kann Defense | 0.13 | 121 |
| Queens Pawn Game | 0.13 | 126 |
| Scandinavian Defense | 0.11 | 75 |
| Kings Indian Defense | 0.08 | 16 |
Opening recognition is more challenging due to the large number of similar variations and the diminishing positional signatures as games progress.
The F1 score is the harmonic mean of precision and recall, providing a balanced measure of classification performance:
F1 = 2 * (Precision * Recall) / (Precision + Recall)
Where:
- Precision = TP / (TP + FP) - proportion of positive predictions that are correct
- Recall = TP / (TP + FN) - proportion of actual positives correctly identified
- Support = number of true instances for each label in the test set
This chart shows F1 scores for all 58 chess puzzle themes, sorted by performance. Color gradient (red-yellow-green) maps to F1 score. Game-phase themes (endgame, middlegame, opening) achieve the highest scores, while tactical patterns like forks and pins show moderate performance.
Opening classification is inherently more challenging than theme classification for several reasons:
- High cardinality: 1,554 distinct opening labels vs. 62 themes
- Positional similarity: Many openings share similar early positions
- Transpositions: Different move orders can reach the same position
- Signal decay: Opening signatures weaken as games progress into middlegame
Full distribution of all 1,554 openings (x-axis labels omitted for clarity):
This scatter plot reveals the relationship between class frequency (support) and model performance (F1 score). Each point represents one of the 1,616 labels.
Key observations:
- High-support themes (>1000 samples) cluster at moderate F1 scores (0.4-0.7)
- Endgame themes achieve high F1 despite varying support levels
- Rare openings (support < 50) typically show F1 near zero
- The model generalizes best to structurally distinct patterns regardless of frequency
Precision-Recall (PR) curves visualize the tradeoff between precision and recall at different classification thresholds.
Key metrics:
- AUC-PR (Area Under the PR Curve): Summarizes overall classifier performance across all thresholds. Computed as the integral of precision over recall. Higher is better; 1.0 is perfect.
- Optimal threshold: The classification threshold that maximizes F1 score (marked on each plot)
Interpreting PR curves:
- Ideal curve: Hugs the top-right corner (high precision and recall simultaneously)
- Random classifier: Horizontal line at y = (positive samples / total samples)
Rather than using a single global threshold for all 1,616 labels, we compute per-class optimal thresholds from the PR curves. For each class, the optimal threshold is the point on its PR curve that maximizes F1 score. This adaptive approach significantly improves performance on rare classes that would otherwise be drowned out by a global threshold tuned for common classes.
The per-class thresholds are stored in analysis/f1/per_class_thresholds.csv and used during evaluation.
The 25 highest-performing labels sorted by F1 score. Top performers show near-ideal PR curves:
Horizontal reflection is used for class-conditional augmentation to address class imbalance in theme labels only. Opening labels are not augmented because opening theory is asymmetric (e.g., 1.e4 positions are fundamentally different from 1.d4 positions, and their mirror images do not preserve opening identity).
This augmentation preserves chess semantics for tactical themes (a mirrored fork is still a fork) while effectively doubling samples for underrepresented theme classes.
| Original | Reflected |
|---|---|
![]() |
![]() |
Type: CNN with Attention and Residual Blocks (~3.2M parameters)
| Component | Details |
|---|---|
| Input | 8x8 board (13 piece vocabulary) |
| Layers | 10 residual blocks with self-attention |
| Output | 1,616 labels (sigmoid multi-label) |
View full architecture diagram
- FEN parsing: Chess positions in FEN notation are converted to 8x8 integer tensors (0-12 piece vocabulary)
- Tensor caching: Preprocessed tensors are cached to disk for fast subsequent access
- Class-conditional augmentation: Underrepresented themes are augmented via horizontal board reflection
Each position can have multiple themes (e.g., "mate", "backRankMate", "short") and one opening. The model outputs independent sigmoid probabilities for each of 1,616 labels.
- Augmentation: Selective horizontal flipping for rare theme combinations
- Weighted loss: Optional per-class loss weighting based on frequency
- Adaptive thresholding: Per-class optimal thresholds derived from PR curves
apt update && apt install -y python3-dev python3-pip python3-virtualenv git
git clone git@github.com:jknoll/chess-theme-classifier.git
cd chess-theme-classifierpython -m venv .chess-theme-classifier
source .chess-theme-classifier/bin/activateNote: On systems where python is not found but python3 is available, you may need apt install python3.10-venv.
pip install --upgrade pip
pip install -r requirements.txtThe model is trained on the lichess puzzle database (~5M labeled positions as of 2025-06-24).
wget https://database.lichess.org/lichess_db_puzzle.csv.zst
sudo apt install -y zstd
unzstd lichess_db_puzzle.csv.zstTo generate the tensor cache from the downloaded CSV:
python create_full_dataset_cache.pyThe pre-processed dataset includes cached tensors for faster training.
Set up AWS credentials:
# Option A: Environment variables
export AWS_ACCESS_KEY_ID="your_access_key"
export AWS_SECRET_ACCESS_KEY="your_secret_key"
# Option B: AWS CLI
pip install awscli
aws configure
# Option C: Credentials file (~/.aws/credentials)
[default]
aws_access_key_id = your_access_key
aws_secret_access_key = your_secret_keyDownload:
python download_dataset.py
python download_dataset.py --output-dir custom_directory # custom location
python download_dataset.py --threads 8 --verify # parallel + verifyTest the training loop with a small dataset:
python train.py --local --test_modetorchrun --nproc_per_node=[NUM_GPUs] train.pypython train.py
python train.py --local # force local mode
python train.py --distributed # force distributed mode| Argument | Description |
|---|---|
--test_mode |
Run with smaller dataset for testing |
--wandb |
Enable Weights & Biases logging |
--project |
W&B project name (default: chess-theme-classifier) |
--name |
W&B run name |
--checkpoint_steps |
Steps between checkpoints (default: 50000) |
Generate per-class and global adaptive thresholds:
python evaluate_model_metrics.pyGenerate precision-recall curves (run after metrics):
python evaluate_model_metrics_pr_curves.py# Adaptive thresholding (default)
python evaluate_model_classification.py --num_samples=100
# Fixed threshold
python evaluate_model_classification.py --num_samples=100 --threshold=0.3
# Verbose output
python evaluate_model_classification.py --num_samples=50 --verbose
# Minimized output
python evaluate_model_classification.py --num_samples=100 --quiet
# Use cached tensors
python evaluate_model_classification.py --use_cache
# Specific checkpoint
python evaluate_model_classification.py --checkpoint=checkpoints/my_checkpoint.pth| Script | Purpose |
|---|---|
evaluate_model_fixed.py |
Maps between training/test indices, supports adaptive thresholding |
evaluate_model_simple.py |
Focused on key chess themes |
evaluate_model_cache.py |
Uses cached tensors directly |
See docs/model_evaluation.md for detailed documentation.
The augmented dataset uses the _conditional suffix:
lichess_db_puzzle_test.csv.tensors.pt_conditional
Generate augmentation for a dataset:
python -c "from dataset import ChessPuzzleDataset; ChessPuzzleDataset('lichess_db_puzzle_test.csv', class_conditional_augmentation=True)"python train_locally_single_gpu.py --test_mode --weighted_lossNote: Combining class-balanced dataset with weighted loss can cause unstable training (Jaccard similarity oscillations).
python -c 'import json; import pprint; with open("lichess_db_puzzle_test.csv.cooccurrence.json", "r") as f: pprint.pprint(json.load(f))'- Aggregates all TP, FP, FN across classes before calculating
- Gives equal weight to each sample-class pair
- Favors performance on common themes
- Calculates metrics per class, then averages
- Each theme contributes equally regardless of frequency
- Use when rare theme performance matters
- Weighted average of per-class metrics by frequency
- Balanced view reflecting dataset distribution
python -m pytest tests/Tests run automatically on push/PR via GitHub Actions (see .github/workflows/test.yml).
See tests/README.md for details.
chess-theme-classifier/
|-- train.py # Main training script
|-- model.py # CNN architecture
|-- dataset.py # Data loading and caching
|-- model_config.yaml # Model hyperparameters
|-- requirements.txt # Dependencies
|-- evaluate_model_*.py # Evaluation scripts
|-- create_full_dataset_cache.py
|-- download_dataset.py
|
|-- analysis/
| |-- f1/ # F1 charts and per-class thresholds
| |-- pr-curves/ # Precision-recall curves
| |-- scatter/ # F1 vs support plots
|
|-- checkpoints_pretrained/ # Pre-trained model checkpoints
|-- processed_lichess_puzzle_files/ # Cached tensors and datasets
|-- docs/ # Additional documentation
|-- tests/ # Unit tests
- Model Evaluation Guide
- Adaptive Thresholding
- Per-Class Adaptive Thresholding
- Precision-Recall Curves
- Class Imbalance Work
- Dataset Download from S3
- Checkpoint Management
train.py supports both local and cluster training. The train-isc.py script is deprecated.
The dataset class generates a .tensors.pt cache file on first access. Cache validation checks CSV modification time to ensure consistency. Typical speedup is 2-3x for dataset access.






