Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import torch
import numpy as np

class LensingDataset(torch.utils.data.Dataset):
def __init__(self, directory, classes, num_samples):
"""
Expand All @@ -14,11 +15,12 @@ def __init__(self, directory, classes, num_samples):
self.directory = directory
self.classes = classes
self.num_samples = num_samples

def __len__(self):
"""
:return: Returns the length of the dataset
"""
return self.num_samples*len(self.classes)
return self.num_samples * len(self.classes)

def __getitem__(self, index):
"""
Expand All @@ -27,8 +29,20 @@ def __getitem__(self, index):
:param index: Index in the dataset to look for
:return: LR image, min-max normalized
"""
selected_class = self.classes[index//self.num_samples]
class_index = index%self.num_samples
image = torch.tensor(np.array([np.load(self.directory+selected_class+'/sim_%d.npy'%(class_index))]))
image = (image - torch.min(image))/(torch.max(image)-torch.min(image))
return image
# Determine class and specific image index
selected_class = self.classes[index // self.num_samples]
class_index = index % self.num_samples

# Safely construct the file path using f-strings and os.path.join
file_path = os.path.join(self.directory, selected_class, f'sim_{class_index}.npy')

# Load efficiently and add the channel dimension (1, H, W)
np_img = np.load(file_path)
image = torch.from_numpy(np_img).float().unsqueeze(0)

# Safe min-max normalization (adding 1e-8 prevents division by zero)
img_min = torch.min(image)
img_max = torch.max(image)
image = (image - img_min) / (img_max - img_min + 1e-8)

return image