SmolWorld is a research repository for training Generative World Models on the OpenAI VPT (Video PreTraining) dataset. It implements a two-stage pipeline:
- Visual Tokenizer (VQ-VAE): Compresses video frames into discrete tokens.
- World Model (Transformer): Predicts the next visual tokens autoregressively based on history and actions.
- VQ-VAE:
- Architecture: CNN Encoder/Decoder with ResNet or ConvNeXt blocks.
- Features: Progressive channel scaling (doubles in encoder, halves in decoder), EMA codebook updates.
- Quantization: Vector Quantization with learnable codebook (custom implementation).
- Config: Configurable latent grid, codebook size, and dimensions.
- World Model:
- Architecture: Decoder-only Transformer (Llama-style).
- Features: RMSNorm (Pre-norm), RoPE (Rotary Positional Embeddings), SwiGLU activation.
- Config: ~150M params (24 layers, 576 dim, 9 heads), 1024 context window.
- Training:
- Precision: BF16/FP16 Mixed Precision.
- Optimization: AdamW with ReduceLROnPlateau.
- Data: Interleaved Action/Image tokens, optimized lazy loading.
-
Clone the repository:
git clone https://github.com/yourusername/SmolWorld.git cd SmolWorld -
Install dependencies:
pip install -r pyproject.toml # Or manually: pip install torch numpy opencv-python-headless einops vector-quantize-pytorch flash-attn requests tqdm pandas matplotlib
Download and preprocess the OpenAI VPT dataset.
# Download and process a single shard to 64x64 resolution
python data/pipeline.py --resolution 64 --num_shards 1
# Download and process multiple random shards
python data/pipeline.py --resolution 64 --num_shards 10Output: data/processed/shard_X_res64.pt
Train the VQ-VAE to compress images into tokens.
# Train on a single file
python train_vqvae.py --data_path data/processed/shard_000000_res64.pt ...
# Train on all files in a directory (Multi-shard)
# Note: Automatically filters files matching *_res{resolution}.pt
python train_vqvae.py \
--data_path data/processed \
--resolution 64 \
--downsamples 2 \
--base_channels 32 \
--codebook_size 1024 \
--codebook_dim 256 \
--block_type convnext \
--channel_multiplier 2.0 \
--batch_size 32 \
--epochs 100
# Resume Training
# Resumes from the exact epoch, optimizer, and scheduler state
python train_vqvae.py --data_path data/processed --resume checkpoints/last_model.ptOutput: checkpoints/best_model.pt (best validation loss) and checkpoints/last_model.pt (latest state).
- Multi-Shard Support: Load data from a directory of
.ptfiles. - Resolution Filtering: Automatically selects files matching the target resolution.
- Configurable Architecture:
--block_type: Choose betweenresnet(with BatchNorm) orconvnext.--channel_multiplier: Float factor for channel scaling (e.g., 1.5 or 2.0).
- Resumption: Seamlessly resume training from checkpoints.
Train the Transformer to predict the future.
# Train using the trained VQ-VAE
python train_world_model.py --data_path data/processed/shard_res64.pt --vqvae_path checkpoints/best_model.pt --epochs 50 --batch_size 4 --grad_accum 4Output: checkpoints_wm/world_model_epoch_X.pt
Interactively generate future frames based on your input.
# Interactive Mode (Requires GUI)
# Controls: 'W' to move forward, 'Q' to quit.
python inference/play.py --vqvae_path checkpoints/best_model.pt --wm_path checkpoints_wm/world_model_epoch_50.pt --resolution 64 --downsamples 2
# Headless Mode (Save frames to disk)
python inference/play.py --vqvae_path ... --wm_path ... --headless --steps 100SmolWorld/
├── data/
│ ├── pipeline.py # Unified download & preprocess script
│ └── processed/ # Processed .pt shards
├── models/
│ ├── vqvae.py # VQ-VAE Architecture
│ ├── quantizer.py # Custom Vector Quantizer
│ └── transformer.py # World Model (Llama-style)
├── inference/
│ └── play.py # Interactive inference script
├── tests/ # Unit tests
├── train_vqvae.py # VQ-VAE Training Script
├── train_world_model.py # World Model Training Script
├── pyproject.toml # Dependencies
└── README.md # Documentation
Actions are mapped to a single integer (0-4355):
- Mouse: Quantized into 11x11 bins.
- Keys: WASD mapped to 9 states.
- Binary: Attack (L-Click), Jump (Space).
- Formula:
idx = mx + my*11 + attack*121 + jump*242 + move*484