Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 0 additions & 29 deletions 2D_Classifier/README.md

This file was deleted.

10 changes: 0 additions & 10 deletions 2D_Classifier/train/requirements.txt

This file was deleted.

File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ WORKDIR /train
ENV DEBIAN_FRONTEND=noninteractive

RUN apt-get update && \
apt-get install -y --no-install-recommends build-essential git rsync software-properties-common ffmpeg libsm6 libxext6 && \
apt-get install -y --no-install-recommends build-essential git rsync ffmpeg libsm6 libxext6 && \
rm -rf /var/lib/apt/lists/*

ENV PYTHONPATH="/mlflow/projects/code/:$PYTHONPATH"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# 2D Classifier THIS NEEDS UPDATING

## Example Workflow:

#### 1. Adapt XNATDataImport.py for your data
Expand Down Expand Up @@ -58,3 +60,5 @@ This dockerfile sets up the Docker image that the MLOps run will utilise.

In the example this is just a simple environment running python version 3.10.
You will most likely need to adapt this for your project.


Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
import sys
import configparser
import json
import logging
import os
import multiprocessing
import json
import os

import mlflow
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from ray.air.integrations.mlflow import setup_mlflow
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from torch.cuda import is_available as cuda_available

from src.DataModule import DataModule
from src.Network import Network
from src.DataModule import label_dict
from src.XNATDataImport import XNATDataImport
from src.data_import_xnat import DataImportXNAT
from src.datamodule import DataModule, label_dict
from src.network import Network

logger = logging.getLogger(__name__)

Expand All @@ -36,7 +33,7 @@ def train(config):
else multiprocessing.cpu_count()
)

importer = XNATDataImport(
importer = DataImportXNAT(
xnat_configuration = xnat_configuration,
num_workers = num_workers
)
Expand All @@ -48,14 +45,6 @@ def train(config):
data = importer.xnat_image_download(raw_data)

# Set up mflow experiment
setup_mlflow(
tracking_uri=mlflow.get_tracking_uri(),
experiment_id=mlflow.get_experiment_by_name(
config["project"]["name"]
).experiment_id
if mlflow.get_experiment_by_name(config["project"]["name"])
else mlflow.create_experiment(config["project"]["name"]),
)
with mlflow.start_run(nested=True):
save_best_model = True

Expand Down Expand Up @@ -93,6 +82,10 @@ def train(config):
label_smoothing = float(config['params']['label_smoothing']),
)

# Callbacks
checkpoint_metric = config['params']['checkpoint_metric']
checkpoint_mode = "min" if checkpoint_metric == "val_loss" else "max"

# Callbacks
callbacks = []
callbacks.append(LearningRateMonitor(logging_interval="step"))
Expand All @@ -104,6 +97,14 @@ def train(config):
)
callbacks.append(checkpoint_callback)

early_stopping_callback = EarlyStopping(
monitor=checkpoint_metric,
patience=10,
mode=checkpoint_mode,
verbose=True,
)
callbacks.append(early_stopping_callback)

# configure trainer
trainer = pl.Trainer(
precision="32" if cuda_available() else "16",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@

import configparser
import logging
import multiprocessing
import os

import mlflow
import optuna
import pytorch_lightning as pl
from ray.air.integrations.mlflow import setup_mlflow
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping
from torch.cuda import is_available as cuda_available

from project.DataModule import DataModule
from project.DataModule import label_dict
from project.Network import Network
from project.XNATDataImport import XNATDataImport
from src.data_import_xnat import DataImportXNAT
from src.datamodule import DataModule, label_dict
from src.network import Network

import optuna
logger = logging.getLogger(__name__)

# Obtain hyperparameters for this trial
Expand Down Expand Up @@ -105,6 +101,14 @@ def objective(trial,data,config):
)
callbacks.append(checkpoint_callback)

early_stopping_callback = EarlyStopping(
monitor="val_loss",
patience=10,
mode="min",
verbose=True,
)
callbacks.append(early_stopping_callback)

# configure trainer
trainer = pl.Trainer(
precision="32" if cuda_available() else "16",
Expand Down Expand Up @@ -147,7 +151,7 @@ def tune(config):
else multiprocessing.cpu_count()
)

importer = XNATDataImport(
importer = DataImportXNAT(
xnat_configuration = xnat_configuration,
num_workers = num_workers
)
Expand All @@ -158,20 +162,11 @@ def tune(config):
# Download images from XNAT
data = importer.xnat_image_download(raw_data)

# Set up mflow experiment
setup_mlflow(
tracking_uri=mlflow.get_tracking_uri(),
experiment_id=mlflow.get_experiment_by_name(
config["project"]["name"]
).experiment_id
if mlflow.get_experiment_by_name(config["project"]["name"])
else mlflow.create_experiment(config["project"]["name"]),
)

mlflow.pytorch.autolog(log_models=False)

# Create optuna study (hyperparameter tuning framework)
study = optuna.create_study(study_name="scaphx-tune", direction="minimize")
study = optuna.create_study(study_name="project-tune", direction="minimize")
study.optimize(lambda trial: objective(trial, data, config), n_trials=50)

with open(('tune_log.txt'), 'w') as f:
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
from tqdm import tqdm
from typing import List

from utils.tools import DataBuilderXNAT
from monai.data import Dataset
from torch.utils.data import DataLoader
from xnat.mixin import ImageScanData, SubjectData

from src.transforms import load_xnat
from monai.data import Dataset
from torch.utils.data import DataLoader
from src.utils.tools import DataBuilderXNAT

logger = logging.getLogger(__name__)

class XNATDataImport():
class DataImportXNAT():

def __init__(self, xnat_configuration: dict = None, num_workers: int = 4, test_batch: int = 0,
n_month_data_window=9999, run_type: str='train'):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import logging
from collections import Counter
from typing import List, Optional

import mlflow
import numpy as np
import pytorch_lightning
import torch

from monai.data import CacheDataset, Dataset
from monai.data import pad_list_data_collate
from monai.data import CacheDataset, Dataset, pad_list_data_collate
from monai.transforms import Compose

from sklearn.model_selection import train_test_split
from torch.cuda import is_available
from torch.utils.data import DataLoader

from src.transforms import normalise, train_augment, output
from src.transforms.SafeWrapper import SafeWrapperTransform
from src.transforms import normalise, output, train_augment
from src.transforms.safe_wrapper import SafeWrapperTransform

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
import logging
import pytorch_lightning
from abc import ABC

import mlflow
import numpy as np
import pytorch_lightning
import torch
from abc import ABC
from monai.data import decollate_batch
from monai.transforms import (
AsDiscrete,
Compose,
Activations,
)
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay, recall_score
from monai.transforms import Activations, AsDiscrete, Compose
from sklearn.metrics import ClassificationReport, ConfusionMatrixDisplay, confusion_matrix, recall_score
from timm import create_model
from timm.data import Mixup
from torch.nn import CrossEntropyLoss
from torchmetrics import Accuracy, F1Score
from src.DataModule import label_dict
from torchmetrics.classification import MulticlassAUROC
import numpy as np

logger = logging.getLogger(__name__)
from src.datamodule import label_dict

logger = logging.getLogger(__name__)

class Network(pytorch_lightning.LightningModule, ABC):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
import torch
from monai.transforms import (
LoadImage,
SqueezeDimd,
EnsureChannelFirstd,
CropForegroundd,
Resized,
ScaleIntensityd,
CastToTyped,
RandFlipd,
RandZoomd,
RandRotated,
CropForegroundd,
EnsureChannelFirstd,
EnsureTyped,
LoadImage,
NormalizeIntensityd,
RandAdjustContrastd,
RandAffined,
RandCoarseDropoutd,
RandFlipd,
RandGaussianNoised,
RandGaussianSmoothd,
RandRotated,
RandScaleIntensityd,
RandAdjustContrastd,
RandCoarseDropoutd,
RandZoomd,
ResizeWithPadOrCropd,
ToTensord,
Resized,
ScaleIntensityd,
ScaleIntensityRangePercentilesd,
SelectItemsd,
EnsureTyped,
Spacingd,
SqueezeDimd,
ToTensord,
)

from src.transforms.LoadImageXNATd import LoadImageXNATd
from src.transforms.load_image_xnatd import LoadImageXNATd

def load_xnat(xnat_configuration: dict):
"""
Expand All @@ -47,9 +49,8 @@ def normalise(image_size):
EnsureChannelFirstd(keys=['image']),
CropForegroundd(keys=['image'], source_key='image'),
Resized(keys=['image'], size_mode='longest', spatial_size=image_size+20),
#Maybe limit top intensity in case of big spikes?
ScaleIntensityd(keys=["image"], minv=0.0, maxv=255.0),
CastToTyped(keys=["image"], dtype=torch.uint8),
ScaleIntensityRangePercentilesd(keys=["image"], lower=0, upper=99, b_min=0.0, b_max=255.0, clip=True),
CastToTyped(keys=["image"], dtype=torch.float32),
]

def train_augment(image_size):
Expand All @@ -59,20 +60,20 @@ def train_augment(image_size):
"""
return [
RandFlipd(keys=['image'], spatial_axis=0, prob=0.5),
RandZoomd(keys=['image'], prob=0.2, min_zoom=1.05,max_zoom=1.1),
RandRotated(keys=['image'], prob=0.2, range_x=0.4),
RandAffined(keys=['image'], prob=0.2, padding_mode='zeros'),
RandGaussianNoised(keys=['image'], prob=0.1, mean=0.0, std=0.1),
RandGaussianSmoothd(keys=['image'], prob=0.2, sigma_x=(0.5,1.0)),
RandScaleIntensityd(keys=['image'], prob=0.15, factors=(0.75,1.25)),
RandAdjustContrastd(keys=['image'], prob=0.1, gamma=(0.5,2), retain_stats=True, invert_image=True),
RandAdjustContrastd(keys=['image'], prob=0.3, gamma=(0.5,2), retain_stats=True, invert_image=False),
RandZoomd(keys=['image'], prob=0.4, min_zoom=1.05,max_zoom=1.1),
RandRotated(keys=['image'], prob=0.4, range_x=0.4),
RandAffined(keys=['image'], prob=0.3, padding_mode='zeros'),
RandGaussianNoised(keys=['image'], prob=0.3, mean=0.0, std=10.0),
RandGaussianSmoothd(keys=['image'], prob=0.35, sigma_x=(0.5,1.0), sigma_y=(0.5,1.0)),
RandScaleIntensityd(keys=['image'], prob=0.3, factors=(0.75,1.25)),
RandAdjustContrastd(keys=['image'], prob=0.2, gamma=(0.5,2), retain_stats=True, invert_image=True),
RandAdjustContrastd(keys=['image'], prob=0.4, gamma=(0.5,2), retain_stats=True, invert_image=False),
ResizeWithPadOrCropd(
keys=["image"],
spatial_size=(image_size,image_size),
mode='replicate'
),
RandCoarseDropoutd(keys=['image'], prob=0.5, fill_value=0, holes=8, max_holes=16, spatial_size=(10,10), max_spatial_size=(36,36)),
RandCoarseDropoutd(keys=['image'], prob=0.35, fill_value=0, holes=8, max_holes=16, spatial_size=(10,10), max_spatial_size=(15,15)),
]

def output(image_size):
Expand All @@ -86,6 +87,9 @@ def output(image_size):
mode='replicate'
),
ScaleIntensityd(keys=["image"], minv=0.0, maxv=1),
# Normalize with grayscale-averaged ImageNet stats (mean=0.449, std=0.226)
# Required for pretrained ImageNet models
NormalizeIntensityd(keys=["image"], subtrahend=0.449, divisor=0.226),
ToTensord(keys=['image', 'label']),
SelectItemsd(keys=['subject_id', 'image', 'label']),
EnsureTyped(keys=['image', 'label'], track_meta=False),
Expand Down
Loading
Loading