Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
309 changes: 144 additions & 165 deletions MAEs/PMAE_Eric_Reinhardt/train.py
Original file line number Diff line number Diff line change
@@ -1,192 +1,171 @@
import torch
from validate import validate
from models.masks import ParticleMask, SpecificParticleMask, KinematicMask
from argparse import ArgumentParser
import os
import logging
from tqdm import tqdm
from typing import List, Optional

# ------------------------
# 1. Setup Logging
# ------------------------
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logger = logging.getLogger(__name__)

# ------------------------
# 2. Helper Functions
# ------------------------
def get_mask_layer(mask_type: Optional[int], output_vars: int, particle_idx: Optional[int] = None):
"""Return the appropriate mask layer."""
if mask_type is None:
return None
dim = output_vars + (output_vars % 3)
if mask_type == 0:
if particle_idx is not None:
return SpecificParticleMask(dim, particle_idx)
return ParticleMask(dim)
return KinematicMask(mask_type)


def apply_trivial_resets(outputs: torch.Tensor, masked_inputs: torch.Tensor) -> torch.Tensor:
"""Apply softmax and reset trivial values for physics."""
mask_999 = (masked_inputs[:, :, 3] == 999).float()
outputs[:, :, 3:5] = torch.nn.functional.softmax(outputs[:, :, 3:5], dim=2)
outputs[:, :, 3] = (1 - mask_999) * outputs[:, :, 3] + mask_999 * 1
outputs[:, :, 4] = (1 - mask_999) * outputs[:, :, 4]
return outputs


# ------------------------
# 3. Main Training Function
# ------------------------
def train(
train_loader,
val_loader,
models: List[torch.nn.Module],
device: torch.device,
optimizer: torch.optim.Optimizer,
criterion,
model_type: str,
output_vars: int,
zero_padded: List[int] = [],
mask: Optional[int] = None,
epochs: range = range(1),
loss_min: float = 999.0,
save_path: str = './saved_models',
model_name: str = ''
) -> float:

os.makedirs(f'./outputs/{model_name}', exist_ok=True)

def train(train_loader, val_loader, models, device, optimizer, criterion, model_type, output_vars, zero_padded=[], mask=None, epochs:range=None, loss_min:int=999, save_path:str='./saved_models', model_name:str=''):
# Create an outputs folder to store config files
os.makedirs('./outputs/' + model_name, exist_ok=True)
if len(epochs) <= 0:
print("Num epochs <= 0")
logger.error("Number of epochs must be greater than 0")
return 0
if model_type == 'autoencoder':
tae = models[0]
for epoch in epochs:
tae.train()
running_loss = 0.0
for batch_idx, batch in enumerate(train_loader):
# Move the data to the device
inputs, _ = batch
inputs = inputs.to(device)
if mask is not None:
if mask == 0:
mask_layer = ParticleMask(output_vars+(output_vars%3))
else:
mask_layer = KinematicMask(mask)
# Mask input data
masked_inputs = mask_layer(inputs)

# Zero the gradients
optimizer.zero_grad()

# Forward pass
outputs = tae(masked_inputs)

outputs = torch.reshape(outputs, (outputs.size(0),
outputs.size(1) * outputs.size(2)))

# Flatten last axes and compute loss
if output_vars == 3:
inputs = inputs[:,:,:-1]
inputs = torch.reshape(inputs, (inputs.size(0),
inputs.size(1) * inputs.size(2)))
loss = criterion.compute_loss(outputs, inputs, zero_padded=[4])
elif output_vars == 4:
inputs = torch.reshape(inputs, (inputs.size(0),
inputs.size(1) * inputs.size(2)))
loss = criterion.compute_loss(outputs, inputs, zero_padded=zero_padded)

# Backward pass
loss.backward()

# Update the parameters
optimizer.step()

# Update running loss
running_loss += loss.item()

# Print running loss every 500 batches
if (batch_idx + 1) % 500 == 0:
print(f"Epoch [{epoch+1}/{epochs[-1] + 1}], Batch [{batch_idx+1}/{len(train_loader)}], Loss: {running_loss / 500:.4f}")
running_loss = 0.0
# Determine model handles
if 'classifier' in model_type:
tae, classifier = models[0], models[1]
else:
tae = models[0]

loss_min = validate(val_loader, models, device, criterion, model_type, output_vars, mask, epoch, epochs[-1] + 1, loss_min, save_path, model_name)
return loss_min
# ------------------------
# Main Epoch Loop
# ------------------------
for epoch in epochs:
running_loss = 0.0

elif model_type == 'classifier partial':
tae, classifier = models[0], models[1]
for epoch in epochs:
# Set modes
if model_type == 'autoencoder':
tae.train()
elif 'classifier' in model_type:
tae.eval()
classifier.train()
running_loss = 0.0
for batch_idx, batch in enumerate(train_loader):
# Move the data to the device
inputs, labels = batch
inputs = inputs.to(device)
labels = labels.to(device)
if mask is not None:
if mask == 0:
mask_layer = ParticleMask(output_vars+(output_vars%3))
else:
mask_layer = KinematicMask(mask)
# Mask input data
masked_inputs = mask_layer(inputs)

# Forward pass for autoencoder
outputs = tae(masked_inputs)

# Reset trivial values
mask_999 = (masked_inputs[:, :, 3] == 999).float()
outputs[:,:,3:5] = torch.nn.functional.softmax(outputs[:,:,3:5], dim=2)
outputs[:, :, 3] = (1 - mask_999) * outputs[:, :, 3] + mask_999 * 1
outputs[:, :, 4] = (1 - mask_999) * outputs[:, :, 4]
masked_inputs[:,:,3:5] = torch.nn.functional.softmax(masked_inputs[:,:,3:5], dim=2)
masked_inputs[:, :, 3] = (1 - mask_999) * masked_inputs[:, :, 3] + mask_999 * 1
masked_inputs[:, :, 4] = (1 - mask_999) * masked_inputs[:, :, 4]

# Flatten last axis
outputs = torch.reshape(outputs, (outputs.size(0),
outputs.size(1) * outputs.size(2)))
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs[-1]+1}")

masked_inputs = torch.reshape(masked_inputs, (masked_inputs.size(0),
masked_inputs.size(1) * masked_inputs.size(2)))

# Zero the gradients
optimizer.zero_grad()

# Forward pass for classifier
outputs_2 = classifier(torch.cat((outputs, masked_inputs), axis=1)).squeeze(1)

# Caclulate the loss
loss = criterion(outputs_2, labels.float())

# Backward pass
loss.backward()

# Update the parameters
optimizer.step()

# Update running loss
running_loss += loss.item()

# Print running loss every 500 batches
if (batch_idx + 1) % 500 == 0:
print(f"Epoch [{epoch+1}/{epochs[-1] + 1}], Batch [{batch_idx+1}/{len(train_loader)}], Loss: {running_loss / 500:.4f}")
running_loss = 0.0

loss_min = validate(val_loader, models, device, criterion, model_type, output_vars, mask, epoch, epochs[-1] + 1, loss_min, save_path, model_name)
return loss_min

elif model_type == 'classifier full':
tae, classifier = models[0], models[1]
for epoch in epochs:
tae.eval()
classifier.train()
running_loss = 0.0
for batch_idx, batch in enumerate(train_loader):
# Move the data to the device
for batch in pbar:
# Safe unpacking: handles datasets without labels
if model_type == 'autoencoder':
inputs = batch[0].to(device)
labels = None
else:
inputs, labels = batch
inputs = inputs.to(device)
labels = labels.to(device)
outputs = torch.zeros(inputs.size(0), 6, output_vars+(output_vars%3)).to(device)
for i in range(6):
if mask is not None:
if mask == 0:
mask_layer = SpecificParticleMask(output_vars+(output_vars%3), i)
else:
mask_layer = KinematicMask(mask)
# Mask input data
masked_inputs = mask_layer(inputs)
# Forward pass for autoencoder
temp_outputs = tae(masked_inputs)
outputs[:,i,:] = temp_outputs[:,i,:]

# Reset trivial values
mask_999 = (masked_inputs[:, :, 3] == 999).float()
outputs[:,:,3:5] = torch.nn.functional.softmax(outputs[:,:,3:5], dim=2)
outputs[:, :, 3] = (1 - mask_999) * outputs[:, :, 3] + mask_999 * 1
outputs[:, :, 4] = (1 - mask_999) * outputs[:, :, 4]
optimizer.zero_grad()

# Flatten last axes of tensors
outputs = torch.reshape(outputs, (outputs.size(0),
outputs.size(1) * outputs.size(2)))
# ------------------------
# Model Logic
# ------------------------
if model_type == 'autoencoder':
mask_layer = get_mask_layer(mask, output_vars)
masked_inputs = mask_layer(inputs) if mask_layer else inputs

inputs = torch.reshape(inputs, (inputs.size(0),
inputs.size(1) * inputs.size(2)))

# Zero the gradients
optimizer.zero_grad()
outputs = tae(masked_inputs)
outputs = outputs.flatten(1)

# Forward pass for classifier
outputs_2 = classifier(torch.cat((outputs, inputs), axis=1)).squeeze(1)
# Prepare targets
targets = inputs[:, :, :-1] if output_vars == 3 else inputs
targets = targets.flatten(1)

# Caclulate the loss
loss = criterion(outputs_2, labels.float())
loss = criterion.compute_loss(outputs, targets, zero_padded=zero_padded)

# Backward pass
loss.backward()
elif model_type == 'classifier partial':
mask_layer = get_mask_layer(mask, output_vars)
masked_inputs = mask_layer(inputs) if mask_layer else inputs

# Update the parameters
optimizer.step()
outputs = apply_trivial_resets(tae(masked_inputs), masked_inputs).flatten(1)
flat_masked = masked_inputs.flatten(1)

# Update running loss
running_loss += loss.item()
preds = classifier(torch.cat((outputs, flat_masked), dim=1)).squeeze(1)
loss = criterion(preds, labels.float())

# Print running loss every 500 batches
if (batch_idx + 1) % 500 == 0:
print(f"Epoch [{epoch+1}/{epochs[-1] + 1}], Batch [{batch_idx+1}/{len(train_loader)}], Loss: {running_loss / 500:.4f}")
running_loss = 0.0
elif model_type == 'classifier full':
batch_size = inputs.size(0)
dim = output_vars + (output_vars % 3)
outputs = torch.zeros(batch_size, 6, dim).to(device)

loss_min = validate(val_loader, models, device, criterion, model_type, output_vars, mask, epoch, epochs[-1] + 1, loss_min, save_path, model_name)
return loss_min
for i in range(6):
mask_layer = get_mask_layer(mask, output_vars, particle_idx=i)
masked_inputs = mask_layer(inputs) if mask_layer else inputs
temp_outputs = tae(masked_inputs)
outputs[:, i, :] = temp_outputs[:, i, :]

outputs = apply_trivial_resets(outputs, inputs).flatten(1)
flat_inputs = inputs.flatten(1)

preds = classifier(torch.cat((outputs, flat_inputs), dim=1)).squeeze(1)
loss = criterion(preds, labels.float())

else:
raise ValueError(f"Unknown model_type: {model_type}")

# ------------------------
# Backward Pass
# ------------------------
loss.backward()
optimizer.step()

running_loss += loss.item()
pbar.set_postfix(loss=f"{loss.item():.4f}")

# ------------------------
# Validation
# ------------------------
loss_min = validate(
val_loader,
models,
device,
criterion,
model_type,
output_vars,
mask,
epoch,
epochs[-1] + 1,
loss_min,
save_path,
model_name
)

return loss_min