diff --git a/MAEs/PMAE_Eric_Reinhardt/train.py b/MAEs/PMAE_Eric_Reinhardt/train.py index e09a8b1..2a9270c 100644 --- a/MAEs/PMAE_Eric_Reinhardt/train.py +++ b/MAEs/PMAE_Eric_Reinhardt/train.py @@ -1,192 +1,171 @@ import torch from validate import validate from models.masks import ParticleMask, SpecificParticleMask, KinematicMask -from argparse import ArgumentParser import os +import logging +from tqdm import tqdm +from typing import List, Optional + +# ------------------------ +# 1. Setup Logging +# ------------------------ +logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') +logger = logging.getLogger(__name__) + +# ------------------------ +# 2. Helper Functions +# ------------------------ +def get_mask_layer(mask_type: Optional[int], output_vars: int, particle_idx: Optional[int] = None): + """Return the appropriate mask layer.""" + if mask_type is None: + return None + dim = output_vars + (output_vars % 3) + if mask_type == 0: + if particle_idx is not None: + return SpecificParticleMask(dim, particle_idx) + return ParticleMask(dim) + return KinematicMask(mask_type) + + +def apply_trivial_resets(outputs: torch.Tensor, masked_inputs: torch.Tensor) -> torch.Tensor: + """Apply softmax and reset trivial values for physics.""" + mask_999 = (masked_inputs[:, :, 3] == 999).float() + outputs[:, :, 3:5] = torch.nn.functional.softmax(outputs[:, :, 3:5], dim=2) + outputs[:, :, 3] = (1 - mask_999) * outputs[:, :, 3] + mask_999 * 1 + outputs[:, :, 4] = (1 - mask_999) * outputs[:, :, 4] + return outputs + + +# ------------------------ +# 3. Main Training Function +# ------------------------ +def train( + train_loader, + val_loader, + models: List[torch.nn.Module], + device: torch.device, + optimizer: torch.optim.Optimizer, + criterion, + model_type: str, + output_vars: int, + zero_padded: List[int] = [], + mask: Optional[int] = None, + epochs: range = range(1), + loss_min: float = 999.0, + save_path: str = './saved_models', + model_name: str = '' +) -> float: + + os.makedirs(f'./outputs/{model_name}', exist_ok=True) -def train(train_loader, val_loader, models, device, optimizer, criterion, model_type, output_vars, zero_padded=[], mask=None, epochs:range=None, loss_min:int=999, save_path:str='./saved_models', model_name:str=''): - # Create an outputs folder to store config files - os.makedirs('./outputs/' + model_name, exist_ok=True) if len(epochs) <= 0: - print("Num epochs <= 0") + logger.error("Number of epochs must be greater than 0") return 0 - if model_type == 'autoencoder': - tae = models[0] - for epoch in epochs: - tae.train() - running_loss = 0.0 - for batch_idx, batch in enumerate(train_loader): - # Move the data to the device - inputs, _ = batch - inputs = inputs.to(device) - if mask is not None: - if mask == 0: - mask_layer = ParticleMask(output_vars+(output_vars%3)) - else: - mask_layer = KinematicMask(mask) - # Mask input data - masked_inputs = mask_layer(inputs) - - # Zero the gradients - optimizer.zero_grad() - - # Forward pass - outputs = tae(masked_inputs) - - outputs = torch.reshape(outputs, (outputs.size(0), - outputs.size(1) * outputs.size(2))) - - # Flatten last axes and compute loss - if output_vars == 3: - inputs = inputs[:,:,:-1] - inputs = torch.reshape(inputs, (inputs.size(0), - inputs.size(1) * inputs.size(2))) - loss = criterion.compute_loss(outputs, inputs, zero_padded=[4]) - elif output_vars == 4: - inputs = torch.reshape(inputs, (inputs.size(0), - inputs.size(1) * inputs.size(2))) - loss = criterion.compute_loss(outputs, inputs, zero_padded=zero_padded) - - # Backward pass - loss.backward() - - # Update the parameters - optimizer.step() - - # Update running loss - running_loss += loss.item() - # Print running loss every 500 batches - if (batch_idx + 1) % 500 == 0: - print(f"Epoch [{epoch+1}/{epochs[-1] + 1}], Batch [{batch_idx+1}/{len(train_loader)}], Loss: {running_loss / 500:.4f}") - running_loss = 0.0 + # Determine model handles + if 'classifier' in model_type: + tae, classifier = models[0], models[1] + else: + tae = models[0] - loss_min = validate(val_loader, models, device, criterion, model_type, output_vars, mask, epoch, epochs[-1] + 1, loss_min, save_path, model_name) - return loss_min + # ------------------------ + # Main Epoch Loop + # ------------------------ + for epoch in epochs: + running_loss = 0.0 - elif model_type == 'classifier partial': - tae, classifier = models[0], models[1] - for epoch in epochs: + # Set modes + if model_type == 'autoencoder': + tae.train() + elif 'classifier' in model_type: tae.eval() classifier.train() - running_loss = 0.0 - for batch_idx, batch in enumerate(train_loader): - # Move the data to the device - inputs, labels = batch - inputs = inputs.to(device) - labels = labels.to(device) - if mask is not None: - if mask == 0: - mask_layer = ParticleMask(output_vars+(output_vars%3)) - else: - mask_layer = KinematicMask(mask) - # Mask input data - masked_inputs = mask_layer(inputs) - - # Forward pass for autoencoder - outputs = tae(masked_inputs) - - # Reset trivial values - mask_999 = (masked_inputs[:, :, 3] == 999).float() - outputs[:,:,3:5] = torch.nn.functional.softmax(outputs[:,:,3:5], dim=2) - outputs[:, :, 3] = (1 - mask_999) * outputs[:, :, 3] + mask_999 * 1 - outputs[:, :, 4] = (1 - mask_999) * outputs[:, :, 4] - masked_inputs[:,:,3:5] = torch.nn.functional.softmax(masked_inputs[:,:,3:5], dim=2) - masked_inputs[:, :, 3] = (1 - mask_999) * masked_inputs[:, :, 3] + mask_999 * 1 - masked_inputs[:, :, 4] = (1 - mask_999) * masked_inputs[:, :, 4] - # Flatten last axis - outputs = torch.reshape(outputs, (outputs.size(0), - outputs.size(1) * outputs.size(2))) + pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs[-1]+1}") - masked_inputs = torch.reshape(masked_inputs, (masked_inputs.size(0), - masked_inputs.size(1) * masked_inputs.size(2))) - - # Zero the gradients - optimizer.zero_grad() - - # Forward pass for classifier - outputs_2 = classifier(torch.cat((outputs, masked_inputs), axis=1)).squeeze(1) - - # Caclulate the loss - loss = criterion(outputs_2, labels.float()) - - # Backward pass - loss.backward() - - # Update the parameters - optimizer.step() - - # Update running loss - running_loss += loss.item() - - # Print running loss every 500 batches - if (batch_idx + 1) % 500 == 0: - print(f"Epoch [{epoch+1}/{epochs[-1] + 1}], Batch [{batch_idx+1}/{len(train_loader)}], Loss: {running_loss / 500:.4f}") - running_loss = 0.0 - - loss_min = validate(val_loader, models, device, criterion, model_type, output_vars, mask, epoch, epochs[-1] + 1, loss_min, save_path, model_name) - return loss_min - - elif model_type == 'classifier full': - tae, classifier = models[0], models[1] - for epoch in epochs: - tae.eval() - classifier.train() - running_loss = 0.0 - for batch_idx, batch in enumerate(train_loader): - # Move the data to the device + for batch in pbar: + # Safe unpacking: handles datasets without labels + if model_type == 'autoencoder': + inputs = batch[0].to(device) + labels = None + else: inputs, labels = batch inputs = inputs.to(device) labels = labels.to(device) - outputs = torch.zeros(inputs.size(0), 6, output_vars+(output_vars%3)).to(device) - for i in range(6): - if mask is not None: - if mask == 0: - mask_layer = SpecificParticleMask(output_vars+(output_vars%3), i) - else: - mask_layer = KinematicMask(mask) - # Mask input data - masked_inputs = mask_layer(inputs) - # Forward pass for autoencoder - temp_outputs = tae(masked_inputs) - outputs[:,i,:] = temp_outputs[:,i,:] - # Reset trivial values - mask_999 = (masked_inputs[:, :, 3] == 999).float() - outputs[:,:,3:5] = torch.nn.functional.softmax(outputs[:,:,3:5], dim=2) - outputs[:, :, 3] = (1 - mask_999) * outputs[:, :, 3] + mask_999 * 1 - outputs[:, :, 4] = (1 - mask_999) * outputs[:, :, 4] + optimizer.zero_grad() - # Flatten last axes of tensors - outputs = torch.reshape(outputs, (outputs.size(0), - outputs.size(1) * outputs.size(2))) + # ------------------------ + # Model Logic + # ------------------------ + if model_type == 'autoencoder': + mask_layer = get_mask_layer(mask, output_vars) + masked_inputs = mask_layer(inputs) if mask_layer else inputs - inputs = torch.reshape(inputs, (inputs.size(0), - inputs.size(1) * inputs.size(2))) - - # Zero the gradients - optimizer.zero_grad() + outputs = tae(masked_inputs) + outputs = outputs.flatten(1) - # Forward pass for classifier - outputs_2 = classifier(torch.cat((outputs, inputs), axis=1)).squeeze(1) + # Prepare targets + targets = inputs[:, :, :-1] if output_vars == 3 else inputs + targets = targets.flatten(1) - # Caclulate the loss - loss = criterion(outputs_2, labels.float()) + loss = criterion.compute_loss(outputs, targets, zero_padded=zero_padded) - # Backward pass - loss.backward() + elif model_type == 'classifier partial': + mask_layer = get_mask_layer(mask, output_vars) + masked_inputs = mask_layer(inputs) if mask_layer else inputs - # Update the parameters - optimizer.step() + outputs = apply_trivial_resets(tae(masked_inputs), masked_inputs).flatten(1) + flat_masked = masked_inputs.flatten(1) - # Update running loss - running_loss += loss.item() + preds = classifier(torch.cat((outputs, flat_masked), dim=1)).squeeze(1) + loss = criterion(preds, labels.float()) - # Print running loss every 500 batches - if (batch_idx + 1) % 500 == 0: - print(f"Epoch [{epoch+1}/{epochs[-1] + 1}], Batch [{batch_idx+1}/{len(train_loader)}], Loss: {running_loss / 500:.4f}") - running_loss = 0.0 + elif model_type == 'classifier full': + batch_size = inputs.size(0) + dim = output_vars + (output_vars % 3) + outputs = torch.zeros(batch_size, 6, dim).to(device) - loss_min = validate(val_loader, models, device, criterion, model_type, output_vars, mask, epoch, epochs[-1] + 1, loss_min, save_path, model_name) - return loss_min \ No newline at end of file + for i in range(6): + mask_layer = get_mask_layer(mask, output_vars, particle_idx=i) + masked_inputs = mask_layer(inputs) if mask_layer else inputs + temp_outputs = tae(masked_inputs) + outputs[:, i, :] = temp_outputs[:, i, :] + + outputs = apply_trivial_resets(outputs, inputs).flatten(1) + flat_inputs = inputs.flatten(1) + + preds = classifier(torch.cat((outputs, flat_inputs), dim=1)).squeeze(1) + loss = criterion(preds, labels.float()) + + else: + raise ValueError(f"Unknown model_type: {model_type}") + + # ------------------------ + # Backward Pass + # ------------------------ + loss.backward() + optimizer.step() + + running_loss += loss.item() + pbar.set_postfix(loss=f"{loss.item():.4f}") + + # ------------------------ + # Validation + # ------------------------ + loss_min = validate( + val_loader, + models, + device, + criterion, + model_type, + output_vars, + mask, + epoch, + epochs[-1] + 1, + loss_min, + save_path, + model_name + ) + + return loss_min