Skip to content

Implement Skill Generator Training Loop#10

Merged
tarikuamisganaw merged 1 commit intomainfrom
feature/Skill-Generator-Training-Loop
May 5, 2026
Merged

Implement Skill Generator Training Loop#10
tarikuamisganaw merged 1 commit intomainfrom
feature/Skill-Generator-Training-Loop

Conversation

@kirubel-Nigussie
Copy link
Copy Markdown
Collaborator

Description

Implements the training pipeline for the 2-head MLP SkillGenerator network. The trained model predicts skill outcomes (payoff, motives) from environment observations in milliseconds, enabling fast skill evaluation without running full environment episodes — critical for efficient CDS/PDS certification at scale.


Files Changed

File Change
generator/losses.py NEW — Combined MSE loss utility for the 2-head generator
generator/train_generator.py NEW — Full training script (DataLoader, epoch loop, model save, loss plot)
tests/test_generator_training.py NEW — 5 automated verification tests
requirements.txt Added matplotlib>=3.7.0 for loss plot generation
.gitignore Excluded data/raw/, models/, plots/ (generated artifacts)

What Was Implemented

generator/losses.py

  • GeneratorLoss class with configurable payoff_weight and motive_weight
  • Combined loss: total = MSE(pred_payoff, actual_payoff) + MSE(pred_motives, actual_motives)
  • breakdown() method returning individual loss components for logging
  • Designed for future extension (regularization, dynamic weighting)

generator/train_generator.py

  • SkillDataset — PyTorch Dataset that loads all .npz files from data/raw/
  • train_one_epoch() — single epoch training function
  • train() — full 50-epoch loop with Adam optimizer (lr=1e-3, batch_size=32)
  • Epoch-by-epoch console output (total, payoff, and motive losses)
  • Saves trained weights → models/generator.pt
  • Saves training curve → plots/generator_training.png

tests/test_generator_training.py

  • test_training_runs_without_error — full pipeline completes without raising
  • test_loss_decreases_over_epochs — final loss < initial loss
  • test_trained_model_beats_random — post-training MSE < untrained MSE on held-out data
  • test_model_saves_and_loads — save/load round-trip produces identical predictions
  • test_loss_plot_is_generated — plot file exists and is non-empty

@tarikuamisganaw tarikuamisganaw merged commit b666916 into main May 5, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants