Skip to content
Merged
129 changes: 104 additions & 25 deletions spectf/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import os
from collections.abc import Callable
from typing import List
from typing import List, Optional

import numpy as np
import torch
Expand All @@ -18,9 +18,10 @@
from spectf.utils import drop_bands, envi_header


class RasterDatasetTOA(Dataset):
"""A PyTorch dataset class for pixelwise access of top-of-atmosphere (TOA)
reflectance data derived from L1b rdn.
class ToaDataset(Dataset):
"""
A Parent class for PyTorch datasets for pixelwise access of top-of-atmosphere (TOA)
reflectance data.

Attributes:
shape (tuple): Shape of the L1b rdn raster.
Expand All @@ -29,6 +30,75 @@ class RasterDatasetTOA(Dataset):
metadata (dict): Metadata of the original raster image.
transform (callable, optional): Transformations for each pixel spectra.

"""

# Class attributes
# toa_arr - the main top of atmosphere reflectance data array, should be in shape (rows, cols, bands)
toa_arr: np.ndarray = None

# Optional transform callable method to be applied to each spectral data point
transform: Optional[Callable] = None

# Metadata dictionary of the original raster image - not used in the dataset class, but stored
metadata: dict = None

# NOTE this is the original dataset shape - this should be in order of (rows, cols, bands) or (cols, rows, bands)
shape: tuple = None

# array of band wavelengths corresponding to the indices of the third dimension of toa_arr
# used for band dropping when requested
banddef: np.ndarray = None

def init_class_data(self,
rm_bands: List[List[int]]=None,
transform: Callable = None,
dtype: torch.dtype = torch.float,
device: torch.device = None):
"""
Helper method to initialize class data common to all TOA datasets AFTER toa_arr has been set.

Args:
rm_bands: List[List[int]]=None: if None - keep all bands on non-EMIT data, else drop specified bands.
transform: Callable = None: Optional transform to be applied to each spectral data point.
dtype: torch.dtype = torch.float: data type to load data as
device: torch.device = None - if provided torch device to load data onto

Returns:

"""
if self.toa_arr is None:
raise ValueError("toa_arr must be set before calling init_class_data.")

self.shape = self.toa_arr.shape
self.toa_arr = self.toa_arr.reshape((self.shape[0] * self.shape[1],
self.shape[2]))
if rm_bands is not None:
self.toa_arr, self.banddef = drop_bands(self.toa_arr,
self.banddef,
rm_bands,
nan=False)
self.transform = transform

self.toa_arr = torch.tensor(self.toa_arr, dtype=dtype)
self.toa_arr = torch.unsqueeze(self.toa_arr, -1)
if device is not None:
self.toa_arr = self.toa_arr.to(device)

def __len__(self):
return len(self.toa_arr)

def __getitem__(self, idx):
out_spec = self.toa_arr[idx]
if self.transform is not None:
out_spec = self.transform(out_spec)

return out_spec


class RasterDatasetTOA(ToaDataset):
"""A PyTorch dataset class for pixelwise access of top-of-atmosphere (TOA)
reflectance data derived from L1b rdn.

Relies on the `l1b_to_toa_arr` function to process input data files and generate
TOA reflectance data.
"""
Expand Down Expand Up @@ -64,30 +134,39 @@ def __init__(
assert os.path.exists(irrfp), f"Irradiance file {irrfp} does not exist."

self.toa_arr, self.banddef, self.metadata = l1b_to_toa_arr(self.rdnhdr, self.obshdr, irrfp)
self.shape = self.toa_arr.shape
self.toa_arr = self.toa_arr.reshape((self.shape[0] * self.shape[1],
self.shape[2]))
if not keep_bands:
self.toa_arr, self.banddef = drop_bands(self.toa_arr,
self.banddef,
rm_bands,
nan=False)
self.transform = transform
self.init_class_data(rm_bands, transform, dtype, device)

self.toa_arr = torch.tensor(self.toa_arr, dtype=dtype)
self.toa_arr = torch.unsqueeze(self.toa_arr, -1)
if device is not None:
self.toa_arr = self.toa_arr.to(device)

def __len__(self):
return len(self.toa_arr)
class ArrayDatasetTOA(ToaDataset):
"""
A PyTorch dataset class for pixelwise access of top-of-atmosphere (TOA)
reflectance data derived from toa numpy array.
"""
def __init__(
self,
toa: np.ndarray,
banddef: np.ndarray,
rm_bands: List[List[int]] = None,
transform: Callable = None,
dtype: torch.dtype = torch.float,
device: torch.device = None,
):
"""
Initialize the ArrayDatasetTOA Dataset object.

def __getitem__(self, idx):
out_spec = self.toa_arr[idx]
if self.transform is not None:
out_spec = self.transform(out_spec)
Args:
toa: numpy array of top of atmosphere reflectance, should be in shape (rows, cols, bands)
banddef: numpy array of band wavelengths corresponding to the indices of the third dimension of toa_arr
used for band dropping when requested
rm_bands (List[List[int]] | None): if None - keep all bands on non-EMIT data, else drop specified bands.
transform: Callable = None: Optional transform to be applied to each spectral data point.
dtype: torch.dtype = torch.float: data type to load data as
device: torch.device = None - if provided torch device to load data onto
"""
super().__init__()

return out_spec
self.toa_arr = toa
self.banddef = banddef
self.init_class_data(rm_bands, transform, dtype, device)


class SpectraDataset(Dataset):
Expand Down
Loading
Loading