diff --git a/.gitignore b/.gitignore index 03e8789..1804d68 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,7 @@ Training/DeepSurv*.ckpt *.txt wandb/ Training/*.xml +Configs/Classification/*.ini +Configs/Regression/*.ini +*.png +Training/test*.py diff --git a/DataGenerator/.DataGenerator.py.swp b/DataGenerator/.DataGenerator.py.swp deleted file mode 100644 index e69de29..0000000 diff --git a/DataGenerator/DataGenerator.py b/DataGenerator/DataGenerator.py index 2988877..1f8e159 100644 --- a/DataGenerator/DataGenerator.py +++ b/DataGenerator/DataGenerator.py @@ -24,11 +24,10 @@ class DataGenerator(torch.utils.data.Dataset): - def __init__(self, SubjectList,config=None, keys=['CT'], transform=None, inference=False, + def __init__(self, SubjectList, config=None, keys=['CT'], transform=None, inference=False, clinical_cols=None, session=None, **kwargs): super().__init__() self.config = config - self.session = session self.SubjectList = SubjectList self.keys = keys self.transform = transform @@ -44,125 +43,75 @@ def __getitem__(self, i): meta = {} subject_id = self.SubjectList.loc[i, 'subjectid'] slabel = self.SubjectList.loc[i, 'subject_label'] + data['slabel'] = slabel ## Load CT if 'CT' in self.keys: CTPath = self.SubjectList.loc[i, 'CT_Path'] - if self.config['DATA']['Nifty']: - CTPath = Path(CTPath, 'ct.nii.gz') - data['CT'], meta['CT'] = LoadImage()(CTPath) - else: - data['CT'], meta['CT'] = LoadImage()(CTPath) - CTSession = ReadDicom(CTPath) - CTArray = sitk.GetArrayFromImage(CTSession) - if not(CTArray.shape == data['CT'].shape): - CTArray = CTArray.transpose((2, 1, 0)) - CTArray = np.flip(CTArray, axis=2) - mCT = MetaTensor(CTArray.copy(), meta=meta['CT']) - data['CT'] = mCT + CTPath = Path(CTPath, 'CT.nii.gz') + data['CT'], meta['CT'] = LoadImage()(CTPath) ## Load Dose if 'Dose' in self.keys: DosePath = self.SubjectList.loc[i, 'Dose_Path'] - if self.config['DATA']['Nifty']: - DosePath = Path(DosePath, 'dose.nii.gz') + DosePath = Path(DosePath, 'Dose.nii.gz') data['Dose'], meta['Dose'] = LoadImage()(DosePath) - data['Dose'] = data['Dose']/67 - if not self.config['DATA']['Nifty']: - data['Dose'] = data['Dose'] * np.double(meta['Dose']['3004|000e'])/67 + data['Dose'] = data['Dose'] / 67 ## Probably need to make it a variable ## Load PET if 'PET' in self.keys: PETPath = self.SubjectList.loc[i, 'PET_Path'] if self.config['DATA']['Nifty']: - PETPath = Path(PETPath, 'dose.nii.gz') + PETPath = Path(PETPath, 'pet.nii.gz') data['PET'], meta['PET'] = LoadImage()(PETPath) ## Load Mask if 'Structs' in self.keys: RSPath = self.SubjectList.loc[i, 'Structs_Path'] - if self.config['DATA']['Nifty']: - #for roi in self.config['DATA']['Structs']: - # data['Struct_' + roi], meta['Struct_' + roi] = LoadImage()(Path(RSPath,roi+'.nii.gz')) - # dt = distance_transform_edt(data['Struct_' + roi]) - # data['Struct_' + roi] = MetaTensor(dt, meta = meta['CT']) - masks_img = np.zeros_like(data['CT']) - masks_img = get_nii_masks(slabel, masks_img, RSPath, self.config['DATA']['Structs']) - masks_img = MetaTensor(masks_img.copy(), meta=meta['CT']) - data['Structs'] = masks_img - else: - ## mask in multichannel - RS = RTStructBuilder.create_from(dicom_series_path=CTPath, rt_struct_path=RSPath) - #roi_names = RS.get_roi_names() - #for roi in self.config['DATA']['Structs']: - # if roi in roi_names: - # mask_img = RS.get_roi_mask_by_name(roi) - # mask_img = distance_transform_edt(mask_img) - # else: - # message = "No ROI of name " + self.targetROI + " found in RTStruct" - # raise ValueError(message) - # mask_img = np.rot90(mask_img) - # mask_img = np.flip(mask_img, 2) - # mask_img = np.flip(mask_img, 0) - # mask = MetaTensor(mask_img.copy(), meta = meta['CT']) - # data['Struct_' + roi] = mask - - ### masks images - masks_img = np.zeros_like(data['CT']) - masks_img = get_RS_masks(slabel, CTPath, masks_img, RSPath, self.config['DATA']['Structs']) - masks_img = np.rot90(masks_img) - masks_img = np.flip(masks_img, 0) - masks_img = MetaTensor(masks_img.copy(), meta = meta['CT']) - data['Structs'] = masks_img - else: - data['Structs'] = np.ones_like(data['CT']) ## No ROI target defined + data['Structs'], meta['Structs'] = LoadImage()(Path(RSPath, self.config['DATA']['Structs'])) - ## Apply transforms on all if self.transform: data = self.transform(data) - # mask_imgs = np.zeros_like(CTArray) - # for key in data.keys(): - # if 'Mask' in key: - # mask_imgs = mask_imgs + data[key] - - #for key in data.keys(): - # data[key] = get_masked_img_voxel(data[key], data['Mask']) - # Decide between multi-branch single-channel/multi-channel single-branch if self.config['DATA']['Multichannel']: - old_keys = list(data.keys()) - data['Image'] = np.concatenate([data[key] for key in data.keys()], axis=0) + old_keys = list(self.keys) + data['Image'] = np.concatenate([data[key] for key in old_keys], axis=0) for key in old_keys: data.pop(key) else: - data.pop('Structs') ## No need for mask in single-channel multi-branch - - #data = ResizeWithPadOrCropd(keys=data.keys(), spatial_size=self.config['DATA']['dim'])(data) + if 'Structs' in data.keys(): + data.pop('Structs') ## No need for mask in single-channel multi-branch ## Add clinical record at the end if 'Records' in self.config.keys(): data['Records'] = torch.tensor(self.SubjectList.loc[i, self.clinical_cols], - dtype=torch.float32) + dtype=torch.float32) if self.inference: return data - else: - label = torch.tensor(np.float(self.SubjectList.loc[i, "xnat_subjectdata_field_map_" + self.config['DATA']['target']])) - if self.config['DATA']['threshold'] is not None: label = torch.where( - label > self.config['DATA']['threshold'], 1, 0) - label = torch.as_tensor(label, dtype=torch.float32) - return data, label + else: ##Training + label = torch.tensor( + np.float(self.SubjectList.loc[i, "xnat_subjectdata_field_map_" + self.config['DATA']['target']])) + censor_status = not (np.int8( + self.SubjectList.loc[i, 'xnat_subjectdata_field_map_' + self.config['DATA']['censor_label']]).astype( + 'bool')) + if 'threshold' in self.config['DATA'].keys(): ## Classification + label = torch.where(label > self.config['DATA']['threshold'], 1, 0) + label = torch.as_tensor(label, dtype=torch.float32) + return data, censor_status, label ### DataLoader class DataModule(LightningDataModule): - def __init__(self, SubjectList, config=None, train_transform=None, val_transform=None, train_size=0.7, - val_size=0.2, test_size=0.1, num_workers=10, **kwargs): + def __init__(self, SubjectList, config=None, train_transform=None, val_transform=None, train_size=0.85, + num_workers=10, **kwargs): super().__init__() self.batch_size = config['MODEL']['batch_size'] self.num_workers = num_workers data_trans = class_stratify(SubjectList, config) ## Split Test with fixed seed - train_val_list, test_list = train_test_split(SubjectList, test_size=0.15, random_state=42, stratify=data_trans) + train_val_list, test_list = train_test_split(SubjectList, train_size=train_size, random_state=42, + stratify=data_trans) data_trans = class_stratify(train_val_list, config) ## Split train-val with random seed - train_list, val_list = train_test_split(train_val_list, test_size=0.15, random_state=np.random.randint(10000), + train_list, val_list = train_test_split(train_val_list, train_size=train_size, + random_state=np.random.randint(10000), stratify=data_trans) train_list = train_list.reset_index(drop=True) @@ -170,7 +119,7 @@ def __init__(self, SubjectList, config=None, train_transform=None, val_transform test_list = test_list.reset_index(drop=True) self.train_data = DataGenerator(train_list, config=config, transform=train_transform, **kwargs) - self.val_data = DataGenerator(val_list,config=config, transform=val_transform, **kwargs) + self.val_data = DataGenerator(val_list, config=config, transform=val_transform, **kwargs) self.test_data = DataGenerator(test_list, config=config, transform=val_transform, **kwargs) def train_dataloader(self): return DataLoader(self.train_data, batch_size=self.batch_size, @@ -193,6 +142,11 @@ def QuerySubjectList(config, session): XML.Add_search_field( {"element_name": "xnat:subjectData", "field_ID": "XNAT_SUBJECTDATA_FIELD_MAP=" + str(config['DATA']['target']), "sequence": "1", "type": "int"}) + if 'censor_label' in config['DATA'].keys(): + XML.Add_search_field( + {"element_name": "xnat:subjectData", + "field_ID": "XNAT_SUBJECTDATA_FIELD_MAP=" + str(config['DATA']['censor_label']), + "sequence": "1", "type": "int"}) ## Label XML.Add_search_field( {"element_name": "xnat:subjectData", "field_ID": "SUBJECT_LABEL", "sequence": "1", "type": "string"}) @@ -237,6 +191,14 @@ def QuerySubjectList(config, session): return SubjectList +def class_stratify(SubjectList, config): + ptarget = SubjectList['xnat_subjectdata_field_map_' + config['DATA']['target']] + kbins = KBinsDiscretizer(n_bins=15, encode='ordinal', strategy='uniform') + ptarget = np.array(ptarget).reshape((len(ptarget), 1)) + data_trans = kbins.fit_transform(ptarget) + return data_trans + + def SynchronizeData(config, SubjectList): session = xnat.connect(config['SERVER']['Address'], user=config['SERVER']['User'], password=config['SERVER']['Password']) @@ -253,46 +215,15 @@ def get_subject_info(config, session, subjectid): return data -def QuerySubjectInfo(config, SubjectList, session): - if config['DATA']['Nifty']: - for i in range(len(SubjectList)): - subject_label = SubjectList.loc[i,'subject_label'] - for key in config['MODALITY'].keys(): +def QuerySubjectInfo(config, SubjectList): + for i in range(len(SubjectList)): + subject_label = SubjectList.loc[i, 'subject_label'] + for key in config['MODALITY'].keys(): + if key == 'Structs': + SubjectList.loc[i, key + '_Path'] = Path(config['DATA']['DataFolder'], subject_label, 'struct_TS') + else: SubjectList.loc[i, key + '_Path'] = Path(config['DATA']['DataFolder'], subject_label) - else: - with ThreadPoolExecutor(max_workers=10) as executor: - future_to_url = {executor.submit(get_subject_info, config, session, subjectid) for subjectid in - SubjectList['subjectid']} - executor.shutdown(wait=True) - for future in concurrent.futures.as_completed(future_to_url): - subjectdata = future.result() - subjectid = subjectdata["xnat:Subject"][0]["@ID"] - for key in config['MODALITY'].keys(): - path = GeneratePath(subjectdata, Modality=key, config=config) - if key == 'CT': - SubjectList.loc[SubjectList.subjectid == subjectid, key + '_Path'] = path - else: - spath = glob.glob(path + '/*dcm') - SubjectList.loc[SubjectList.subjectid == subjectid, key + '_Path'] = spath[0] - -def GeneratePath(subjectdata, Modality, config): - subject = subjectdata['xnat:Subject'][0] - subject_label = subject['@label'] - experiments = subject['xnat:experiments'][0]['xnat:experiment'] - - ## Won't work with many experiments yet - for experiment in experiments: - experiment_label = experiment['@label'] - scans = experiment['xnat:scans'][0]['xnat:scan'] - for scan in scans: - if (scan['@type'] in Modality): - scan_label = scan['@ID'] + '-' + scan['@type'] - resources_label = scan['xnat:file'][0]['@label'] - if resources_label == 'SNAPSHOTS': - resources_label = scan['xnat:file'][1]['@label'] - path = os.path.join(config['DATA']['DataFolder'], subject_label, experiment_label, 'scans', - scan_label, 'resources', resources_label, 'files') - return path + def LoadClinicalData(config, PatientList): category_cols = [] @@ -313,7 +244,7 @@ def LoadClinicalData(config, PatientList): yc = X[category_cols].astype('float32') X[category_cols] = yc.fillna(yc.mean().astype('int')) yn = X[numerical_cols].astype('float32') - X[numerical_cols] = yn.fillna(yn.mean()) #X.loc[:, numerical_cols] = yn.fillna(yn.mean()) + X[numerical_cols] = yn.fillna(yn.mean()) # X.loc[:, numerical_cols] = yn.fillna(yn.mean()) X_trans = ct.fit_transform(X) if not isinstance(X_trans, (np.ndarray, np.generic)): X_trans = X_trans.toarray() @@ -322,4 +253,8 @@ def LoadClinicalData(config, PatientList): df_trans['xnat_subjectdata_field_map_' + target] = PatientList.loc[:, 'xnat_subjectdata_field_map_' + target] df_trans['subject_label'] = PatientList.loc[:, 'subject_label'] df_trans['subjectid'] = PatientList.loc[:, 'subjectid'] + if 'censor_label' in config['DATA'].keys(): + df_trans['xnat_subjectdata_field_map_' + config['DATA']['censor_label']] = PatientList.loc[:, + 'xnat_subjectdata_field_map_' + + config['DATA']['censor_label']] return df_trans, clinical_col diff --git a/DefaultConfiguration.ini b/DefaultConfiguration.ini deleted file mode 100644 index 7f9a696..0000000 --- a/DefaultConfiguration.ini +++ /dev/null @@ -1,65 +0,0 @@ -[MODEL] -BaseModel = "AutoEncoder" -Model_Save_Path = "./" -batch_size = 2 -RANDOM_SEED = 42 -Loss_Function = "CrossEntropyLoss"#"MSELoss" -Activation = "Sigmoid" -Max_Epochs = 2 -Precision = 32 -Backbone = "densenet121" -Pretrained = true -Drop_Rate = 0.1 -wf = 4 -depth = 6 -activation = "Identity" -inference = true -emb_size = 1000 - -[MODEL_PARAMETERS] -spatial_dims = 3 -block_config = [1, 2, 4, 1] -in_channels = 1 -out_channels = 1 - -[Dose_MODEL_PARAMETERS] -in_channels = 1 -wf = 3 -depth = 3 - -[MODALITY] -CT = 1 -Dose = 1 - -[SERVER] -Address = 'http://128.16.11.124:8080/xnat' -Projects = ["RTOG_0617"] -User = "***" -Password = "***" - -[DATA] -DataFolder = "./Data" -n_per_sample = 5000 -n_classes = 2 -n_channel = 3 -sub_patch_size = 16 -dim = [100,256, 256] -vis = [0] -train_size = 0.7 -val_size = 0.3 -target = "survival_months" -threshold = 20 -#Mask = '' -Multichannel = false - -[CRITERIA] -#survival_status = 1 -#arm = 1 - -[CHECKPOINT] -monitor = "val_loss" #"val_acc_epoch" -mode = "max" -matrix = ['ROC', 'Specificity'] - -[FILTER] -#patient_id = ['0617-444138','0617-449451'] diff --git a/Models/AutoEncoder2D.py b/Models/AutoEncoder2D.py deleted file mode 100644 index 73c389a..0000000 --- a/Models/AutoEncoder2D.py +++ /dev/null @@ -1,87 +0,0 @@ -import matplotlib.pyplot as plt -from pytorch_lightning import LightningDataModule, LightningModule -import numpy as np -import torch -from collections import Counter -import torchvision -from torchvision import datasets, models, transforms -from torchvision import transforms -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping -from pytorch_lightning import LightningDataModule, LightningModule, Trainer,seed_everything -from torchsummary import summary -import sys -import torchio as tio -import sklearn -from pytorch_lightning import loggers as pl_loggers -import torchmetrics - -## Module - Dataloaders -from Dataloader.Dataloader import DataModule, DataGenerator, LoadSortDataLabel - -## Model -class Classifier2D(LightningModule): - def __init__(self): - super().__init__() - self.n_classes = 1 - self.backbone = models.resnet50(pretrained=True) - self.model= torch.nn.Sequential( - self.unet_model, - torch.nn.LazyLinear(128), - torch.nn.LazyLinear(self.n_classes) - ) - summary(self.model.to('cuda'), (2,160,160,40)) - self.accuracy = torchmetrics.AUC(reorder=True) - self.loss_fcn = torch.nn.BCEWithLogitsLoss() - - def forward(self, x): - return self.model(x) - - def training_step(self, batch,batch_idx): - image,label = batch - prediction = self.forward(image) - loss = self.loss_fcn(prediction.squeeze(), label) - self.log("loss", loss) - return {"loss":loss,"prediction":prediction.squeeze(),"label":label} - - def validation_step(self, batch,batch_idx): - image,label = batch - prediction = self.forward(image) - loss = self.loss_fcn(prediction.squeeze(), label) - return {"loss":loss,"prediction":prediction.squeeze(),"label":label} - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) - scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) - return [optimizer], [scheduler] - -if __name__ == "__main__": - -## Main -train_transform = tio.Compose([ - tio.RandomAffine(), - # tio.RescaleIntensity(out_min_max=(0, 1)) -]) - -val_transform = tio.Compose([ - tio.RandomAffine(), - # tio.RescaleIntensity(out_min_max=(0, 1)) -]) -callbacks = [ - ModelCheckpoint( - dirpath='./', - monitor='val_loss', - filename="model_DeepSurv",#.{epoch:02d}-{val_loss:.2f}.h5", - save_top_k=1, - mode='min'), - EarlyStopping(monitor='val_loss') -] - -data_file = np.load(sys.argv[1]) -label_file = sys.argv[2] -label_name = sys.argv[3] - -data,label = LoadSortDataLabel(label_name, label_file, data_file) -trainer = Trainer(gpus=1, max_epochs=20)#,callbacks=callbacks) -model = DeepSurv() -dataloader = DataModule(data, label, train_transform = train_transform, val_transform = val_transform, batch_size=4, inference=False) -trainer.fit(model, dataloader) diff --git a/Models/AutoEncoder3D.py b/Models/AutoEncoder3D.py deleted file mode 100644 index 27f450e..0000000 --- a/Models/AutoEncoder3D.py +++ /dev/null @@ -1,91 +0,0 @@ -import matplotlib.pyplot as plt -from pytorch_lightning import LightningDataModule, LightningModule -import numpy as np -import torch -from collections import Counter -import torchvision -from torchvision import datasets, models, transforms -from torchvision import transforms -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping -from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything -from torchsummary import summary -import sys -import torchio as tio -import sklearn -from pytorch_lightning import loggers as pl_loggers -import torchmetrics - -## Module - Dataloaders -from Dataloader.Dataloader import DataModule, DataGenerator, LoadSortDataLabel - - -## Model -class Classifier2D(LightningModule): - def __init__(self): - super().__init__() - self.n_classes = 1 - self.backbone = models.resnet50(pretrained=True) - self.model = torch.nn.Sequential( - self.unet_model, - torch.nn.LazyLinear(128), - torch.nn.LazyLinear(self.n_classes) - ) - summary(self.model.to('cuda'), (2, 160, 160, 40)) - self.accuracy = torchmetrics.AUC(reorder=True) - self.loss_fcn = torch.nn.BCEWithLogitsLoss() - - def forward(self, x): - return self.model(x) - - def training_step(self, batch, batch_idx): - image, label = batch - prediction = self.forward(image) - loss = self.loss_fcn(prediction.squeeze(), label) - self.log("loss", loss) - return {"loss": loss, "prediction": prediction.squeeze(), "label": label} - - def validation_step(self, batch, batch_idx): - image, label = batch - prediction = self.forward(image) - loss = self.loss_fcn(prediction.squeeze(), label) - return {"loss": loss, "prediction": prediction.squeeze(), "label": label} - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) - scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) - return [optimizer], [scheduler] - - -if __name__ == "__main__": - -## Main -train_transform = tio.Compose([ - tio.RandomAffine(), - # tio.RescaleIntensity(out_min_max=(0, 1)) -]) - -val_transform = tio.Compose([ - tio.RandomAffine(), - # tio.RescaleIntensity(out_min_max=(0, 1)) -]) -callbacks = [ - ModelCheckpoint( - dirpath='./', - monitor='val_loss', - filename="model_DeepSurv", - # .{epoch:02d}-{val_loss:.2f}.h5", - save_top_k=1, - mode='min'), - EarlyStopping(monitor='val_loss') -] - -data_file = np.load(sys.argv[1]) -label_file = sys.argv[2] -label_name = sys.argv[3] - -data, label = LoadSortDataLabel(label_name, label_file, data_file) -trainer = Trainer(gpus=1, max_epochs=20) # ,callbacks=callbacks) -model = DeepSurv() -dataloader = DataModule(data, label, train_transform=train_transform, val_transform=val_transform, batch_size=4, - inference=False) -trainer.fit(model, dataloader) \ No newline at end of file diff --git a/Models/Classifier.py b/Models/Classifier.py index c341b87..052ae36 100644 --- a/Models/Classifier.py +++ b/Models/Classifier.py @@ -5,8 +5,6 @@ from torch import nn import torchmetrics from monai.networks import blocks, nets -from Models.UnetEncoder import UnetEncoder -from Models.PretrainedEncoder3D import PretrainedEncoder3D ## Model class Classifier(LightningModule): def __init__(self, config, module_str): @@ -14,7 +12,6 @@ def __init__(self, config, module_str): model = config['MODEL']['Backbone'] parameters = config['MODEL_PARAMETERS'] - # only use network for features if model == 'torchvision': model_name = config['MODEL'][module_str + '_model_name'] @@ -24,23 +21,14 @@ def __init__(self, config, module_str): model_str = 'nets.' + model + '(**parameters)' self.backbone = eval(model_str) - layers = list(self.backbone.children())[:-1] - self.model = nn.Sequential(*layers) + if 'out_channels' in config['MODEL_PARAMETERS'].keys(): + self.out_feat = config['MODEL_PARAMETERS']['out_channels'] + elif 'num_classes' in config['MODEL_PARAMETERS'].keys(): + self.out_feat = config['MODEL_PARAMETERS']['num_classes'] - self.flatten = nn.Sequential( - # nn.Dropout(0.3), - # nn.AdaptiveAvgPool3d(output_size=(4, 4, 4)), - nn.Dropout(0.3), - nn.AdaptiveAvgPool3d(output_size=(1, 1, 1)), - nn.Flatten(), - ) - self.model.apply(self.weights_init) - self.accuracy = torchmetrics.AUROC(task="binary") - self.loss_fcn = torch.nn.BCEWithLogitsLoss() def forward(self, x): - features = self.model(x) - return self.flatten(features) + return self.backbone(x) def weights_init(self, m): if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear): diff --git a/Models/Linear.py b/Models/Linear.py index 21eb750..0cf2600 100644 --- a/Models/Linear.py +++ b/Models/Linear.py @@ -17,16 +17,17 @@ ## Model class Linear(pl.LightningModule): - def __init__(self): + def __init__(self, out_feat=42, in_feat=58): super().__init__() + self.loss_fcn = nn.CrossEntropyLoss() + self.out_feat = out_feat + self.model = nn.Sequential( - nn.Linear(51, 42), + nn.Linear(in_feat, out_feat), nn.Dropout(0.3), - nn.LayerNorm(42), + nn.LayerNorm(out_feat), nn.ReLU(), - ) - self.loss_fcn = nn.CrossEntropyLoss() - + ) def forward(self, x): return self.model(x.float()) diff --git a/Models/MixModel.py b/Models/MixModel.py index 41a3eca..d185d30 100644 --- a/Models/MixModel.py +++ b/Models/MixModel.py @@ -1,5 +1,6 @@ import matplotlib.pyplot as plt import torch +import numpy as np import copy from torch import nn from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything @@ -11,18 +12,20 @@ class MixModel(LightningModule): def __init__(self, module_dict, config, loss_fcn=torch.nn.BCEWithLogitsLoss()): super().__init__() self.module_dict = module_dict + out_feat = np.sum([model.out_feat for model in module_dict.values()]) self.config = config - self.loss_fcn = getattr(torch.nn, self.config["MODEL"]["Loss_Function"])(pos_weight=torch.tensor(1.21)) + self.loss_fcn = getattr(torch.nn, self.config["MODEL"]["Loss_Function"])(pos_weight=torch.tensor(1.18)) self.activation = getattr(torch.nn, self.config["MODEL"]["Activation"])() self.classifier = nn.Sequential( - nn.Linear(198, 120), - nn.Dropout(0.3), + nn.Linear(out_feat, 256), + nn.Dropout(0.05), + nn.Linear(256, 120), + nn.Dropout(0.05), nn.Linear(120, 40), - nn.Dropout(0.3), + nn.Dropout(0.05), nn.Linear(40, config['DATA']['n_classes']), - #self.activation + self.activation ) - self.classifier.apply(self.weights_init) def forward(self, data_dict): features = torch.cat([self.module_dict[key](data_dict[key]) for key in self.module_dict.keys()], dim=1) @@ -30,54 +33,56 @@ def forward(self, data_dict): return prediction def training_step(self, batch, batch_idx): - out = {} - data_dict, label = batch - prediction = self.forward(data_dict) - loss = self.loss_fcn(prediction.squeeze(dim=1), label) + data_dict, censor_status, label = batch + prediction = self.forward(data_dict).squeeze(dim=1) + loss = self.loss_fcn(prediction, label) self.log("train_loss", loss, on_step=False, on_epoch=True, sync_dist=True) - MAE = torch.abs(prediction.flatten(0) - label) - out['MAE'] = MAE.detach() + MAE = torch.abs(prediction - label) out = copy.deepcopy(data_dict) + out['MAE'] = MAE.detach() out['prediction'] = prediction.detach() - out['label'] = label - out['loss'] = loss + out['label'] = label + out['censor_status'] = censor_status + out['loss'] = loss return out def training_epoch_end(self, step_outputs): labels = torch.cat([out['label'] for i, out in enumerate(step_outputs)], dim=0) + censor_status = torch.cat([out['censor_status'] for i, out in enumerate(step_outputs)], dim=0) prediction = torch.cat([out['prediction'] for i, out in enumerate(step_outputs)], dim=0) - self.logger.report_epoch(prediction, labels, step_outputs,self.current_epoch, 'train_epoch_') - + self.logger.report_epoch(prediction, censor_status, labels, step_outputs,self.current_epoch, 'train_epoch_') + def validation_step(self, batch, batch_idx): - out = {} - data_dict, label = batch - prediction = self.forward(data_dict) - loss = self.loss_fcn(prediction.squeeze(dim=1), label) + data_dict, censor_status, label = batch + prediction = self.forward(data_dict).squeeze(dim=1) + loss = self.loss_fcn(prediction, label) self.log("val_loss", loss, on_step=False, on_epoch=True, sync_dist=True) - MAE = torch.abs(prediction.flatten(0) - label) - out['MAE'] = MAE + MAE = torch.abs(prediction - label) out = copy.deepcopy(data_dict) + out['MAE'] = MAE out['prediction'] = prediction + out['censor_status'] = censor_status out['label'] = label out['loss'] = loss return out def validation_epoch_end(self, step_outputs): labels = torch.cat([out['label'] for i, out in enumerate(step_outputs)], dim=0) + censor_status = torch.cat([out['censor_status'] for i, out in enumerate(step_outputs)], dim=0) prediction = torch.cat([out['prediction'] for i, out in enumerate(step_outputs)], dim=0) - self.logger.report_epoch(prediction.squeeze(), labels, step_outputs, self.current_epoch,'val_epoch_') + self.logger.report_epoch(prediction, censor_status, labels, step_outputs, self.current_epoch, 'val_epoch_') def test_step(self, batch, batch_idx): - data_dict, label = batch - prediction = self.forward(data_dict) - loss = self.loss_fcn(prediction.squeeze(dim=1), label) - out = {} - MAE = torch.abs(prediction.flatten(0) - label) - out['MAE'] = MAE + data_dict, censor_status, label = batch + prediction = self.forward(data_dict).squeeze(dim=1) + loss = self.loss_fcn(prediction, label) + MAE = torch.abs(prediction - label) out = copy.deepcopy(data_dict) - out['prediction'] = prediction.squeeze(dim=1) + out['MAE'] = MAE + out['prediction'] = prediction out['label'] = label - out['loss'] = loss + out['censor_status'] = censor_status + out['loss'] = loss return out def weights_init(self, m): @@ -89,6 +94,6 @@ def weights_reset(self, m): m.reset_parameters() def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + optimizer = torch.optim.Adam(self.parameters(), lr=1e-4, weight_decay=1e-5) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) return [optimizer], [scheduler] diff --git a/Models/ModelCAE.py b/Models/ModelCAE.py deleted file mode 100644 index fa527db..0000000 --- a/Models/ModelCAE.py +++ /dev/null @@ -1,146 +0,0 @@ -import matplotlib.pyplot as plt -from pytorch_lightning import LightningDataModule, LightningModule -import numpy as np -import torch -from torch import nn -from collections import Counter -import torchvision -from torchvision import datasets, models, transforms -from torchvision import transforms -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping -from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything -import sys -import torchio as tio -import sklearn -from pytorch_lightning import loggers as pl_loggers -import torchmetrics -from Losses.loss import WeightedMSE - -## Models -from Models.Linear import Linear -from Models.Classifier2D import Classifier2D -from Models.Classifier3D import Classifier3D -from Models.TransformerEncoder import PositionEncoding, PatchEmbedding, TransformerBlock -from Models.fds import FDS - -# Please refer paper CAE-TRANSFORMER: TRANSFORMER-BASED MODEL TO PREDICT INVASIVENESS -# OF LUNG ADENOCARCINOMA SUBSOLID NODULES FROM NON-THIN SECTION 3D -# CT SCANS - - -class ModelCAE(LightningModule): - def __init__(self, config): - super().__init__() - ## define backbone - # backbone = torchvision.models.resnet18(pretrained=True) - backbone = torch.hub.load('pytorch/vision:v0.10.0', 'densenet121', pretrained=True) - layers = list(backbone.children())[:-1] - self.feature_extractor = nn.Sequential(*layers) - self.feature_extractor.eval() - - for param in self.feature_extractor.parameters(): - param.requires_grad = False - # for param in self.feature_extractor[0][0:10].parameters(): - # param.requires_grad = False - - self.linear1 = nn.LazyLinear(config['transformer_embed_dim']) - # self.FDS = FDS(feature_dim=1024, start_update=0, start_smooth=1, kernel='gaussian', ks=7, sigma=3) - - self.pe = PositionEncoding(img_size=config['img_sizes'], patch_size=config['patch_size'], - in_channel=config['in_channels'], - embed_dim=config['transformer_embed_dim'], dropout=config['dropout'], img_dim=2, iftoken=False) - - self.transformers = nn.ModuleList( - [TransformerBlock(num_heads=config['transformer_head'], embed_dim=config['transformer_embed_dim'], - mlp_dim=config['transformer_mlp_dim'], - dropout=config['dropout']) for _ in range(config['transformer_layer'])]) - self.pool_top = nn.MaxPool2d(4) - - # def WeightedMSE(self, prediction, labels): - # loss = 0 - # for i, label in enumerate(labels): - # idx = (self.label_range == int(label.cpu().numpy())).nonzero() - # if (idx is not None) and (idx[0][0] < 60): - # # print(idx[0][0]) - # loss = loss + (prediction[i] - label) ** 2 * self.weights[idx[0][0]] - # else: - # loss = loss + (prediction[i] - label) ** 2 * self.weights[-1] - # loss = loss / (i + 1) - # return loss - - def convert2d(self, x): - y = x.repeat(1, 3, 1, 1) - features = self.feature_extractor(y) - features = features.flatten(1) - features = features.unsqueeze(0) - features = features.unsqueeze(1) - # features = features.permute(2, 3, 0, 1) - features = self.linear1(features) - return features - - def forward(self, x): - features = torch.cat([self.convert2d(b.transpose(0, 1)) for i, b in enumerate(x)], dim=0) - # features = self.pe(features) - features = features.permute(0, 2, 3, 1).flatten(2) - for transformer in self.transformers: - features = transformer(features) - x = self.pool_top(features) - features = x.flatten(start_dim=1) - return features - - # def training_step(self, batch, batch_idx): - # datadict, label = batch - # forward_cal = self.forward(datadict) - # prediction = forward_cal['prediction'] - # print(prediction, label) - # if self.config['REGULARIZATION']['Label_smoothing']: - # loss = self.WeightedMSE(prediction.squeeze(dim=1), batch[-1]) - # else: - # loss = self.loss_fcn(prediction.squeeze(dim=1), batch[-1]) - # self.log("loss", loss, on_epoch=True) - # out = {'loss': loss, 'features': forward_cal['features'], 'label': label} - # return out - - # def training_epoch_end(self, training_step_outputs): - # if self.config['REGULARIZATION']['Feature_smoothing']: - # training_features = torch.cat([out['features'] for i, out in enumerate(training_step_outputs)], dim=0) - # training_labels = torch.cat([out['label'] for i, out in enumerate(training_step_outputs)], dim=0) - # if self.current_epoch >= 0: - # self.FDS.update_last_epoch_stats(self.current_epoch) - # self.FDS.update_running_stats(training_features, training_labels, self.current_epoch) - - # def validation_step(self, batch, batch_idx): - # datadict, label = batch - # forward_cal = self.forward(datadict, label) - # prediction = forward_cal['prediction'] - # val_loss = self.loss_fcn(prediction.squeeze(dim=1), batch[-1]) - # self.log("val_loss", val_loss, on_epoch=True) - # MAE = torch.abs(prediction.flatten(0) - label) - # out = {'MAE': MAE, 'img': datadict['Anatomy']} - # return out - - # def validation_epoch_end(self, validation_step_outputs): - # worst_MAE = 0 - # for i, data in enumerate(validation_step_outputs): - # loss = data['MAE'] - # idx = torch.argmax(loss) - # if loss[idx] > worst_MAE: - # worst_img = data['img'][idx] - # worst_MAE = loss[idx] - # self.log('worst_MAE', worst_MAE) - # grid = self.generate_report(worst_img) - # self.logger.experiment.add_image('validate_worst_case_img', grid, self.current_epoch) - - # def test_step(self, batch, batch_idx): - # datadict, label = batch - # forward_cal = self.forward(datadict, label) - # prediction = forward_cal['prediction'] - # test_loss = self.loss_fcn(prediction.squeeze(dim=1), batch[-1]) - # print('test_prediction:', prediction, label) - # self.log('test_loss:', test_loss) - # return test_loss - - # def configure_optimizers(self): - # optimizer = torch.optim.Adam(self.parameters(), lr=1e-5) - # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) - # return [optimizer], [scheduler] diff --git a/Models/ModelCoTr.py b/Models/ModelCoTr.py deleted file mode 100644 index cd32cdb..0000000 --- a/Models/ModelCoTr.py +++ /dev/null @@ -1,59 +0,0 @@ -import matplotlib.pyplot as plt -from pytorch_lightning import LightningDataModule, LightningModule -import numpy as np -import torch -from torch import nn -from collections import Counter -import torchvision -from torchvision import datasets, models, transforms -from torchvision import transforms -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping -from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything -import sys -import torchio as tio -import sklearn -from pytorch_lightning import loggers as pl_loggers -import torchmetrics -from Models.UnetEncoder import UnetEncoder -## Models -from Models.Linear import Linear -from Models.Classifier2D import Classifier2D -from Models.Classifier3D import Classifier3D -from Models.TransformerEncoder import PositionEncoding, PatchEmbedding, TransformerBlock - -# Please refer to model CoTr: Efficiently Bridging CNN and Transformer for 3D Medical Image Segmentation. - - -class ModelCoTr(LightningModule): - def __init__(self, config): - super().__init__() - parameters = config['MODEL_PARAMETERS'] - depth = parameters['depth'] - wf = parameters['width'] - self.model = UnetEncoder(**parameters) - self.pe = nn.ModuleList( - [PositionEncoding(img_size=config['img_sizes'][i], patch_size=config['patch_size'], in_channel=2 ** (wf + i), - embed_dim=config['transformer_embed_dim'], - img_dim=3, dropout=config['dropout'], iftoken=True) for i in range(depth)] - ) - self.transformers = nn.ModuleList( - [TransformerBlock(num_heads=config['transformer_head'], embed_dim=config['transformer_embed_dim'], - mlp_dim=config['transformer_mlp_dim'], - dropout=config['dropout']) for _ in range(config['transformer_layer'])]) - - def forward(self, x): - flg = 0 - for i, down in enumerate(self.model.encoder): - x = down(x) - if flg == 0: - feature = self.pe[i](x) - flg = 1 - else: - out_trans = self.pe[i](x) - feature = torch.cat((feature, out_trans), dim=1) - - for transformer in self.transformers: - feature = transformer(feature) - features = feature.flatten(start_dim=1) - - return features diff --git a/Models/ModelTransUnet.py b/Models/ModelTransUnet.py deleted file mode 100644 index 4e8ecb7..0000000 --- a/Models/ModelTransUnet.py +++ /dev/null @@ -1,56 +0,0 @@ -import matplotlib.pyplot as plt -from pytorch_lightning import LightningDataModule, LightningModule -import numpy as np -import torch -from torch import nn -from collections import Counter -import torchvision -from torchvision import datasets, models, transforms -from torchvision import transforms -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping -from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything -import sys -import torchio as tio -import sklearn -from pytorch_lightning import loggers as pl_loggers -import torchmetrics -from Models.UnetEncoder import UnetEncoder -from torchinfo import summary -## Models -from Models.Linear import Linear -from Models.Classifier2D import Classifier2D -from Models.Classifier3D import Classifier3D -from Models.TransformerEncoder import PositionEncoding, PatchEmbedding, TransformerBlock - -# Please refer to model TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation -# img_sizes=256, patch_size=4, embed_dim=256, in_channels=1, -# num_layers=3, num_heads=8, dropout=0.5, mlp_dim=128 - -class ModelTransUnet(LightningModule): - def __init__(self, config): - super().__init__() - parameters = config['MODEL_PARAMETERS'] - self.model = UnetEncoder(**parameters) - self.model.apply(self.weights_init) - summary(self.model.to('cuda'), (3, 1, 32, 128, 128), col_names=["input_size", "output_size"], depth=5) - - self.pe = PositionEncoding(img_size=config['img_sizes'], patch_size=config['patch_size'], in_channel=config['in_channels'], - embed_dim=config['transformer_embed_dim'], img_dim=3, dropout=config['dropout'], iftoken=True) - self.transformers = nn.ModuleList( - [TransformerBlock(num_heads=config['transformer_head'], embed_dim=config['transformer_embed_dim'], - mlp_dim=config['transformer_mlp_dim'], - dropout=config['dropout']) for _ in range(config['transformer_layer'])]) - - def forward(self, x): - for i, down in enumerate(self.model.encoder): - x = down(x) - feature = self.pe(x) - for transformer in self.transformers: - feature = transformer(feature) - features = feature.flatten(start_dim=1) - - return features - - def weights_init(self, m): - if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear): - nn.init.xavier_uniform_(m.weight.data) diff --git a/Models/PretrainedEncoder3D.py b/Models/PretrainedEncoder3D.py deleted file mode 100644 index d977481..0000000 --- a/Models/PretrainedEncoder3D.py +++ /dev/null @@ -1,64 +0,0 @@ -import matplotlib.pyplot as plt -import torch -from pytorch_lightning import LightningModule -from torch import nn -import torchmetrics -from monai.networks import blocks, nets -from Models.UnetEncoder import UnetEncoder - - -class PretrainedEncoder3D(LightningModule): - def __init__(self, config, module_str): - super().__init__() - self.n_classes = 1 - model_str = config['MODEL'][module_str + '_Backbone'] - parameters = config[module_str + '_MODEL_PARAMETERS'] - self.config = config - - self.loss_fcn = getattr(torch.nn, self.config["MODEL"]["Loss_Function"])() - self.activation = getattr(torch.nn, self.config["MODEL"]["Activation"])() - - model_str = 'nets.' + model_str + '(**parameters)' - full_model = eval(model_str) - vit_dict = torch.load(config['MODEL'][module_str + '_ckpt_path']) - vit_weights = vit_dict['state_dict'] - model_dict = full_model.state_dict() - vit_weights = {k: v for k, v in vit_weights.items() if k in model_dict} - model_dict.update(vit_weights) - full_model.load_state_dict(model_dict) - del model_dict, vit_weights, vit_dict - - self.backbone = full_model.vit - self.hidden_size = full_model.hidden_size - self.feat_size = full_model.feat_size - self.encoder2 = full_model.encoder2 - self.encoder1 = full_model.encoder1 - - for param in self.backbone.parameters(): - param.requires_grad = False - - for param in self.encoder1.parameters(): - param.requires_grad = False - - for param in self.encoder2.parameters(): - param.requires_grad = False - - self.accuracy = torchmetrics.AUC(reorder=True) - - def forward(self, img): - out1 = self.backbone(img) - enc1 = self.encoder1(img) - f1 = nn.AdaptiveMaxPool3d((1, 1, 48))(enc1) - f1f = f1.flatten(start_dim=1) - enc2 = self.encoder2(self.proj_feat(out1[1][3], self.hidden_size, self.feat_size)) - f2 = nn.AdaptiveAvgPool3d((1, 1, 12))(enc2) - f2f = f2.flatten(start_dim=1) - connect_features = torch.cat((f1f, f2f), dim=1) - return connect_features - - def proj_feat(self, x, hidden_size, feat_size): - new_view = (x.size(0), *feat_size, hidden_size) - x = x.view(new_view) - new_axes = (0, len(x.shape) - 1) + tuple(d + 1 for d in range(len(feat_size))) - x = x.permute(new_axes).contiguous() - return x diff --git a/Models/TransformerEncoder.py b/Models/TransformerEncoder.py deleted file mode 100644 index 5ab26ee..0000000 --- a/Models/TransformerEncoder.py +++ /dev/null @@ -1,127 +0,0 @@ -import torch -import torch.nn as nn - - -class PositionEncoding(nn.Module): - - def __init__(self, img_size, patch_size, in_channel, embed_dim, dropout=0.8, img_dim=3, iftoken=True): - super().__init__() - self.img_size = img_size - self.in_channel = in_channel - self.iftoken = iftoken - self.patch_embed = PatchEmbedding(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim, in_channel=in_channel, img_dim=img_dim) - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - if iftoken: - self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim)) - else: - self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches, embed_dim)) - self.pos_drop = nn.Dropout(dropout) - - def forward(self, x): - B = x.shape[0] - x = self.patch_embed(x) - if self.iftoken: - cls_tokens = self.cls_token.expand(B, -1, -1) - x = torch.cat((cls_tokens, x), dim=1) - x = x + self.pos_embed - x = self.pos_drop(x) - return x - - def get_attention_maps(self, x): - attention_maps = [] - for l in self.layers: - _, attn_map = l.self_attn(x, return_attention=True) - attention_maps.append(attn_map) - x = l(x) - return attention_maps - - -class TransformerBlock(nn.Module): - - def __init__(self, num_heads, embed_dim, mlp_dim, dropout=0.0): - """ - Inputs: - input_dim - Dimensionality of the input - num_heads - Number of heads to use in the attention block - dim_feedforward - Dimensionality of the hidden layer in the MLP - dropout - Dropout probability to use in the dropout layers - """ - super().__init__() - - # Attention layer - self.query = nn.Linear(embed_dim, embed_dim) - self.key = nn.Linear(embed_dim, embed_dim) - self.value = self.query = nn.Linear(embed_dim, embed_dim) - - self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout) - - # Two-layer MLP - self.mlp = nn.Sequential( - nn.Linear(embed_dim, mlp_dim), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(mlp_dim, embed_dim), - nn.Dropout(dropout), - ) - - # Layers to apply in between the main layers - self.norm0 = nn.LayerNorm(embed_dim) - self.norm1 = nn.LayerNorm(embed_dim) - self.norm2 = nn.LayerNorm(embed_dim) - self.dropout = nn.Dropout(dropout) - - def forward(self, x): - query = self.query(self.norm0(x)) - key = self.key(self.norm0(x)) - value = self.value(self.norm0(x)) - out, attention = self.attn(query, key, value) - x = x + self.dropout(out) - x = self.norm1(x) - - # MLP part - linear_out = self.mlp(x) - x = x + linear_out - x = self.norm2(x) - - return x - -## -#class TransformerRegression(nn.Module): -# def __init__(self, num_layers, input_dim, num_heads, embed_dim, mlp_dim, dropout=0.0, -# input_dropout=0.0): -# super().__init__() -# self.linear_net = nn.Sequential( -# nn.Dropout(input_dim), -# nn.Linear(input_dim, embed_dim) -# ) -# self.transformer = TransformerEncoder(num_layers, num_heads, embed_dim, mlp_dim, dropout=0.0) -# -# def forward(self, x): -# linear_out = self.linear_net(x) -# return self.transformer(linear_out) - - -class PatchEmbedding(nn.Module): - def __init__(self, img_size=64, patch_size=4, embed_dim = 64, in_channel=1, img_dim=3): - super().__init__() - if len(img_size) == 1: - num_patches = (img_size[0] // patch_size) ** 3 ## for 3D image - if len(img_size) == 2: - num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) - if len(img_size) == 3: - num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) * (img_size[2] // patch_size) - self.img_size = img_size - self.img_dim = img_dim - self.patch_size = patch_size - self.num_patches = int(num_patches) - self.in_channel = in_channel - self.embed_dim = embed_dim - if img_dim == 3: - self.proj = nn.Conv3d(in_channel, embed_dim, kernel_size=patch_size, stride=patch_size) - else: - self.proj = nn.Conv2d(in_channel, embed_dim, kernel_size=patch_size, stride=patch_size) - - def forward(self, x): - #B, C, H, W, D = x.shape - x = self.proj(x).flatten(2).transpose(1, 2) - return x diff --git a/Models/UnetEDcoder.py b/Models/UnetEDcoder.py deleted file mode 100644 index 0f273a7..0000000 --- a/Models/UnetEDcoder.py +++ /dev/null @@ -1,42 +0,0 @@ -import monai.networks.nets -import torch -from torch import nn -from monai.networks import blocks, nets -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping -from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything - - -class UnetEDcoder(LightningModule): - def __init__(self, config): - super().__init__() - parameters = config['MODEL_PARAMETERS'] - model = eval('monai.networks.nets.BasicUNet(**parameters)') - - self.encoder = nn.Sequential( - model.conv_0, - model.down_1, - model.down_2, - model.down_3, - model.down_4, - ) - self.decoder = nn.Sequential( - model.upcat_4, - model.upcat_3, - model.upcat_2, - model.upcat_1, - model.final_conv, - ) - - def forward(self, x): - x0 = self.encoder[0](x) - x1 = self.encoder[1](x0) - x2 = self.encoder[2](x1) - x3 = self.encoder[3](x2) - x4 = self.encoder[4](x3) - - u4 = self.decoder[0](x4, x3) - u3 = self.decoder[1](u4, x2) - u2 = self.decoder[2](u3, x1) - u1 = self.decoder[3](u2, x0) - out = self.decoder[4](u1) - return out diff --git a/Models/UnetEncoder.py b/Models/UnetEncoder.py deleted file mode 100644 index 798aae6..0000000 --- a/Models/UnetEncoder.py +++ /dev/null @@ -1,27 +0,0 @@ -from torch import nn -from monai.networks import blocks, nets - -class UnetEncoder(nn.Module): - def __init__(self, depth, wf, in_channels, spatial_dims=3, kernel_size=None, stride=None): - super(UnetEncoder, self).__init__() - self.encoder = nn.ModuleList() - for i in range(depth): - out_channels = 2 ** (wf + i) - down_block = blocks.UnetResBlock(spatial_dims=spatial_dims, in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, norm_name='batch', dropout=0.5) - self.encoder.append(down_block) - in_channels = out_channels - - self.out_channels = in_channels - - def forward(self, x): - for i, down in enumerate(self.encoder): - x = down(x) - return x - - - def weights_init(self, m): - if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear): - nn.init.xavier_uniform_(m.weight.data) \ No newline at end of file diff --git a/Models/fds.py b/Models/fds.py deleted file mode 100644 index 0f82167..0000000 --- a/Models/fds.py +++ /dev/null @@ -1,155 +0,0 @@ -import logging -import numpy as np -from scipy.ndimage import gaussian_filter1d -from scipy.signal.windows import triang -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class FDS(nn.Module): - - def __init__(self, feature_dim, bucket_num=50, bucket_start=0, start_update=0, start_smooth=1, - kernel='gaussian', ks=5, sigma=2, momentum=0.9): - super(FDS, self).__init__() - self.feature_dim = feature_dim - self.bucket_num = bucket_num - self.bucket_start = bucket_start - self.kernel_window = self._get_kernel_window(kernel, ks, sigma) - self.half_ks = (ks - 1) // 2 - self.momentum = momentum - self.start_update = start_update - self.start_smooth = start_smooth - - self.register_buffer('epoch', torch.zeros(1).fill_(start_update).cuda()) - self.register_buffer('running_mean', torch.zeros(bucket_num - bucket_start, feature_dim).cuda()) - self.register_buffer('running_var', torch.ones(bucket_num - bucket_start, feature_dim).cuda()) - self.register_buffer('running_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim).cuda()) - self.register_buffer('running_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim).cuda()) - self.register_buffer('smoothed_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim).cuda()) - self.register_buffer('smoothed_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim).cuda()) - self.register_buffer('num_samples_tracked', torch.zeros(bucket_num - bucket_start).cuda()) - - @staticmethod - def _get_kernel_window(kernel, ks, sigma): - assert kernel in ['gaussian', 'triang', 'laplace'] - half_ks = (ks - 1) // 2 - if kernel == 'gaussian': - base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks - base_kernel = np.array(base_kernel, dtype=np.float32) - kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / sum( - gaussian_filter1d(base_kernel, sigma=sigma)) - elif kernel == 'triang': - kernel_window = triang(ks) / sum(triang(ks)) - else: - laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma) - kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / sum( - map(laplace, np.arange(-half_ks, half_ks + 1))) - - print(f'Using FDS: [{kernel.upper()}] ({ks}/{sigma})') - return torch.tensor(kernel_window, dtype=torch.float32).cuda() - - @staticmethod - def calibrate_mean_var(matrix, m1, v1, m2, v2, clip_min=0.1, clip_max=10): - if torch.sum(v1) < 1e-10: - return matrix - if (v1 == 0.).any(): - valid = (v1 != 0.) - factor = torch.clamp(v2[valid] / v1[valid], clip_min, clip_max) - matrix[:, valid] = (matrix[:, valid] - m1[valid]) * torch.sqrt(factor) + m2[valid] - return matrix - - factor = torch.clamp(v2 / v1, clip_min, clip_max) - return (matrix - m1) * torch.sqrt(factor) + m2 - - def _update_last_epoch_stats(self): - self.running_mean_last_epoch = self.running_mean - self.running_var_last_epoch = self.running_var - - self.smoothed_mean_last_epoch = F.conv1d( - input=F.pad(self.running_mean_last_epoch.unsqueeze(1).permute(2, 1, 0), - pad=(self.half_ks, self.half_ks), mode='reflect'), - weight=self.kernel_window.view(1, 1, -1), padding=0 - ).permute(2, 1, 0).squeeze(1) - self.smoothed_var_last_epoch = F.conv1d( - input=F.pad(self.running_var_last_epoch.unsqueeze(1).permute(2, 1, 0), - pad=(self.half_ks, self.half_ks), mode='reflect'), - weight=self.kernel_window.view(1, 1, -1), padding=0 - ).permute(2, 1, 0).squeeze(1) - - def reset(self): - self.running_mean.zero_() - self.running_var.fill_(1) - self.running_mean_last_epoch.zero_() - self.running_var_last_epoch.fill_(1) - self.smoothed_mean_last_epoch.zero_() - self.smoothed_var_last_epoch.fill_(1) - self.num_samples_tracked.zero_() - - def update_last_epoch_stats(self, epoch): - if epoch == self.epoch + 1: - self.epoch += 1 - self._update_last_epoch_stats() - print(f"Updated smoothed statistics on Epoch [{epoch}]!") - - def update_running_stats(self, features, labels, epoch): - if epoch < self.epoch: - return - - assert self.feature_dim == features.size(1), "Input feature dimension is not aligned!" - assert features.size(0) == labels.size(0), "Dimensions of features and labels are not aligned!" - - for label in torch.unique(labels): - if label > self.bucket_num - 1 or label < self.bucket_start: - continue - elif label == self.bucket_start: - curr_feats = features[labels <= label] - elif label == self.bucket_num - 1: - curr_feats = features[labels >= label] - else: - curr_feats = features[labels == label] - curr_num_sample = curr_feats.size(0) - curr_mean = torch.mean(curr_feats, 0) - curr_var = torch.var(curr_feats, 0, unbiased=True if curr_feats.size(0) != 1 else False) - - self.num_samples_tracked[int(label - self.bucket_start)] += curr_num_sample - factor = self.momentum if self.momentum is not None else \ - (1 - curr_num_sample / float(self.num_samples_tracked[int(label - self.bucket_start)])) - factor = 0 if epoch == self.start_update else factor - self.running_mean[int(label - self.bucket_start)] = \ - (1 - factor) * curr_mean + factor * self.running_mean[int(label - self.bucket_start)] - self.running_var[int(label - self.bucket_start)] = \ - (1 - factor) * curr_var + factor * self.running_var[int(label - self.bucket_start)] - - print(f"Updated running statistics with Epoch [{epoch}] features!") - - def smooth(self, features, labels, epoch): - if epoch < self.start_smooth: - return features - - # labels = labels.squeeze(1) - for label in torch.unique(labels): - if label > self.bucket_num - 1 or label < self.bucket_start: - continue - elif label == self.bucket_start: - features[labels <= label] = self.calibrate_mean_var( - features[labels <= label], - self.running_mean_last_epoch[int(label - self.bucket_start)], - self.running_var_last_epoch[int(label - self.bucket_start)], - self.smoothed_mean_last_epoch[int(label - self.bucket_start)], - self.smoothed_var_last_epoch[int(label - self.bucket_start)]) - elif label == self.bucket_num - 1: - features[labels >= label] = self.calibrate_mean_var( - features[labels >= label], - self.running_mean_last_epoch[int(label - self.bucket_start)], - self.running_var_last_epoch[int(label - self.bucket_start)], - self.smoothed_mean_last_epoch[int(label - self.bucket_start)], - self.smoothed_var_last_epoch[int(label - self.bucket_start)]) - else: - features.cuda()[labels == label] = self.calibrate_mean_var( - features.cuda()[labels == label].cuda(), - self.running_mean_last_epoch[int(label - self.bucket_start)].cuda(), - self.running_var_last_epoch[int(label - self.bucket_start)].cuda(), - self.smoothed_mean_last_epoch[int(label - self.bucket_start)].cuda(), - self.smoothed_var_last_epoch[int(label - self.bucket_start)].cuda()).cuda() - return features diff --git a/Training/#InnerEye.py# b/Training/#InnerEye.py# deleted file mode 100644 index ff8e824..0000000 --- a/Training/#InnerEye.py# +++ /dev/null @@ -1,73 +0,0 @@ -import sys -from DataGenerator.DataGenerator import QuerySubjectList, QuerySubjectInfo, GeneratePath, SynchronizeData -import toml -import nibabel as nib -from Utils.DicomTools import * -from pathlib import Path -import os -from rt_utils import RTStructBuilder -import xnat -session = xnat.connect('http://128.16.11.124:8080/xnat', user='admin', password='mortavar1977') -import csv -from monai.transforms import Spacing, LoadImage, EnsureChannelFirst -from monai.data import MetaTensor -import itk -config = toml.load(sys.argv[1]) -SubjectList = QuerySubjectList(config) -#SynchronizeData(config, SubjectList) -SubjectInfo = QuerySubjectInfo(config, SubjectList) -# roi = ['Heart', 'Oesophagus', 'Spinal Canal', 'Prox bronch tree', 'Proximal trachea', 'Cwall & ribs', 'L Lung', -# 'R Lung', 'Brachial Plexus'] -# roi_name = ['heart', 'oesophagus', 'spinal_canal', 'prox_bronch_tree', 'proximal_trachea', 'cwall_ribs', 'lt_lung', -# 'rt_lung', 'brachial_plexus'] - -roi_series = ['HEART', 'PTV', 'LUNG_IPSI', 'LUNG_CNTR', 'SPINAL_CORD', 'ESOPHAGUS'] -roi_name = ['HEART', 'PTV', 'LUNG_IPSI', 'LUNG_CNTR', 'SPINAL_CORD', 'ESOPHAGUS'] - -path = config['DATA']['NiiFolder'] -sPatient = SubjectList -r1 = [x.lower() for x in roi_series] - -for i in range(0, len(sPatient), 1): - subject_id = sPatient.loc[i, 'subjectid'] - subject_label = sPatient.loc[i,'subject_label'] - - print('No.{}:'.format(i) + subject_label) - - CTPath = GeneratePath(SubjectInfo, config, subject_id, 'CT') - CTArray, meta = LoadImage(reader='PydicomReader')(CTPath) - #ct = EnsureChannelFirst()(CTArray) - #ct = Spacing(pixdim=(1, 1, 3))(ct) - - """ - ct_array = ct.array.squeeze() - # First define the ROI based on target - RSPath = glob.glob(GeneratePath(SubjectInfo, config, subject_id, 'Structs') + '/*dcm') - RS = RTStructBuilder.create_from(dicom_series_path=CTPath, rt_struct_path=RSPath[0]) - # print(RS) - #%% - roi_names = RS.get_roi_names() - strList = [x.lower() for x in roi_names] - ni_img = nib.Nifti1Image(np.double(ct_array), affine=ct.affine) - spath = Path(path, subject_label) - - if not os.path.isdir(spath): - os.mkdir(spath) - nib.save(ni_img, Path(spath, 'ct.nii.gz')) - - for i in range(len(r1)): - roi = r1[i] - if roi in strList: - index = strList.index(roi.lower()) - mask_img = RS.get_roi_mask_by_name(roi_names[index]) - mask_img = np.rot90(mask_img) - mask_img = np.flip(mask_img, 2) - mask_img = np.flip(mask_img, 0) - mask = MetaTensor(mask_img.copy(), meta=meta) - mask = EnsureChannelFirst()(mask) - mask = Spacing(pixdim=(1, 1, 3))(mask) - mask_array = mask.array.squeeze() - ni_mask = nib.Nifti1Image(mask_array.astype('int'), affine=mask.affine, dtype='uint8') - nib.save(ni_mask, Path(spath, roi_name[i].lower() + '.nii.gz')) - """ - diff --git a/Training/Classification.py b/Training/Classification.py new file mode 100644 index 0000000..40f3fc9 --- /dev/null +++ b/Training/Classification.py @@ -0,0 +1,128 @@ +import torch +import torchvision +from torch import nn +from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping +from pytorch_lightning.strategies import DDPStrategy +from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything +import sys, os +import monai + +torch.cuda.empty_cache() +## Module - Dataloaders +from DataGenerator.DataGenerator import * +from Models.Classifier import Classifier +from Models.Linear import Linear +from Models.MixModel import MixModel +from monai.transforms import EnsureChannelFirstd, ScaleIntensityd, ResampleToMatchd, BoundingRectd +## Main +from sklearn.preprocessing import StandardScaler, OneHotEncoder +import toml +from Utils.GenerateSmoothLabel import get_smoothed_label_distribution, get_module +from Utils.PredictionReports import PredictionReports +from pathlib import Path # +from torchmetrics import ConfusionMatrix +import torchmetrics + +config = toml.load(sys.argv[1]) + + +def threshold_at_one(x): + return x > 0 + + +## 2D transform +img_keys = list(config['MODALITY'].keys()) + +if config['MODALITY'].values(): + train_transform = torchvision.transforms.Compose([ + EnsureChannelFirstd(keys=img_keys), + monai.transforms.CropForegroundd(keys=img_keys, source_key='Structs', select_fn=threshold_at_one), + monai.transforms.Resized(keys=img_keys, spatial_size=config['DATA']['dim']), + monai.transforms.ScaleIntensityd(keys=list(set(img_keys).difference(set(['Dose'])))), + # monai.transforms.RandAffined(keys=img_keys), + # monai.transforms.RandHistogramShiftd(keys=img_keys), + # monai.transforms.RandAdjustContrastd(keys=img_keys), + # monai.transforms.RandGaussianNoised(keys=img_keys), + + ]) + + val_transform = torchvision.transforms.Compose([ + EnsureChannelFirstd(keys=img_keys), + monai.transforms.CropForegroundd(keys=img_keys, source_key='Structs', select_fn=threshold_at_one), + monai.transforms.Resized(keys=img_keys, spatial_size=config['DATA']['dim']), + monai.transforms.ScaleIntensityd(list(set(img_keys).difference(set(['Dose'])))), + ]) +else: + train_transform = None + val_transform = None + +## First Connect to XNAT +session = xnat.connect(config['SERVER']['Address'], user=config['SERVER']['User'], + password=config['SERVER']['Password']) +SubjectList = QuerySubjectList(config, session) +SynchronizeData(config, SubjectList) +SubjectList.dropna(subset=['xnat_subjectdata_field_map_survival_months'], inplace=True) + +module_dict = nn.ModuleDict() +if config['DATA']['Multichannel']: ## Single-Model Multichannel learning + if config['MODALITY'].keys(): + module_dict['Image'] = Classifier(config, 'Image') +else: + for key in config['MODALITY'].keys(): # Multi-Model Single Channel learning + module_dict[key] = Classifier(config, key) + +if 'Records' in config.keys(): + SubjectList, clinical_cols = LoadClinicalData(config, SubjectList) + module_dict['Records'] = Linear(in_feat=len(clinical_cols), out_feat=42) +else: + clinical_cols = None + +## GeneratePath +for key in config['MODALITY'].keys(): + SubjectList[key + '_Path'] = "" +QuerySubjectInfo(config, SubjectList) +print(SubjectList) + +for iter in range(0, 3, 1): + seed_everything(np.random.randint(0, 10000), workers=True) + dataloader = DataModule(SubjectList, + config=config, + keys=config['MODALITY'].keys(), + train_transform=train_transform, + val_transform=val_transform, + clinical_cols=clinical_cols, + inference=False) + + model = MixModel(module_dict, config) + model.apply(model.weights_reset) + filename = config['DATA']['LogFolder'] + + logger = PredictionReports(config=config, save_dir='lightning_logs', name=filename) + logger.log_text() + logger._version = iter + callbacks = [ + ModelCheckpoint(dirpath=Path(logger.log_dir, 'ckpt'), + monitor='val_loss', + filename='Iter_' + str(iter), + save_top_k=2, + mode='min'), + # EarlyStopping(monitor='val_loss', + # check_finite=True), + ] + + trainer = Trainer( + accelerator="gpu", + devices=[0, 1, 2, 3], + strategy=DDPStrategy(find_unused_parameters=True), + max_epochs=40, + logger=logger, + callbacks=callbacks, + ) + trainer.fit(model, dataloader) + +with open(logger.root_dir + "/Config.ini", "w+") as toml_file: + toml.dump(config, toml_file) + toml_file.write("Train transform:\n") + toml_file.write(str(train_transform)) + toml_file.write("Val/Test transform:\n") + toml_file.write(str(val_transform)) diff --git a/Training/Regression.py b/Training/Regression.py deleted file mode 100644 index ccdbccb..0000000 --- a/Training/Regression.py +++ /dev/null @@ -1,132 +0,0 @@ -import torch -import torchvision -from torch import nn -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping -from pytorch_lightning.strategies import DDPStrategy -from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything -import sys, os -#import torchio as tio -import monai -torch.cuda.empty_cache() -## Module - Dataloaders -from DataGenerator.DataGenerator import * -from Models.Classifier import Classifier -from Models.Linear import Linear -from Models.MixModel import MixModel -from monai.transforms import EnsureChannelFirstd, ScaleIntensityd, ResampleToMatchd -## Main -from sklearn.preprocessing import StandardScaler, OneHotEncoder -import toml -from Utils.GenerateSmoothLabel import get_smoothed_label_distribution, get_module -from Utils.PredictionReports import PredictionReports -from pathlib import Path -from Utils.DicomTools import img_train_transform, img_val_transform -#import torchio as tio -from torchmetrics import ConfusionMatrix -import torchmetrics - -config = toml.load(sys.argv[1]) -total_backbone = config['MODEL']['Backbone'] + '_bitset7_seed_42' -## 2D transform -img_keys = list(config['MODALITY'].keys()) -## Multichannel masks -#img_keys.remove('Structs') -#if 'Structs' in config['DATA'].keys(): -# for roi in config['DATA']['Structs']: -# img_keys.append('Struct_' + roi) - -train_transform = torchvision.transforms.Compose([ - EnsureChannelFirstd(keys=img_keys), - #ResampleToMatchd(list(set(img_keys).difference(set(['CT']))), key_dst='CT'), - monai.transforms.ScaleIntensityd(keys=list(set(img_keys).difference(set(['Dose'])))), - # monai.transforms.ResizeWithPadOrCropd(keys=img_keys, spatial_size=config['DATA']['dim']), - monai.transforms.Resized(keys=img_keys, spatial_size=config['DATA']['dim']), - monai.transforms.RandAffined(keys=img_keys), - monai.transforms.RandHistogramShiftd(keys=img_keys), - monai.transforms.RandAdjustContrastd(keys=img_keys), - monai.transforms.RandGaussianNoised(keys=img_keys), - -]) - -val_transform = torchvision.transforms.Compose([ - EnsureChannelFirstd(keys=img_keys), - #ResampleToMatchd(list(set(img_keys).difference(set(['CT']))), key_dst='CT'), - monai.transforms.ScaleIntensityd(list(set(img_keys).difference(set(['Dose'])))), - # monai.transforms.ResizeWithPadOrCropd(img_keys, spatial_size=config['DATA']['dim']), - monai.transforms.Resized(keys=img_keys, spatial_size=config['DATA']['dim']), -]) - - -## First Connect to XNAT -session = xnat.connect(config['SERVER']['Address'], user=config['SERVER']['User'],password=config['SERVER']['Password']) - - -SubjectList = QuerySubjectList(config, session) -SynchronizeData(config, SubjectList) -SubjectList.dropna(subset=['xnat_subjectdata_field_map_survival_months'], inplace=True) - -module_dict = nn.ModuleDict() -if config['DATA']['Multichannel']: ## Single-Model Multichannel learning - if config['MODALITY'].keys(): - module_dict['Image'] = Classifier(config, 'Image') -else: - for key in config['MODALITY'].keys():# Multi-Model Single Channel learning - module_dict[key] = Classifier(config, key) - -if 'Records' in config.keys(): - module_dict['Records'] = Linear() - SubjectList, clinical_cols = LoadClinicalData(config, SubjectList) - -else: - clinical_cols = None - -## GeneratePath -for key in config['MODALITY'].keys(): - SubjectList[key+'_Path'] = "" -QuerySubjectInfo(config, SubjectList, session) -print(SubjectList) - -threshold = config['DATA']['threshold'] -ckpt_path = Path('./', total_backbone + '_ckpt') -rd = [2300, 5700, 998, 24, 7865, 9273] -for iter in range(0,5,1): - seed_everything(rd[iter]) - dataloader = DataModule(SubjectList, - config=config, - keys=config['MODALITY'].keys(), - train_transform=train_transform, - val_transform=val_transform, - clinical_cols=clinical_cols, - inference=False, - session = session) - - model = MixModel(module_dict, config) - model.apply(model.weights_reset) - #full_ckpt_path = Path(ckpt_path, 'Iter_'+ str(iter) + '.ckpt') - #model.load_state_dict(torch.load(full_ckpt_path)['state_dict']) - - filename = total_backbone - logger = PredictionReports(config=config, save_dir='lightning_logs', name=filename) - logger.log_text() - callbacks = [ - ModelCheckpoint(dirpath=ckpt_path, - monitor='val_loss', - filename='Iter_' + str(iter), - save_top_k=1, - mode='min'), - # EarlyStopping(monitor='val_loss', - # check_finite=True), - ] - - trainer = Trainer( - #gpus=1, - accelerator="gpu", - devices=[0,1,2,3], - strategy=DDPStrategy(find_unused_parameters=True), - max_epochs=30, - logger=logger, - callbacks=callbacks - ) - #model = torch.compile(model) - trainer.fit(model, dataloader) - torch.save({'state_dict': model.state_dict(),}, Path('ckpt_test_bitset7_r42', 'Iter_' + str(iter) + '.ckpt')) diff --git a/Training/val.py b/Training/val.py deleted file mode 100644 index a57d4d3..0000000 --- a/Training/val.py +++ /dev/null @@ -1,157 +0,0 @@ -import torch -import torchvision -from torch import nn -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping -from pytorch_lightning.strategies import DDPStrategy -from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything -import sys, os -import torchio as tio -import monai -torch.cuda.empty_cache() -## Module - Dataloaders -from DataGenerator.DataGenerator import * -from Models.Classifier import Classifier -from Models.Linear import Linear -from Models.MixModel import MixModel -from monai.transforms import EnsureChannelFirstd, ScaleIntensityd, ResampleToMatchd -## Main -from sklearn.preprocessing import StandardScaler, OneHotEncoder -import toml -from Utils.GenerateSmoothLabel import get_smoothed_label_distribution, get_module -from Utils.PredictionReports import PredictionReports -from pathlib import Path -from Utils.DicomTools import img_train_transform, img_val_transform -import torchio as tio -from torchmetrics import ConfusionMatrix -import torchmetrics - -config = toml.load(sys.argv[1]) -total_backbone = "" -## 2D transform -img_keys = list(config['MODALITY'].keys()) -img_keys.remove('Structs') -if 'Structs' in config['DATA'].keys(): - for roi in config['DATA']['Structs']: - img_keys.append('Struct_' + roi) - -#img_keys = list(config['MODALITY'].keys()) -#if 'Structs' in config['DATA'].keys(): -# img_keys.append('Mask') - -train_transform = torchvision.transforms.Compose([ - EnsureChannelFirstd(keys=img_keys), - ResampleToMatchd(list(set(img_keys).difference(set(['CT']))), key_dst='CT'), - monai.transforms.ScaleIntensityd(keys=img_keys), - # monai.transforms.ResizeWithPadOrCropd(keys=img_keys, spatial_size=config['DATA']['dim']), - monai.transforms.Resized(keys=img_keys, spatial_size=config['DATA']['dim']), - monai.transforms.RandAffined(keys=img_keys), - monai.transforms.RandHistogramShiftd(keys=img_keys), - monai.transforms.RandAdjustContrastd(keys=img_keys), - monai.transforms.RandGaussianNoised(keys=img_keys), - -]) - -val_transform = torchvision.transforms.Compose([ - EnsureChannelFirstd(keys=img_keys), - ResampleToMatchd(list(set(img_keys).difference(set(['CT']))), key_dst='CT'), - monai.transforms.ScaleIntensityd(img_keys), - # monai.transforms.ResizeWithPadOrCropd(img_keys, spatial_size=config['DATA']['dim']), - monai.transforms.Resized(keys=img_keys, spatial_size=config['DATA']['dim']), -]) - - -## First Connect to XNAT -session = xnat.connect(config['SERVER']['Address'], user=config['SERVER']['User'],password=config['SERVER']['Password']) - - -SubjectList = QuerySubjectList(config, session) -SynchronizeData(config, SubjectList) - -module_dict = nn.ModuleDict() -if config['DATA']['Multichannel']: ## Single-Model Multichannel learning - if config['MODALITY'].keys(): - module_dict['Image'] = Classifier(config, 'Image') -else: - for key in config['MODALITY'].keys():# Multi-Model Single Channel learning - module_dict[key] = Classifier(config, key) - -if 'Records' in config.keys(): - module_dict['Records'] = Linear() - SubjectList, clinical_cols = LoadClinicalData(config, SubjectList) - -else: - clinical_cols = None - -## GeneratePath -for key in config['MODALITY'].keys(): - SubjectList[key+'_Path'] = "" -QuerySubjectInfo(config, SubjectList, session) -print(SubjectList) - -threshold = config['DATA']['threshold'] -ckpt_path = Path('./', total_backbone + '_ckpt') -roc_list = [] -sp_list = [] -sensi_list = [] -acc_list = [] -pre_list = [] - -tprs = [] -roc = torchmetrics.ROC() -auroc = torchmetrics.AUROC() -fig = plt.figure() -base_fpr = np.linspace(0, 1, 39) -cm = ConfusionMatrix(num_classes=2) -prediction_labels_full_list = [] - -for iter in range(0,1,1): - # seed_everything(4200) - dataloader = DataModule(SubjectList, - config=config, - keys=config['MODALITY'].keys(), - train_transform=train_transform, - val_transform=val_transform, - clinical_cols=clinical_cols, - inference=False, - session = session) - - model = MixModel(module_dict, config) - full_ckpt_path = Path(ckpt_path, 'Iter_'+ str(iter) + '.ckpt') - # full_ckpt_path = Path('Classification_4_ckpt', 'Iter_' + str(iter) + '.ckpt') - # full_ckpt_path = 'ckpt_test/Iter_' + str(iter) + '.ckpt' - model.load_state_dict(torch.load(full_ckpt_path)['state_dict']) - # model.load_state_dict(torch.load(full_ckpt_path, map_location='cpu')['state_dict']) - model.eval() - print('start testing...') - worstCase = 0 - with torch.no_grad(): - outs = [] - for i, data in enumerate(dataloader.test_dataloader()): - truth = data[1] - x = data[0] - output = model.test_step(data, i) - outs.append(output) - - validation_labels_full = torch.cat([out['label'] for i, out in enumerate(outs)], dim=0) - prediction_labels_full = torch.cat([out['prediction'] for i, out in enumerate(outs)], dim=0) - prediction_labels_full_list.append(prediction_labels_full.tolist()) - -prediction_labels = torch.tensor(prediction_labels_full_list).mean(dim=0) -validation_labels = validation_labels_full -roc = auroc(prediction_labels, validation_labels.int()) -bcm = cm(prediction_labels.round(), validation_labels.int()) -tn = bcm[0][0] -tp = bcm[1][1] -fp = bcm[0][1] -fn = bcm[1][0] -acc = bcm.diag().sum() / bcm.sum() -sensitivity = tp / (tp + fn) -precision = tp / (tp + fp) -spec = tn / (tn + fp) - -print('avg_roc', str(roc)) -print('avg_specificity', str(spec)) -print('avg_sensitivity', str(sensitivity)) -print('avg_accuracy', str(acc)) -print('avg_precision', str(precision)) -print('finish test') diff --git a/Training/val2.py b/Training/val2.py deleted file mode 100644 index 6a72d19..0000000 --- a/Training/val2.py +++ /dev/null @@ -1,155 +0,0 @@ -import torch -import torchvision -from torch import nn -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping -from pytorch_lightning.strategies import DDPStrategy -from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything -import sys, os -#import torchio as tio -import monai -torch.cuda.empty_cache() -## Module - Dataloaders -from DataGenerator.DataGenerator import * -from Models.Classifier import Classifier -from Models.Linear import Linear -from Models.MixModel import MixModel -from monai.transforms import EnsureChannelFirstd, ScaleIntensityd, ResampleToMatchd -## Main -from sklearn.preprocessing import StandardScaler, OneHotEncoder -import toml -from Utils.GenerateSmoothLabel import get_smoothed_label_distribution, get_module -from Utils.PredictionReports import PredictionReports -from pathlib import Path -from Utils.DicomTools import img_train_transform, img_val_transform -#import torchio as tio -from torchmetrics import ConfusionMatrix -import torchmetrics - -config = toml.load(sys.argv[1]) -total_backbone = config['MODEL']['Backbone'] + '_bitset5_seed_42' -## 2D transform -img_keys = list(config['MODALITY'].keys()) -#img_keys.remove('Structs') -#if 'Structs' in config['DATA'].keys(): -# for roi in config['DATA']['Structs']: -# img_keys.append('Struct_' + roi) - -train_transform = torchvision.transforms.Compose([ - EnsureChannelFirstd(keys=img_keys), - #ResampleToMatchd(list(set(img_keys).difference(set(['CT']))), key_dst='CT'), - monai.transforms.ScaleIntensityd(keys=img_keys), - # monai.transforms.ResizeWithPadOrCropd(keys=img_keys, spatial_size=config['DATA']['dim']), - monai.transforms.Resized(keys=img_keys, spatial_size=config['DATA']['dim']), - monai.transforms.RandAffined(keys=img_keys), - monai.transforms.RandHistogramShiftd(keys=img_keys), - monai.transforms.RandAdjustContrastd(keys=img_keys), - monai.transforms.RandGaussianNoised(keys=img_keys), - -]) - -val_transform = torchvision.transforms.Compose([ - EnsureChannelFirstd(keys=img_keys), - #ResampleToMatchd(list(set(img_keys).difference(set(['CT']))), key_dst='CT'), - monai.transforms.ScaleIntensityd(img_keys), - # monai.transforms.ResizeWithPadOrCropd(img_keys, spatial_size=config['DATA']['dim']), - monai.transforms.Resized(keys=img_keys, spatial_size=config['DATA']['dim']), -]) - - -## First Connect to XNAT -session = xnat.connect(config['SERVER']['Address'], user=config['SERVER']['User'],password=config['SERVER']['Password']) - - -SubjectList = QuerySubjectList(config, session) -SynchronizeData(config, SubjectList) - -module_dict = nn.ModuleDict() -if config['DATA']['Multichannel']: ## Single-Model Multichannel learning - if config['MODALITY'].keys(): - module_dict['Image'] = Classifier(config, 'Image') -else: - for key in config['MODALITY'].keys():# Multi-Model Single Channel learning - module_dict[key] = Classifier(config, key) - -if 'Records' in config.keys(): - module_dict['Records'] = Linear() - SubjectList, clinical_cols = LoadClinicalData(config, SubjectList) - -else: - clinical_cols = None - -## GeneratePath -for key in config['MODALITY'].keys(): - SubjectList[key+'_Path'] = "" -QuerySubjectInfo(config, SubjectList, session) - -threshold = config['DATA']['threshold'] -ckpt_path = Path('./', total_backbone + '_ckpt') -roc_list = [] -sp_list = [] -sensi_list = [] -acc_list = [] -pre_list = [] - -tprs = [] -roc = torchmetrics.ROC() -auroc = torchmetrics.AUROC() -fig = plt.figure() -base_fpr = np.linspace(0, 1, 39) -cm = ConfusionMatrix(num_classes=2) -prediction_labels_full_list = [] - -for iter in range(0, 1, 1): - # seed_everything(4200) - dataloader = DataModule(SubjectList, - config=config, - keys=config['MODALITY'].keys(), - train_transform=train_transform, - val_transform=val_transform, - clinical_cols=clinical_cols, - inference=False, - session = session) - - model = MixModel(module_dict, config) - #full_ckpt_path = Path(ckpt_path, 'Iter_'+ str(iter) + '.ckpt') - #full_ckpt_path = Path('Classification_4_ckpt', 'Iter_' + str(iter) + '.ckpt') - full_ckpt_path = 'ckpt_test_bitset7_r42/Iter_' + str(iter) + '.ckpt' - model.load_state_dict(torch.load(full_ckpt_path)['state_dict']) - # model.load_state_dict(torch.load(full_ckpt_path, map_location='cpu')['state_dict']) - model.eval() - print('start testing...') - worstCase = 0 - with torch.no_grad(): - outs = [] - for i, data in enumerate(dataloader.test_dataloader()): - truth = data[1] - x = data[0] - output = model.test_step(data, i) - outs.append(output) - - validation_labels_full = torch.cat([out['label'] for i, out in enumerate(outs)], dim=0) - prediction_labels_full = torch.cat([out['prediction'] for i, out in enumerate(outs)], dim=0) - roc_i = auroc(prediction_labels_full, validation_labels_full.int()) - print('roc_'+str(iter), roc_i) - prediction_labels_full_list.append(prediction_labels_full.tolist()) - -prediction_labels = torch.tensor(prediction_labels_full_list).mean(dim=0) -validation_labels = validation_labels_full -roc = auroc(prediction_labels, validation_labels.int()) -bcm = cm(prediction_labels.round(), validation_labels.int()) -tn = bcm[0][0] -tp = bcm[1][1] -fp = bcm[0][1] -fn = bcm[1][0] -acc = bcm.diag().sum() / bcm.sum() -sensitivity = tp / (tp + fn) -precision = tp / (tp + fp) -spec = tn / (tn + fp) - -print('avg_roc', str(roc)) -print('avg_specificity', str(spec)) -print('avg_sensitivity', str(sensitivity)) -print('avg_accuracy', str(acc)) -print('avg_precision', str(precision)) -print('finish test') - diff --git a/Utils/#DicomTools.py# b/Utils/#DicomTools.py# deleted file mode 100644 index c6a625c..0000000 --- a/Utils/#DicomTools.py# +++ /dev/null @@ -1,170 +0,0 @@ -import os -import glob -import cv2 -import SimpleITK as sitk -import pydicom as dicom -from pydicom import dcmread -import numpy as np -import matplotlib -import matplotlib.pyplot as plt -from matplotlib.path import Path -from monai.data import ITKReader, PILReader -import torchio as tio -from sklearn.preprocessing import KBinsDiscretizer - -sitk.ProcessObject_SetGlobalWarningDisplay(False) -from rt_utils import RTStructBuilder -from scipy.ndimage import * - - -def get_bbox_from_mask(mask, img_shape): - pos = np.where(mask) - if pos[0].shape[0] == 0: - bbox = np.zeros((0, 4)) - else: - xmin = np.min(pos[3]) - xmax = np.max(pos[3]) - ymin = np.min(pos[2]) - ymax = np.max(pos[2]) - zmin = np.min(pos[1]) - zmax = np.max(pos[1]) - bbox = [zmin, zmax, ymin, ymax, xmin, xmax] - return bbox - - -def ReadDicom(dicom_path, view_image=False): - Reader = sitk.ImageSeriesReader() - filenames = sorted(glob.glob(dicom_path + '/*.dcm')) - Reader.SetFileNames(sorted(filenames)) - - assert len(filenames) > 0 - Session = Reader.Execute() - return Session - - -def ResamplingITK(Session, Reference, is_label=False, pad_value=0): - resample = sitk.ResampleImageFilter() - resample.SetOutputSpacing(Reference.GetSpacing()) - resample.SetSize(Reference.GetSize()) - resample.SetOutputDirection(Reference.GetDirection()) - resample.SetOutputOrigin(Reference.GetOrigin()) - resample.SetTransform(sitk.Transform()) - resample.SetDefaultPixelValue(Session.GetPixelIDValue()) - - if is_label: - resample.SetInterpolator(sitk.sitkNearestNeighbor) - else: - resample.SetInterpolator(sitk.sitkLinear) - Resampled = resample.Execute(Session) - return Resampled - - -def RStoContour(rs_path, targetROI='PTV'): - rs_file = glob.glob(rs_path + '*.dcm') - ds = dcmread(rs_file[0]) - for item in ds.StructureSetROISequence: - if item.ROIName == targetROI: - ROI = ds.ROIContourSequence[item.ROINumber - 1] - contours = [contour for contour in ROI.ContourSequence] - return contours - - -def poly_to_mask(polygon, img_shape): - x, y = np.meshgrid(np.arange(img_shape[0]), np.arange(img_shape[1])) - x, y = x.flatten(), y.flatten() - points = np.vstack((x, y)).T - path = Path(polygon) - mask = path.contains_points(points) - mask = mask.reshape(img_shape) - - return mask - - -def ViewROI(patient_id, img_array, mask_array, ROIbox, Inputbox): - masked = np.ma.masked_where(mask_array == 0, mask_array) - plt.subplot(1, 3, 1) - plt.title('{} ROI mask'.format(patient_id)) - plt.imshow(img_array, cmap='gray') - plt.imshow(masked, vmin=0, vmax=1, alpha=0.5) - plt.subplot(1, 3, 2) - plt.title('ROI Box') - plt.imshow(ROIbox, cmap='gray') - plt.subplot(1, 3, 3) - plt.title('Input Box') - plt.imshow(Inputbox, cmap='gray') - plt.show() - - -def get_masked_img_voxel(ImageVoxel, mask_voxel): - bbox = get_bbox_from_mask(mask_voxel, np.shape(ImageVoxel)) - assert len(mask_voxel) == ImageVoxel.shape[0] - img_masked = ImageVoxel[:, bbox[0]:bbox[1], bbox[2]:bbox[3], bbox[4]:bbox[5]] - return img_masked - - -def img_train_transform(img_dim): - transform = tio.Compose([ - tio.transforms.ZNormalization(), - tio.RandomAffine(), - tio.RandomFlip(), - tio.RandomNoise(), - tio.RandomMotion(), - tio.transforms.Resize(img_dim), - tio.RescaleIntensity(out_min_max=(0, 1)) - ]) - return transform - - -def img_val_transform(img_dim): - transform = tio.Compose([ - tio.transforms.ZNormalization(), - tio.transforms.Resize(img_dim), - tio.RescaleIntensity(out_min_max=(0, 1)) - ]) - return transform - - -def class_stratify(SubjectList, config): - ptarget = SubjectList['xnat_subjectdata_field_map_' + config['DATA']['target']] - kbins = KBinsDiscretizer(n_bins=5, encode='ordinal', strategy='uniform') - ptarget = np.array(ptarget).reshape((len(ptarget), 1)) - data_trans = kbins.fit_transform(ptarget) - return data_trans - - -def get_RS_masks(slabel, CTPath, mask_imgs, RSfile, mask_names): - #RS = RTStructBuilder.create_from(dicom_series_path=CTPath, rt_struct_path=RSfile) - #roi_names = RS.get_roi_names() - #strList = [x.lower() for x in roi_names] - #for idx, roi in enumerate(mask_names): - # if roi.lower() in strList: - # roi_s = roi_names[strList.index(roi.lower())] - # mask_img = RS.get_roi_mask_by_name(roi_s) - # # mask_img = distance_transform_edt(mask_img) - # mask_imgs = BitSet(mask_imgs, idx * np.ones_like(mask_imgs), mask_img) - # else: - # raise ValueError(slabel + " has no ROI of name " + roi + " found in RTStruct") - # - #return mask_imgs - - RS = RTStructBuilder.create_from(dicom_series_path=CTPath, rt_struct_path=RSfile) - roi_names = RS.get_roi_names() - strList = [x.lower() for x in roi_names] - for idx, roi in enumerate(mask_names): - if roi.lower() in strList: - roi_s = roi_names[strList.index(roi.lower())] - mask_img = RS.get_roi_mask_by_name(roi_s) - mask_img = distance_transform_edt(mask_img) - mask_imgs = mask_imgs + mask_img - else: - raise ValueError(slabel + " has no ROI of name " + roi + " found in RTStruct") - - return mask_imgs - - -def BitSet(n, p, b): - p = p.astype(int) - n = n.astype(int) - mask = 1 << p - bm = b << p - return (n & ~mask) | bm diff --git a/Utils/DicomTools.py b/Utils/DicomTools.py index 5ad03a5..3f62440 100644 --- a/Utils/DicomTools.py +++ b/Utils/DicomTools.py @@ -1,21 +1,19 @@ import os import glob -#import cv2 import SimpleITK as sitk import pydicom as dicom from pydicom import dcmread import numpy as np import matplotlib import matplotlib.pyplot as plt -#from monai.data import ITKReader, PILReader -#import torchio as tio from pathlib import Path from sklearn.preprocessing import KBinsDiscretizer from monai.transforms import LoadImage sitk.ProcessObject_SetGlobalWarningDisplay(False) from rt_utils import RTStructBuilder from scipy.ndimage import * - +from concurrent.futures import ThreadPoolExecutor +import concurrent def get_bbox_from_mask(mask, img_shape): pos = np.where(mask) @@ -101,84 +99,6 @@ def get_masked_img_voxel(ImageVoxel, mask_voxel): img_masked = ImageVoxel[:, bbox[0]:bbox[1], bbox[2]:bbox[3], bbox[4]:bbox[5]] return img_masked - -def img_train_transform(img_dim): - transform = tio.Compose([ - tio.transforms.ZNormalization(), - tio.RandomAffine(), - tio.RandomFlip(), - tio.RandomNoise(), - tio.RandomMotion(), - tio.transforms.Resize(img_dim), - tio.RescaleIntensity(out_min_max=(0, 1)) - ]) - return transform - - -def img_val_transform(img_dim): - transform = tio.Compose([ - tio.transforms.ZNormalization(), - tio.transforms.Resize(img_dim), - tio.RescaleIntensity(out_min_max=(0, 1)) - ]) - return transform - - -def class_stratify(SubjectList, config): - ptarget = SubjectList['xnat_subjectdata_field_map_' + config['DATA']['target']] - kbins = KBinsDiscretizer(n_bins=5, encode='ordinal', strategy='uniform') - ptarget = np.array(ptarget).reshape((len(ptarget), 1)) - data_trans = kbins.fit_transform(ptarget) - return data_trans - - -def get_RS_masks(slabel, CTPath, mask_imgs, RSfile, mask_names): - #RS = RTStructBuilder.create_from(dicom_series_path=CTPath, rt_struct_path=RSfile) - #roi_names = RS.get_roi_names() - #strList = [x.lower() for x in roi_names] - #for idx, roi in enumerate(mask_names): - # if roi.lower() in strList: - # roi_s = roi_names[strList.index(roi.lower())] - # mask_img = RS.get_roi_mask_by_name(roi_s) - # # mask_img = distance_transform_edt(mask_img) - # mask_imgs = BitSet(mask_imgs, idx * np.ones_like(mask_imgs), mask_img) - # else: - # raise ValueError(slabel + " has no ROI of name " + roi + " found in RTStruct") - # - #return mask_imgs - - RS = RTStructBuilder.create_from(dicom_series_path=CTPath, rt_struct_path=RSfile) - roi_names = RS.get_roi_names() - strList = [x.lower() for x in roi_names] - for idx, roi in enumerate(mask_names): - if roi.lower() in strList: - roi_s = roi_names[strList.index(roi.lower())] - mask_img = RS.get_roi_mask_by_name(roi_s) - mask_img = distance_transform_edt(mask_img) - mask_imgs = mask_imgs + mask_img - else: - raise ValueError(slabel + " has no ROI of name " + roi + " found in RTStruct") - - return mask_imgs - -def get_nii_masks(slabel, mask_imgs, MPath, mask_names): - for idx, roi in enumerate(mask_names): - try: - data, meta = LoadImage()(Path(MPath, roi + '.nii.gz')) - except: - raise ValueError(slabel + " has no ROI of name " + roi + " found in RTStruct") - mask_imgs = BitSet(mask_imgs, idx * np.ones_like(mask_imgs), data) - return mask_imgs - #for roi in mask_names: - # try: - # data, meta = LoadImage()(Path(MPath, roi + '.nii.gz')) - # except: - # raise ValueError(slabel + " has no ROI of name " + roi + " found in RTStruct") - # mask_img = distance_transform_edt(data) - # mask_imgs = mask_imgs + mask_img - #return mask_imgs - - def BitSet(n, p, b): p = p.astype(int) n = n.astype(int) @@ -186,3 +106,44 @@ def BitSet(n, p, b): mask = 1 << p bm = b << p return (n & ~mask) | bm + +def QuerySubjectInfo(config, SubjectList, session): + if config['DATA']['Nifty']: + for i in range(len(SubjectList)): + subject_label = SubjectList.loc[i,'subject_label'] + for key in config['MODALITY'].keys(): + SubjectList.loc[i, key + '_Path'] = Path(config['DATA']['DataFolder'], subject_label) + else: + with ThreadPoolExecutor(max_workers=10) as executor: + future_to_url = {executor.submit(get_subject_info, config, session, subjectid) for subjectid in + SubjectList['subjectid']} + executor.shutdown(wait=True) + for future in concurrent.futures.as_completed(future_to_url): + subjectdata = future.result() + subjectid = subjectdata["xnat:Subject"][0]["@ID"] + for key in config['MODALITY'].keys(): + path = GeneratePath(subjectdata, Modality=key, config=config) + if key == 'CT': + SubjectList.loc[SubjectList.subjectid == subjectid, key + '_Path'] = path + else: + spath = glob.glob(path + '/*dcm') + SubjectList.loc[SubjectList.subjectid == subjectid, key + '_Path'] = spath[0] + +def GeneratePath(subjectdata, Modality, config): + subject = subjectdata['xnat:Subject'][0] + subject_label = subject['@label'] + experiments = subject['xnat:experiments'][0]['xnat:experiment'] + + ## Won't work with many experiments yet + for experiment in experiments: + experiment_label = experiment['@label'] + scans = experiment['xnat:scans'][0]['xnat:scan'] + for scan in scans: + if (scan['@type'] in Modality): + scan_label = scan['@ID'] + '-' + scan['@type'] + resources_label = scan['xnat:file'][0]['@label'] + if resources_label == 'SNAPSHOTS': + resources_label = scan['xnat:file'][1]['@label'] + path = os.path.join(config['DATA']['DataFolder'], subject_label, experiment_label, 'scans', + scan_label, 'resources', resources_label, 'files') + return path diff --git a/Utils/GenerateSmoothLabel.py b/Utils/GenerateSmoothLabel.py deleted file mode 100644 index 90af2f6..0000000 --- a/Utils/GenerateSmoothLabel.py +++ /dev/null @@ -1,66 +0,0 @@ -import numpy as np -import matplotlib.pyplot as plt -from scipy.ndimage import convolve1d -from scipy.ndimage import gaussian_filter1d -from scipy.signal.windows import triang -from sksurv.metrics import cumulative_dynamic_auc -from torch import nn -from Models.Classifier import Classifier -from Models.Linear import Linear - -def get_lds_kernel_window(kernel, ks, sigma): - assert kernel in ['gaussian', 'triang', 'laplace'] - half_ks = (ks - 1) // 2 - if kernel == 'gaussian': - base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks - kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / max(gaussian_filter1d(base_kernel, sigma=sigma)) - elif kernel == 'triang': - kernel_window = triang(ks) - else: - laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma) - kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / max( - map(laplace, np.arange(-half_ks, half_ks + 1))) - - return kernel_window - - -def get_smoothed_label_distribution(SubjectList, config): - label_all = get_train_label(SubjectList, config) - range_max = np.max(label_all).astype(int) + 1 - range_min = np.min(label_all).astype(int) - - label_range = np.arange(range_min, range_max, 1) - - bin_index_per_label = np.histogram(label_all, bins=label_range) - lds_kernel_window = get_lds_kernel_window(kernel='gaussian', ks=7, sigma=3) - eff_label_dist = convolve1d(np.array(bin_index_per_label[0]), weights=lds_kernel_window, mode='constant') - - eff_num_per_label = [eff_label_dist[bin_idx] for bin_idx in np.arange(eff_label_dist.shape[0])] - weights = [np.float32(1 / x) for x in eff_num_per_label] - - label_mean = np.mean(label_all) - mse = ((label_all - label_mean) ** 2).mean() - return weights, bin_index_per_label[1] - - -def get_train_label(SubjectList, config): - train_label = [] - for patient in SubjectList: - label = patient.fields[config['DATA']['target']] - train_label.append(label) - return train_label - - -def get_module(config): - s_module = config['DATA']['module'] - module_dict = nn.ModuleDict() - if config['MODEL']['Clinical_Backbone']: - Clinical_backbone = Linear() - for i, module in enumerate(s_module): - if module == 'CT' or module == 'Dose' or module == 'PET': - Backbone = Classifier(config) - module_dict[module] = Backbone - else: - module_dict[module] = Clinical_backbone - - return module_dict diff --git a/Utils/PredictionReports.py b/Utils/PredictionReports.py index 714b6ba..8ee7807 100644 --- a/Utils/PredictionReports.py +++ b/Utils/PredictionReports.py @@ -4,9 +4,9 @@ from pytorch_lightning.loggers.base import rank_zero_experiment from torch.utils.tensorboard import SummaryWriter import numpy as np - +from lifelines.utils import concordance_index import matplotlib.pyplot as plt -# plt.switch_backend('agg') +import matplotlib import torchvision from pytorch_lightning.loggers import LightningLoggerBase from sksurv.metrics import cumulative_dynamic_auc @@ -21,6 +21,7 @@ from pytorch_lightning.loggers import TensorBoardLogger from torchmetrics import ConfusionMatrix + class PredictionReports(TensorBoardLogger): def __init__(self, config, save_dir: str, @@ -38,36 +39,6 @@ def __init__(self, config, def log_hyperparams(self, params: argparse.Namespace, *args, **kwargs): pass - @property - def version(self): - description = '' - for i, param in enumerate(self.config['CRITERIA'].keys()): - clinical_criteria = str(self.config['CRITERIA'][param]) - if i > 0: - description = description + '_' - description = description + param + '+' + '+'.join(clinical_criteria) - # Return the experiment version, int or str. - - sub_str = description + '_' + 'modalities' + '+' + '+'.join(self.config['MODALITY'].keys()) - self._version = self._get_next_version(sub_str) - description = description + '_' + 'modalities' + '+' + '+'.join(self.config['MODALITY'].keys()) + '_' + str(self._version) - return description - - def _get_next_version(self, sub_str): - root_dir = self.root_dir - listdir_info = self._fs.listdir(root_dir) - existing_versions = [] - for listing in listdir_info: - d = listing["name"] - bn = os.path.basename(d) - if self._fs.isdir(d) and bn.startswith(sub_str): - dir_ver = bn.split("_")[-1] - existing_versions.append(int(dir_ver)) - if len(existing_versions) == 0: - return 0 - - return max(existing_versions) + 1 - def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): for k, v in metrics.items(): if isinstance(v, torch.Tensor): @@ -75,24 +46,24 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): self.experiment.add_scalar(k, v, step) def log_image(self, img, text, current_epoch=None): - img_batch = img.view(img.shape[0] * img.shape[1], *[1, img.shape[2], img.shape[3]]) - grid = torchvision.utils.make_grid(img_batch) + img = img.transpose(2, 0) + img_batch = img.view(img.shape[0], *[1, img.shape[1], img.shape[2]]) + grid = torchvision.utils.make_grid(img_batch, normalize=True) self.experiment.add_image(text, grid, current_epoch) return grid def log_text(self) -> None: configurations = 'The modules included are ' + str(self.config['MODALITY'].keys()) - # configurations = 'The img_dim is ' + str(self.config['DATA']['dim']) + ' and the modules included are ' + - # str(self.config['MODALITY'].keys()) self.experiment.add_text('configurations:', configurations) - def regression_matrix(self, prediction, label, prefix): + def regression_matrix(self, prediction, censor_status, label, prefix): ## matrix should be metrics r_out = {} if 'cindex' in self.config['CHECKPOINT']['matrix']: - cindex = c_index(prediction, label) - r_out[prefix + 'cindex'] = cindex[0] + cindex = concordance_index(label.cpu().detach().numpy(), prediction.cpu().detach().numpy(), + censor_status.cpu().detach().numpy()) + r_out[prefix + 'cindex'] = cindex if 'r2' in self.config['CHECKPOINT']['matrix']: - r2 = r2_index(prediction, label) + r2 = r2_index(prediction.detach(), label.detach()) r_out[prefix + 'r2'] = r2 return r_out @@ -109,7 +80,7 @@ def classification_matrix(self, prediction, label, prefix): accuracy = auroc(prediction, label.int()) c_out[prefix + 'roc'] = accuracy if 'Specificity' in self.config['CHECKPOINT']['matrix']: - spec = tn /(tn + fp) + spec = tn / (tn + fp) c_out[prefix + 'specificity'] = spec if 'Sensitivity' in self.config['CHECKPOINT']['matrix']: @@ -125,36 +96,10 @@ def classification_matrix(self, prediction, label, prefix): c_out[prefix + 'precision'] = precision return c_out - def generate_cumulative_dynamic_auc(self, prediction, label, current_epoch, prefix) -> None: - # this function has issues - risk_score = 1 / prediction - va_times = np.arange(int(label.cpu().min()) + 1, label.cpu().max(), 1) - - dtypes = np.dtype([('event', np.bool_), ('time', np.float)]) - construct_test = np.ndarray(shape=(len(label),), dtype=dtypes) - for i in range(len(label)): - construct_test[i] = (True, label[i].cpu().numpy()) - - cph_auc, cph_mean_auc = cumulative_dynamic_auc( - construct_test, construct_test, risk_score.cpu(), va_times - ) - - fig = plt.figure() - plt.plot(va_times, cph_auc, marker="o") - plt.axhline(cph_mean_auc, linestyle="--") - plt.ylim([0, 1]) - plt.xlabel("survival months") - plt.ylabel("time-dependent AUC") - plt.grid(True) - self.experiment.add_figure(prefix + "AUC", fig, current_epoch) - plt.close(fig) - def plot_AUROC(self, prediction, label, prefix, current_epoch=None) -> None: roc = torchmetrics.ROC() fpr, tpr, _ = roc(prediction, label) fig = plt.figure() - # lw = 2 - # plt.plot(fpr.cpu(), tpr.cpu(), color='darkorange', lw=lw) plt.plot(fpr.cpu(), tpr.cpu(), color='darkorange') plt.title(prefix + '_roc_curve') plt.xlabel('False Positive Rate') @@ -165,37 +110,31 @@ def plot_AUROC(self, prediction, label, prefix, current_epoch=None) -> None: def worst_case_show(self, validation_step_outputs, prefix): out = {} worst_AE = 0 + label = '' for i, data in enumerate(validation_step_outputs): loss = data['MAE'] idx = torch.argmax(loss) if loss[idx] > worst_AE: - if 'CT' in self.config['DATA']['target']: - worst_img = data['CT'][idx] - if 'Dose' in self.config['DATA']['target']: - worst_dose = data['dose'][idx] + if 'CT' in self.config['MODALITY'].keys(): + worst_img = data['Image'][idx][0, :, :, :] + if 'Dose' in self.config['MODALITY'].keys(): + worst_dose = data['Image'][idx][1, :, :, + :] ### this index needs to be careful when adding pet image worst_AE = loss[idx] + label = data['slabel'][idx] out[prefix + 'worst_AE'] = worst_AE - if 'CT' in self.config['DATA']['target']: + if 'CT' in self.config['MODALITY'].keys(): out[prefix + 'worst_img'] = worst_img - if 'Dose' in self.config['DATA']['target']: + if 'Dose' in self.config['MODALITY'].keys(): out[prefix + 'worst_dose'] = worst_dose + out[prefix + 'slabel'] = label return out - # def report_step(self, prediction, label, step, prefix) -> None: - # if self.config['MODEL']['Prediction_type'] == 'Regression': - # regression_out = self.regression_matrix(prediction, label, prefix) - # self.log_metrics(regression_out, step) - # if self.config['MODEL']['Prediction_type'] == 'Classification': - # classification_out = self.classification_matrix(prediction.squeeze(), label, prefix) - # self.log_metrics(classification_out, step) - - def report_epoch(self, prediction, label, validation_step_outputs, + def report_epoch(self, prediction, censor_status, label, validation_step_outputs, current_epoch, prefix) -> None: if self.config['MODEL']['Prediction_type'] == 'Regression': - regression_out = self.regression_matrix(prediction, label, prefix) + regression_out = self.regression_matrix(prediction, censor_status, label, prefix) self.log_metrics(regression_out, current_epoch) - if 'AUROC' in self.config['CHECKPOINT']['matrix']: - self.generate_cumulative_dynamic_auc(prediction, label, current_epoch, prefix) if self.config['MODEL']['Prediction_type'] == 'Classification': classification_out = self.classification_matrix(prediction.squeeze(), label, prefix) @@ -206,17 +145,19 @@ def report_epoch(self, prediction, label, validation_step_outputs, if 'WorstCase' in self.config['CHECKPOINT']['matrix']: worst_record = self.worst_case_show(validation_step_outputs, prefix) self.log_metrics({prefix + 'worst_AE': worst_record[prefix + 'worst_AE']}, current_epoch) - if 'CT' in self.config['DATA']['target']: + self.log_metrics('worst_subject: ', str(worst_record[prefix + 'slabel'])) + if 'CT' in self.config['MODALITY'].keys(): text = 'validate_worst_case_img' self.log_image(worst_record[prefix + 'worst_img'], text, current_epoch) - if 'Dose' in self.config['DATA']['target']: + if 'Dose' in self.config['MODALITY'].keys(): text = 'validate_worst_case_dose' self.log_image(worst_record[prefix + 'worst_dose'], text, current_epoch) - def report_test(self, config, outs, model, prediction_labels, validation_labels, prefix): + def report_test(self, config, outs, model, prediction_labels, validation_censor, validation_labels, prefix): if 'WorstCase' in config['CHECKPOINT']['matrix']: worst_record = self.worst_case_show(outs, prefix) self.experiment.add_text('worst_test_AE: ', str(worst_record[prefix + 'worst_AE'])) + self.experiment.add_text('worst_subject: ', str(worst_record[prefix + 'slabel'])) if 'CT' in config['MODALITY'].keys(): text = 'test_worst_case_img' self.log_image(worst_record[prefix + 'worst_img'], text) @@ -226,8 +167,7 @@ def report_test(self, config, outs, model, prediction_labels, validation_labels, if config['MODEL']['Prediction_type'] == 'Regression': self.experiment.add_text('test loss: ', str(model.loss_fcn(prediction_labels, validation_labels))) - self.generate_cumulative_dynamic_auc(prediction_labels, validation_labels, 0, prefix) - regression_out = self.regression_matrix(prediction_labels, validation_labels, prefix) + regression_out = self.regression_matrix(prediction_labels, validation_censor, validation_labels, prefix) self.experiment.add_text('test_cindex: ', str(regression_out[prefix + 'cindex'])) self.experiment.add_text('test_r2: ', str(regression_out[prefix + 'r2'])) return regression_out @@ -251,18 +191,8 @@ def report_test(self, config, outs, model, prediction_labels, validation_labels, def r2_index(prediction, label): - loss = nn.MSELoss() - MSE = loss(prediction, label) - SSres = MSE * label.shape[0] + loss = nn.MSELoss(reduction='sum') + SSres = loss(prediction, label) SStotal = torch.sum(torch.square(label - torch.mean(label))) r2 = 1 - SSres / SStotal return r2 - - -def c_index(prediction, label): - event_indicator = torch.ones(label.shape, dtype=torch.bool) - risk = 1 / prediction.squeeze() - cindex = concordance_index_censored(event_indicator.cpu().detach().numpy(), - event_time=label.cpu().detach().numpy(), - estimate=risk.cpu().detach().numpy()) - return cindex diff --git a/check.py b/check.py deleted file mode 100644 index 2e784f3..0000000 --- a/check.py +++ /dev/null @@ -1,54 +0,0 @@ -import toml -import sys -import glob -import nibabel as nib -import numpy as np -from Utils.DicomTools import * -from pathlib import Path -from DataGenerator.DataGenerator import QuerySubjectList, SynchronizeData, QuerySubjectInfo -import os -from rt_utils import RTStructBuilder -import xnat -from monai.transforms import LoadImage, ResampleToMatchd, EnsureChannelFirstd -from monai.data.image_writer import ITKWriter - -session = xnat.connect('http://128.16.11.124:8080/xnat', user='admin', password='mortavar1977') -config = toml.load(sys.argv[1]) -from pydicom import dcmread - -session = xnat.connect(config['SERVER']['Address'], user=config['SERVER']['User'], - password=config['SERVER']['Password']) -SubjectList = QuerySubjectList(config, session) -print(SubjectList) -SynchronizeData(config, SubjectList) -QuerySubjectInfo(config, SubjectList, session) -for i in range(0,len(SubjectList),1): - print(i) - CTPath = SubjectList['CT_Path'][i].split('/') - scanPath = '/'.join(CTPath[0:8]) - Dosefile = glob.glob(scanPath + '/*-RTDOSE') - data = {} - meta = {} - if len(Dosefile) > 1: - for j in range(len(Dosefile)): - file = glob.glob(Dosefile[j] + '/**/*.dcm', recursive=True) - if j == 0: - info = dcmread(file[0]) - data['Dose'+ str(j)], meta['Dose'+str(j)] = LoadImage()(file[0]) - - data = EnsureChannelFirstd(data.keys())(data) - data = ResampleToMatchd(list(set(data.keys()).difference(set(['Dose0']))), key_dst='Dose0')(data) - - total_array = np.zeros_like(data['Dose0']) - for key in data.keys(): - total_array = total_array + data[key]*np.float64(meta[key]['3004|000e'])/np.float64(meta['Dose0']['3004|000e']) - - total_array = total_array.squeeze() - total_array = np.transpose(total_array, (2, 1, 0)) - total_array = np.uint32(total_array) - - info.PixelData = total_array.tobytes() - if not os.path.isdir(Path('/home/dgs1/data/dose/', CTPath[5])): - os.mkdir(Path('/home/dgs1/data/dose/', CTPath[5])) - dpath = Path('/home/dgs1/data/dose/', CTPath[5], '1-1.dcm') - info.save_as(dpath) diff --git a/renameCT.py b/renameCT.py deleted file mode 100644 index 460d55c..0000000 --- a/renameCT.py +++ /dev/null @@ -1,99 +0,0 @@ -import sys -#sys.path.insert(0, '/home/dgs1/Software/OutcomePrediction/') -import toml -import glob -import nibabel as nib -import numpy as np -from Utils.DicomTools import * -from pathlib import Path -from DataGenerator.DataGenerator import QuerySubjectList, SynchronizeData, QuerySubjectInfo -from Utils.DicomTools import * -import os -from rt_utils import RTStructBuilder -import xnat -from Utils.FixRTSS import * -session = xnat.connect('http://128.16.11.124:8080/xnat', user='admin', password='mortavar1977') -import csv -from monai.transforms import Spacing, LoadImage, EnsureChannelFirst -from monai.data import MetaTensor -import itk -config = toml.load(sys.argv[1]) - -session = xnat.connect(config['SERVER']['Address'], user=config['SERVER']['User'],password=config['SERVER']['Password']) -SubjectList = QuerySubjectList(config, session) -print(SubjectList) -SynchronizeData(config, SubjectList) - -for key in config['MODALITY'].keys(): - SubjectList[key+'_Path'] = "" -QuerySubjectInfo(config, SubjectList, session) -# -# roi_series = ['Heart', 'Oesophagus', 'Spinal Canal', 'Prox bronch tree', 'Proximal trachea', 'Cwall & ribs', 'L Lung', -# 'R Lung', 'Brachial Plexus'] -# roi_name = ['HEART', 'ESOPHAGUS', 'SPINAL_CORD', 'PROX_BRONCH_TREE', 'PROXIMAL_TRACHEA', 'CWALL_RIBS', 'LUNG_LEFT', -# 'LUNG_RIGHT', 'BRACHIAL_PLEXUS'] -# roi_series = ['Heart', 'Esophagus', 'Lung_L', 'Lung_R','SpinalCord'] -# roi_name = ['HEART', 'ESOPHAGUS', 'LUNG_LEFT','LUNG_RIGHT', 'SPINAL_CORD'] - -#roi_series = ['HEART', 'PTV', 'LUNG_LEFT', 'LUNG_RIGHT', 'SPINAL_CORD', 'ESOPHAGUS'] -#roi_name = ['HEART', 'PTV', 'LUNG_LEFT', 'LUNG_RIGHT', 'SPINAL_CORD', 'ESOPHAGUS'] -roi_series = ['PTV'] -roi_name = ['PTV'] - -path = config['DATA']['NiiFolder'] -sPatient = SubjectList -r1 = [x.lower() for x in roi_series] - -for i in range(1, len(SubjectList), 1): - subjectid = sPatient.loc[i, 'subjectid'] - subject_label = sPatient.loc[i,'subject_label'] - CTPath = SubjectList[SubjectList.subjectid == subjectid]['CT_Path'] - CTArray, meta = LoadImage()(CTPath) - ct = EnsureChannelFirst()(CTArray) - ct = Spacing(pixdim=(1, 1, 3))(ct) - - names_generator = itk.GDCMSeriesFileNames.New() - names_generator.SetUseSeriesDetails(True) - names_generator.AddSeriesRestriction("0008|0021") # Series Date - names_generator.SetDirectory(list(CTPath)[0]) - series_uid = names_generator.GetSeriesUIDs() - if len(series_uid) > 1: - print(i) - else: - ct_array = ct.array.squeeze() - # First define the ROI based on target - RSPath = SubjectList[SubjectList.subjectid == subjectid]['Structs_Path'] - try: - RS = RTStructBuilder.create_from(dicom_series_path=list(CTPath)[0], rt_struct_path=list(RSPath)[0]) - #%% - roi_names = RS.get_roi_names() - strList = [x.lower() for x in roi_names] - ni_img = nib.Nifti1Image(np.double(ct_array), affine=ct.affine) - spath = Path(path, subject_label) - - if not os.path.isdir(spath): - os.mkdir(spath) - nib.save(ni_img, Path(spath, 'ct.nii.gz')) - - for i in range(len(r1)): - roi = r1[i] - if roi in strList: - #print(roi) - index = strList.index(roi.lower()) - mask_img = RS.get_roi_mask_by_name(roi_names[index]) - mask_img = np.rot90(mask_img) - mask_img = np.flip(mask_img, 0) - mask = MetaTensor(mask_img.copy(), meta=meta) - mask = EnsureChannelFirst()(mask) - mask = Spacing(pixdim=(1, 1, 3))(mask) - mask_array = mask.array.squeeze() - ni_mask = nib.Nifti1Image(mask_array.astype('int'), affine=mask.affine, dtype='uint8') - nib.save(ni_mask, Path(spath, roi_name[i].lower() + '.nii.gz')) - else: - print(subject_label) - except: - #print('No.{}:'.format(i)) - continue - - -