diff --git a/Grid_based_strong_lensing_for_unsupervised_super_resolution_Anirudh_Shankar/data.py b/Grid_based_strong_lensing_for_unsupervised_super_resolution_Anirudh_Shankar/data.py index 6d394ff..3659915 100644 --- a/Grid_based_strong_lensing_for_unsupervised_super_resolution_Anirudh_Shankar/data.py +++ b/Grid_based_strong_lensing_for_unsupervised_super_resolution_Anirudh_Shankar/data.py @@ -1,6 +1,7 @@ +import os import torch import numpy as np - + class LensingDataset(torch.utils.data.Dataset): def __init__(self, directory, classes, num_samples): """ @@ -14,11 +15,12 @@ def __init__(self, directory, classes, num_samples): self.directory = directory self.classes = classes self.num_samples = num_samples + def __len__(self): """ :return: Returns the length of the dataset """ - return self.num_samples*len(self.classes) + return self.num_samples * len(self.classes) def __getitem__(self, index): """ @@ -27,8 +29,20 @@ def __getitem__(self, index): :param index: Index in the dataset to look for :return: LR image, min-max normalized """ - selected_class = self.classes[index//self.num_samples] - class_index = index%self.num_samples - image = torch.tensor(np.array([np.load(self.directory+selected_class+'/sim_%d.npy'%(class_index))])) - image = (image - torch.min(image))/(torch.max(image)-torch.min(image)) - return image \ No newline at end of file + # Determine class and specific image index + selected_class = self.classes[index // self.num_samples] + class_index = index % self.num_samples + + # Safely construct the file path using f-strings and os.path.join + file_path = os.path.join(self.directory, selected_class, f'sim_{class_index}.npy') + + # Load efficiently and add the channel dimension (1, H, W) + np_img = np.load(file_path) + image = torch.from_numpy(np_img).float().unsqueeze(0) + + # Safe min-max normalization (adding 1e-8 prevents division by zero) + img_min = torch.min(image) + img_max = torch.max(image) + image = (image - img_min) / (img_max - img_min + 1e-8) + + return image