diff --git a/examples/mnist_vae.py b/examples/mnist_vae.py index bb112da..2b0d162 100644 --- a/examples/mnist_vae.py +++ b/examples/mnist_vae.py @@ -6,7 +6,7 @@ from torchvision.datasets import MNIST from torchvision.transforms import ToTensor -from torchsupport.training.vae import VAETraining +from torchsupport.training.vae import VAETraining, LoggingTypes, LoggerTypes def normalize(image): return (image - image.min()) / (image.max() - image.min()) @@ -59,11 +59,14 @@ def __init__(self, z=32): def forward(self, sample): return self.decoder(sample).view(-1, 1, 28, 28) + class MNISTVAETraining(VAETraining): def run_networks(self, data, *args): mean, logvar, reconstruction, data = super().run_networks(data, *args) - self.writer.add_image("target", normalize(data[0]), self.step_id) - self.writer.add_image("reconstruction", normalize(reconstruction[0].sigmoid()), self.step_id) + # self.writer.add_image("target", normalize(data[0]), self.step_id) + self.logger.log(LoggingTypes.IMAGE, "target", normalize(data[0]), self.step_id) + # self.writer.add_image("reconstruction", normalize(reconstruction[0].sigmoid()), self.step_id) + self.logger.log(LoggingTypes.IMAGE, "reconstruction", normalize(reconstruction[0].sigmoid()), self.step_id) return mean, logvar, reconstruction, data if __name__ == "__main__": diff --git a/torchsupport/training/vae.py b/torchsupport/training/vae.py index e409b04..e4a5315 100644 --- a/torchsupport/training/vae.py +++ b/torchsupport/training/vae.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod import numpy as np +from enum import Enum import torch +import mlflow from torch import nn from torch.nn import functional as func from torch.distributions import Normal, RelaxedOneHotCategorical @@ -15,6 +17,68 @@ from torchsupport.data.io import netwrite, to_device from torchsupport.data.collate import DataLoader + +class LoggerTypes(Enum): + TENSORBOARD = "tensorboard" + MLFLOW = "mlflow" + NONE = "none" + + +class LoggingTypes(Enum): + PARAM = "param" + METRIC = "metric" + IMAGE = "image" + MODEL = "model" + + +class LoggingSystem(object): + def __init__(self): + pass + + def log(self, type, key, content, step_id=None): + pass + + +class TensorboardLogger(LoggingSystem): + writer = None + + def __init__(self, network_id): + self.writer = SummaryWriter(network_id) + + def log(self, logging_type, key, content, step_id=None): + if logging_type == LoggingTypes.METRIC: + self.writer.add_scalar(key, content, step_id) + elif logging_type == LoggingTypes.IMAGE: + self.writer.add_image(key, content, step_id) + + +class MlflowLogger(LoggingSystem): + network_id = None + def __init__(self, network_id, mlflow_tracking_uri=None): + if mlflow_tracking_uri: + mlflow.set_tracking_uri(mlflow_tracking_uri) + mlflow.set_experiment(network_id) + mlflow.start_run() + self.network_id = network_id + + def log(self, logging_type, key, content, step_id=None): + if logging_type == LoggingTypes.METRIC: + assert type(key) is str, f"Logging a metric with MLflow requires a string key. Type is: ${type(key)}" + assert type(content) is float, f"Logging a metric with MLflow requires a float value. Type is: ${type(content)}" + if step_id is not None: + assert type(step_id) is int, f"Logging a metric with MLflow requires a integer step ID. Type is: ${type(step_id)}" + mlflow.log_metric(key, content, step_id) + elif logging_type == LoggingTypes.PARAM: + mlflow.log_param(key, content) + elif logging_type == LoggingTypes.IMAGE: + print("WARNING: Cannot log images with the torchsupport MLflow backend yet. Skipping.") + elif logging_type == LoggingTypes.MODEL: + import mlflow.pytorch + mlflow.pytorch.log_model(content, key) + else: + raise NotImplementedError(f"Invalid logging type: ${logging_type}") + + class AbstractVAETraining(Training): """Abstract base class for VAE training.""" checkpoint_parameters = Training.checkpoint_parameters + [ @@ -32,6 +96,7 @@ def __init__(self, networks, data, valid=None, network_name="network", report_interval=10, checkpoint_interval=1000, + logger_type=LoggerTypes.TENSORBOARD, verbose=False): """Generic training setup for variational autoencoders. @@ -73,7 +138,17 @@ def __init__(self, networks, data, valid=None, self.current_losses = {} self.network_name = network_name self.writer = SummaryWriter(network_name) - + if logger_type == LoggerTypes.TENSORBOARD: + self.logger = TensorboardLogger(network_name) + elif logger_type == LoggerTypes.MLFLOW: + self.logger = MlflowLogger(network_name) + elif logger_type == LoggerTypes.NONE: + self.logger = LoggingSystem() + else: + raise ValueError("Logger must be one of the types listed in LoggingTypes.") + self.logger.log(LoggingTypes.PARAM, "max_epochs", self.max_epochs) + self.logger.log(LoggingTypes.PARAM, "batch_size", self.batch_size) + self.logger.log(LoggingTypes.PARAM, "device", self.device) self.epoch_id = 0 self.step_id = 0 @@ -127,8 +202,10 @@ def step(self, data): if self.verbose: for loss_name in self.current_losses: loss_float = self.current_losses[loss_name] - self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id) - self.writer.add_scalar("total loss", float(loss_val), self.step_id) + self.logger.log(LoggingTypes.METRIC, f"{loss_name} loss", loss_float, self.step_id) + # self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id) + self.logger.log(LoggingTypes.METRIC, "total loss", float(loss_val), self.step_id) + # self.writer.add_scalar("total loss", float(loss_val), self.step_id) loss_val.backward() self.optimizer.step() @@ -166,6 +243,7 @@ def checkpoint(self): the_net = getattr(self, name) if isinstance(the_net, torch.nn.DataParallel): the_net = the_net.module + self.logger.log(LoggingTypes.MODEL, f"{name}-epoch-{self.epoch_id}-step-{self.step_id}", the_net) netwrite( the_net, f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch" @@ -174,7 +252,8 @@ def checkpoint(self): def validate(self, data): loss = self.valid_step(data) - self.writer.add_scalar("valid loss", loss, self.step_id) + # self.writer.add_scalar("valid loss", loss, self.step_id) + self.logger.log(LoggingTypes.METRIC, "valid loss", loss, self.step_id) self.each_validate() def train(self): @@ -479,20 +558,22 @@ def step(self, data): pass_through, *critic_args = self.run_critic(data, *netargs) loss_val = self.critic_loss(*critic_args) loss_val.backward(retain_graph=True) - self.writer.add_scalar("critic loss", float(loss_val), self.step_id) + self.logger.log(LoggingTypes.METRIC, "critic loss", float(loss_val), self.step_id) + # self.writer.add_scalar("critic loss", float(loss_val), self.step_id) self.critic_optimizer.step() self.generator_optimizer.zero_grad() generator_args = self.run_generator(data, *pass_through, *netargs) loss_val = self.generator_loss(*generator_args) loss_val.backward() - self.writer.add_scalar("generator loss", float(loss_val), self.step_id) + self.logger.log(LoggingTypes.METRIC, "generator loss", float(loss_val), self.step_id) self.generator_optimizer.step() if self.verbose: for loss_name in self.current_losses: loss_float = self.current_losses[loss_name] - self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id) + self.logger.log(LoggingTypes.METRIC, f"{loss_name} loss", loss_float, self.step_id) + # self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id) self.each_step() @@ -636,7 +717,7 @@ def step(self, data): ce = self.reconstruction_loss(reconstruction, data) tc = self.divergence_loss((mean, logvar), (self.decoder, sample)) loss_val = ce + tc - self.writer.add_scalar("total loss", float(loss_val), self.step_id) + self.logger.log(LoggingTypes.METRIC, "total loss", float(loss_val), self.step_id) self.optimizer.zero_grad() loss_val.backward() @@ -651,7 +732,7 @@ def step(self, data): discriminator_loss.backward() self.discriminator_optimizer.step() - self.writer.add_scalar("discriminator loss", float(discriminator_loss), self.step_id) + self.logger.log(LoggingTypes.METRIC, "discriminator loss", float(discriminator_loss), self.step_id) self.each_step() class ConditionalVAETraining(VAETraining):