Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions examples/mnist_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

from torchsupport.training.vae import VAETraining
from torchsupport.training.vae import VAETraining, LoggingTypes, LoggerTypes

def normalize(image):
return (image - image.min()) / (image.max() - image.min())
Expand Down Expand Up @@ -59,11 +59,14 @@ def __init__(self, z=32):
def forward(self, sample):
return self.decoder(sample).view(-1, 1, 28, 28)


class MNISTVAETraining(VAETraining):
def run_networks(self, data, *args):
mean, logvar, reconstruction, data = super().run_networks(data, *args)
self.writer.add_image("target", normalize(data[0]), self.step_id)
self.writer.add_image("reconstruction", normalize(reconstruction[0].sigmoid()), self.step_id)
# self.writer.add_image("target", normalize(data[0]), self.step_id)
self.logger.log(LoggingTypes.IMAGE, "target", normalize(data[0]), self.step_id)
# self.writer.add_image("reconstruction", normalize(reconstruction[0].sigmoid()), self.step_id)
self.logger.log(LoggingTypes.IMAGE, "reconstruction", normalize(reconstruction[0].sigmoid()), self.step_id)
return mean, logvar, reconstruction, data

if __name__ == "__main__":
Expand Down
99 changes: 90 additions & 9 deletions torchsupport/training/vae.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC, abstractmethod
import numpy as np
from enum import Enum
import torch
import mlflow
from torch import nn
from torch.nn import functional as func
from torch.distributions import Normal, RelaxedOneHotCategorical
Expand All @@ -15,6 +17,68 @@
from torchsupport.data.io import netwrite, to_device
from torchsupport.data.collate import DataLoader


class LoggerTypes(Enum):
TENSORBOARD = "tensorboard"
MLFLOW = "mlflow"
NONE = "none"


class LoggingTypes(Enum):
PARAM = "param"
METRIC = "metric"
IMAGE = "image"
MODEL = "model"


class LoggingSystem(object):
def __init__(self):
pass

def log(self, type, key, content, step_id=None):
pass


class TensorboardLogger(LoggingSystem):
writer = None

def __init__(self, network_id):
self.writer = SummaryWriter(network_id)

def log(self, logging_type, key, content, step_id=None):
if logging_type == LoggingTypes.METRIC:
self.writer.add_scalar(key, content, step_id)
elif logging_type == LoggingTypes.IMAGE:
self.writer.add_image(key, content, step_id)


class MlflowLogger(LoggingSystem):
network_id = None
def __init__(self, network_id, mlflow_tracking_uri=None):
if mlflow_tracking_uri:
mlflow.set_tracking_uri(mlflow_tracking_uri)
mlflow.set_experiment(network_id)
mlflow.start_run()
self.network_id = network_id

def log(self, logging_type, key, content, step_id=None):
if logging_type == LoggingTypes.METRIC:
assert type(key) is str, f"Logging a metric with MLflow requires a string key. Type is: ${type(key)}"
assert type(content) is float, f"Logging a metric with MLflow requires a float value. Type is: ${type(content)}"
if step_id is not None:
assert type(step_id) is int, f"Logging a metric with MLflow requires a integer step ID. Type is: ${type(step_id)}"
mlflow.log_metric(key, content, step_id)
elif logging_type == LoggingTypes.PARAM:
mlflow.log_param(key, content)
elif logging_type == LoggingTypes.IMAGE:
print("WARNING: Cannot log images with the torchsupport MLflow backend yet. Skipping.")
elif logging_type == LoggingTypes.MODEL:
import mlflow.pytorch
mlflow.pytorch.log_model(content, key)
else:
raise NotImplementedError(f"Invalid logging type: ${logging_type}")


class AbstractVAETraining(Training):
"""Abstract base class for VAE training."""
checkpoint_parameters = Training.checkpoint_parameters + [
Expand All @@ -32,6 +96,7 @@ def __init__(self, networks, data, valid=None,
network_name="network",
report_interval=10,
checkpoint_interval=1000,
logger_type=LoggerTypes.TENSORBOARD,
verbose=False):
"""Generic training setup for variational autoencoders.

Expand Down Expand Up @@ -73,7 +138,17 @@ def __init__(self, networks, data, valid=None,
self.current_losses = {}
self.network_name = network_name
self.writer = SummaryWriter(network_name)

if logger_type == LoggerTypes.TENSORBOARD:
self.logger = TensorboardLogger(network_name)
elif logger_type == LoggerTypes.MLFLOW:
self.logger = MlflowLogger(network_name)
elif logger_type == LoggerTypes.NONE:
self.logger = LoggingSystem()
else:
raise ValueError("Logger must be one of the types listed in LoggingTypes.")
self.logger.log(LoggingTypes.PARAM, "max_epochs", self.max_epochs)
self.logger.log(LoggingTypes.PARAM, "batch_size", self.batch_size)
self.logger.log(LoggingTypes.PARAM, "device", self.device)
self.epoch_id = 0
self.step_id = 0

Expand Down Expand Up @@ -127,8 +202,10 @@ def step(self, data):
if self.verbose:
for loss_name in self.current_losses:
loss_float = self.current_losses[loss_name]
self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id)
self.writer.add_scalar("total loss", float(loss_val), self.step_id)
self.logger.log(LoggingTypes.METRIC, f"{loss_name} loss", loss_float, self.step_id)
# self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id)
self.logger.log(LoggingTypes.METRIC, "total loss", float(loss_val), self.step_id)
# self.writer.add_scalar("total loss", float(loss_val), self.step_id)

loss_val.backward()
self.optimizer.step()
Expand Down Expand Up @@ -166,6 +243,7 @@ def checkpoint(self):
the_net = getattr(self, name)
if isinstance(the_net, torch.nn.DataParallel):
the_net = the_net.module
self.logger.log(LoggingTypes.MODEL, f"{name}-epoch-{self.epoch_id}-step-{self.step_id}", the_net)
netwrite(
the_net,
f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
Expand All @@ -174,7 +252,8 @@ def checkpoint(self):

def validate(self, data):
loss = self.valid_step(data)
self.writer.add_scalar("valid loss", loss, self.step_id)
# self.writer.add_scalar("valid loss", loss, self.step_id)
self.logger.log(LoggingTypes.METRIC, "valid loss", loss, self.step_id)
self.each_validate()

def train(self):
Expand Down Expand Up @@ -479,20 +558,22 @@ def step(self, data):
pass_through, *critic_args = self.run_critic(data, *netargs)
loss_val = self.critic_loss(*critic_args)
loss_val.backward(retain_graph=True)
self.writer.add_scalar("critic loss", float(loss_val), self.step_id)
self.logger.log(LoggingTypes.METRIC, "critic loss", float(loss_val), self.step_id)
# self.writer.add_scalar("critic loss", float(loss_val), self.step_id)
self.critic_optimizer.step()

self.generator_optimizer.zero_grad()
generator_args = self.run_generator(data, *pass_through, *netargs)
loss_val = self.generator_loss(*generator_args)
loss_val.backward()
self.writer.add_scalar("generator loss", float(loss_val), self.step_id)
self.logger.log(LoggingTypes.METRIC, "generator loss", float(loss_val), self.step_id)
self.generator_optimizer.step()

if self.verbose:
for loss_name in self.current_losses:
loss_float = self.current_losses[loss_name]
self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id)
self.logger.log(LoggingTypes.METRIC, f"{loss_name} loss", loss_float, self.step_id)
# self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id)

self.each_step()

Expand Down Expand Up @@ -636,7 +717,7 @@ def step(self, data):
ce = self.reconstruction_loss(reconstruction, data)
tc = self.divergence_loss((mean, logvar), (self.decoder, sample))
loss_val = ce + tc
self.writer.add_scalar("total loss", float(loss_val), self.step_id)
self.logger.log(LoggingTypes.METRIC, "total loss", float(loss_val), self.step_id)

self.optimizer.zero_grad()
loss_val.backward()
Expand All @@ -651,7 +732,7 @@ def step(self, data):
discriminator_loss.backward()
self.discriminator_optimizer.step()

self.writer.add_scalar("discriminator loss", float(discriminator_loss), self.step_id)
self.logger.log(LoggingTypes.METRIC, "discriminator loss", float(discriminator_loss), self.step_id)
self.each_step()

class ConditionalVAETraining(VAETraining):
Expand Down