Train Diffusion Trajectory Forecaster model using the Waymo Motion dataset for mmls hse project
To create env and install dependecies:
conda create -n diffusion_tracker python=3.10
conda activate diffusion_tracker
pip install uv
uv sync
You can run the whole project inside Docker with GPU access instead of installing the Python environment locally.
- Build the image from the repository root:
docker build -t diffusion-trajectory-forecaster .- Make the helper script executable:
chmod +x scripts/docker_run.sh- Start an interactive shell inside the container:
scripts/docker_run.sh bash- Run project commands inside that shell:
uv run python train.py
uv run python train.py data=processed_v2
uv run python create_dataset.py data=processed_v1
uv run pytest
uv run python visualise_data.py
uv run dvc pull
uv run dvc pushHow it works:
- the repository is mounted into the container at
/app - your code, checkpoints, outputs, and local changes stay on the host machine
- the container uses its own virtual environment at
/opt/venv, so Docker does not recreate or modify your host.venv - the helper script runs the container with your host UID/GID so generated files remain writable by your user and Git can stage them
- container-side cache and auth files are stored in gitignored
.docker-cache/ - DVC secrets should stay in
.dvc/config.local, not.dvc/config
Notes:
- rebuild the image after this change so the container environment is created under
/opt/venv
To authenticate to google account for data downloading(one time):
- Apply for Waymo Open Dataset access.
- Install gcloud CLI
- Run
gcloud auth login <your_email>with the same email used for step 1. - Run
gcloud auth application-default login.
Processed datasets are tracked with DVC as directory artifacts. Git stores the .dvc metadata files, while the actual .pkl files live locally or in the configured DVC remote.
Remote configuration:
- keep the remote URL in
.dvc/config - keep credentials such as
access_key_idandsecret_access_keyin.dvc/config.local - do not commit
.dvc/config.local
Amazon S3 credentials setup:
uv run dvc remote list
uv run dvc remote modify --local myremote access_key_id <AWS_ACCESS_KEY_ID>
uv run dvc remote modify --local myremote secret_access_key <AWS_SECRET_ACCESS_KEY>Notes:
- the shared repository config already defines the default DVC remote URL and region
- the AWS identity used by DVC should have
s3:ListBucket,s3:GetObject,s3:PutObject, ands3:DeleteObject - if the bucket uses SSE-KMS encryption, the same identity also needs the matching KMS permissions
- if AWS CLI is already configured on the machine, DVC can also reuse that configuration
Pull datasets on a new machine:
uv run dvc pullPull one dataset explicitly:
uv run dvc pull data/processed_v1.dvc
uv run dvc pull data/processed_v2.dvc
uv run dvc pull data/baseline1.dvcPush updated artifacts:
uv run dvc pushUseful checks:
uv run dvc status
uv run dvc list . dataTo build processed train/val/test datasets from raw Waymo data:
uv run python create_dataset.py data=processed_v1 dataset_creation=defaultThis script:
- creates the processed
.pklfiles insidedata/processed_v1/ - runs
dvc addon the dataset directory once - stages the generated
.dvcfile in Git - runs
dvc pushif a working DVC remote is configured
Main configs:
- dataset artifact paths:
src/configs/data/processed_v1.yaml - dataset creation settings:
src/configs/dataset_creation/default.yaml
When creating a new dataset version, set both config groups:
uv run python create_dataset.py data=baseline1 dataset_creation=baseline1Why both are needed:
data=...chooses where the dataset is saved and which.dvcfile is updateddataset_creation=...chooses how the dataset is generated
If you reuse the same data=... name, the local files and the corresponding .dvc artifact are updated to the new content.
You can still override individual settings with Hydra, for example:
uv run python create_dataset.py dataset_creation.train.num_states=100
uv run python create_dataset.py dataset_creation.val.max_num_objects=16The repository currently supports two processed dataset formats:
- Legacy chunked pickle format
- WebDataset shard format
Legacy chunked pickle format:
- storage format value:
directory_chunks - output layout:
- one manifest file such as
train_processed_v1.pkl - one chunk directory such as
train_processed_v1.pkl.chunks/
- one manifest file such as
- training path:
src/data_module/legacy_dataset.py- custom chunk cache
- custom chunk sampler from
src/data_module/sampler.py
WebDataset shard format:
- storage format value:
webdataset - output layout:
- one shard directory such as
train_processed_v1.pkl.wds/ - many tar shards inside it, for example
shard-000000.tar - one
index.json
- one shard directory such as
- training path:
src/data_module/dataset.py- standard
webdataset.WebDataset(...) - no custom chunk sampler
Create the legacy chunked format:
uv run python create_dataset.py \
data=small_base_wo_vis \
dataset_creation=small_base_wo_vis \
storage_format=directory_chunksCreate the WebDataset format:
uv run python create_dataset.py \
data=small_base_wo_vis \
dataset_creation=small_base_wo_vis \
storage_format=webdatasetTrain with the legacy chunked format:
uv run python train.py \
data=small_base_wo_vis \
data.train.storage_format=directory_chunks \
data.val.storage_format=directory_chunks \
data.test.storage_format=directory_chunksTrain with the WebDataset format:
uv run python train.py \
data=small_base_wo_vis \
data.train.storage_format=webdataset \
data.val.storage_format=webdataset \
data.test.storage_format=webdatasetHow the training path is selected:
src/configs/data/*.yamlcontainsstorage_formatfor each split- if
storage_format=webdataset, training uses the WebDataset loader - otherwise it falls back to the legacy chunked dataset loader
The project uses Hydra configs from src/configs/. The script you run decides which config groups are read.
train.py:
- root config:
src/configs/ddpm_baseline.yaml - alternative training preset:
src/configs/ddpm_1.yaml - config groups used by training:
model=...: model class and model-specific hyperparameters such as architecture size, diffusion settings, learning rate, oracle settings, and checkpoint loadingdata=...: which processed dataset artifact is used for train/val/test, including the.pklpaths and matching.dvcfiledataloaders=...: batch size, shuffle, workers, and otherDataLoadersettingslogger=...: experiment logger backend and ClearML project/run settingsmetrics=...: train/validation metrics instantiated during trainingvisual=...: visualization and sampling/debug rendering settings used by the model during validation/loggingtrainer.*: top-level training loop settings such as epochs, train/val epoch length, gradient clipping, seed, logging mode, and checkpoint reload flag
ClearML setup:
- initialize ClearML credentials once with
uv run clearml-init - default logger config:
src/configs/logger/clearml.yaml logger.project_name: shared ClearML project for a family of experimentslogger.task_name: individual run name shown in ClearML- recommended pattern: keep one fixed
logger.project_nameand overridelogger.task_nameper run
Examples:
uv run clearml-init
uv run python train.py logger.project_name=my_experiments logger.task_name=exp_001
uv run python train.py logger.project_name=my_experiments logger.task_name=attn_v2_lr1e-4
uv run python train.py logger.task_name=baseline_processed_v2Examples:
uv run python train.py
uv run python train.py data=processed_v2
uv run python train.py model=diffusion_attn_2x trainer.num_epochs=230The project now has separate normal and debug model classes.
Usual attention model:
uv run python train.py model=diffusion_attnDebug attention model:
uv run python train.py model=diffusion_attn_debugYou can still override individual debug flags when using the debug model, for example:
uv run python train.py model=diffusion_attn_debug visual.debug_metrics=true
uv run python train.py model=diffusion_attn_debug visual.debug_denoiser_scale=true
uv run python train.py model=diffusion_attn_debug model.oracle_cfg.use_for_sampling=trueHow model selection works:
src/configs/model/diffusion_attn.yamlusessrc.models.DiffusionAttentionModelsrc/configs/model/diffusion_attn_debug.yamlusessrc.models.DiffusionAttentionDebugModel- the normal model inherits
BaseDiffusionModel - the debug model inherits
DebuggableBaseDiffusionModel
What the base model files do:
src/models/base_model.py: normal training, validation, loss, sampling, checkpointing, and metric entry pointssrc/models/base_model_debuggable.py: debug-only extensions such as extra shape/metric diagnostics, optional oracle paths, and fixed-noise sampling hookssrc/models/base_model_debug.py: small debug helper functions used only by the debug-capable modelsrc/models/base_model_oracle.py: oracle-only helper functions used by the debug-capable modelsrc/models/base_model_eval.py: shared metric-evaluation and visualization helpers used by both normal and debug models
Note:
- if you train the normal model, debug flags in
visual.*ormodel.oracle_cfg.*will not activate debug-only code paths
create_dataset.py:
- root config:
src/configs/create_dataset.yaml - config groups used for dataset creation:
data=...: where the generated dataset is written and which.dvcfile is updateddataset_creation=...: how the dataset is generated from raw Waymo data for each split
Inside dataset_creation=..., each split (train, val, test) controls:
raw_data_url: source TFRecord shard(s)waymax_conf_version: Waymo/Waymax dataset versionnum_states: how many scenes to processmax_num_objects: scene filtering limit before preprocessingextract_scene: whether to extract scene datapreprocessing.*: preprocessing parameters such as object cap, polyline limits, current index, point count, log transform, and history removal
Rule of thumb:
- change
data=...when you want a different saved dataset artifact - change
dataset_creation=...when you want different dataset contents - change
model=...,trainer.*,dataloaders.*,metrics,logger, orvisualwhen you want different training behavior