Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions spectf/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,51 @@ def __getitem__(self, idx):
'spectra': torch.unsqueeze(out_spec, -1),
'label': self.labels[idx]
}


class SpectraDatasetV2(Dataset):
"""A PyTorch dataset class for access of ML-ready HDF5 spectral data (Dynamic Banddef).

Attributes:
spectra (ndarray): The spectral data.
labels (ndarray): The corresponding labels for the spectral data.
wavelengths (ndarray): The corresponding wavelength band definitions.
transform (callable): Transformations or normalizations
for each spectral data point.
device (str): The device to load the data onto (e.g., 'cpu', 'cuda:0').
"""

def __init__(self, spectra: np.ndarray, labels: np.ndarray, wavelengths: np.ndarray,
transform: bool = None, device: str = 'cpu'):
""" Initialize the SpectraDatasetV2 object.

Args:
spectra (np.ndarray): The spectral data.
labels (np.ndarray): The corresponding labels for the spectral data.
wavelengths (np.ndarray): The corresponding wavelength band definitions.
transform (callable): Optional transform to be applied to
each spectral data point. Default None.
device (str): The device to load the data onto. Default 'cpu'.
"""
super().__init__()
self.spectra = torch.tensor(spectra, dtype=torch.float32).to(device)
self.labels = torch.tensor(labels).to(device)
self.labels[self.labels==2] = 0 # shadow considered clear
self.wavelengths = torch.tensor(wavelengths, dtype=torch.float32).to(device)

self.transform = transform

def __len__(self):
return len(self.labels)

def __getitem__(self, idx):

out_spec = self.spectra[idx]
if self.transform is not None:
out_spec = self.transform(out_spec)

return {
'spectra': torch.unsqueeze(out_spec, -1),
'label': self.labels[idx],
'banddef': self.wavelengths[idx]
}
181 changes: 176 additions & 5 deletions spectf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
from torch import nn
from typing import Optional


class BandConcat(nn.Module):
Expand Down Expand Up @@ -56,6 +57,55 @@ def forward(self, spectra: torch.Tensor):
return encoded


class BandConcatV2(nn.Module):
"""Module to concatenate band wavelength information to spectra (Dynamic).

This serves as the positional encoding for the transformer, and replaces the
traditional additive sinusoidal encoding. Band wavelengths are passed during
the forward pass and normalized to a fixed mean and standard deviation.
Default mean (1440) and std (600) are set based on the EMIT spectral range.

Attributes:
mean (int): Predefined mean of the band center wavelengths.
std (int): Predefined stddev of the band center wavelengths.
"""

def __init__(self, mean: int = 1440, std: int = 600):
"""Initialize BandConcatV2 module.

Args:
mean (int): Mean of the band center wavelengths. Default 1440.
std (int): Stddev of the band center wavelengths. Default 600.
"""
super().__init__()
self.mean = mean
self.std = std

def forward(self, spectra: torch.Tensor, banddef: torch.Tensor):
"""BandConcatV2 forward pass.

Args:
spectra (torch.Tensor): tensor of shape (b, s, 1)
banddef (torch.Tensor): tensor of shape (b, s) or (s,)

Returns:
torch.Tensor: concatenated tensor of shape (b, s, 2)
"""
# Ensure banddef is (b, s, 1)
if banddef.dim() == 1:
# (s,) -> (1, s, 1)
banddef = banddef.unsqueeze(-1).unsqueeze(0)
elif banddef.dim() == 2:
# (b, s) -> (b, s, 1)
banddef = banddef.unsqueeze(-1)

# Normalize band wavelengths
banddef_norm = (banddef - self.mean) / self.std

encoded = torch.cat((spectra, banddef_norm.expand_as(spectra)), dim=-1)
return encoded


class SpectralEmbed(nn.Module):
"""Module to embed spectra per-band using a linear layer.

Expand Down Expand Up @@ -177,7 +227,8 @@ def __init__(self, dim_model: int, num_heads: int, dropout: float = 0.1,

Args:
dim_model (int): Dimension of the input and output tensors.
num_heads (int): Number of attention heads.
num_heads (int): Number of attention heads. Must be a divisor of
dim_model.
dropout (float): Dropout rate. Default 0.1.
use_residual (bool): Whether to use residual connections.
Default False.
Expand All @@ -193,19 +244,21 @@ def __init__(self, dim_model: int, num_heads: int, dropout: float = 0.1,
self.use_residual = use_residual

def forward(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor):
value: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None):
"""AttentionBlock forward pass.

Args:
query (torch.Tensor): Query tensor of shape (b, s, dim_model)
key (torch.Tensor): Key tensor of shape (b, s, dim_model)
value (torch.Tensor): Value tensor of shape (b, s, dim_model)
key_padding_mask (torch.Tensor): Optional mask of shape (b, s),
True for padded elements.

Returns:
torch.Tensor: Output tensor of shape (b, s, dim_model)
"""
residual = query
x = self.attention(query, key, value)[0]
x = self.attention(query, key, value, key_padding_mask=key_padding_mask)[0]
x = self.dropout(x)
if self.use_residual:
x = x + residual
Expand Down Expand Up @@ -250,16 +303,17 @@ def __init__(self, dim_model: int, num_heads: int, dim_ff: int,
self.norm1 = nn.LayerNorm(dim_model)
self.norm2 = nn.LayerNorm(dim_model)

def forward(self, x: torch.Tensor):
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
"""EncoderLayer forward pass.

Args:
x (torch.Tensor): Input tensor of shape (b, s, dim_model)
mask (torch.Tensor): Optional mask of shape (b, s), True for padded elements.

Returns:
torch.Tensor: Output tensor of shape (b, s, dim_model)
"""
x = self.attention(self.norm1(x), self.norm1(x), self.norm1(x))
x = self.attention(self.norm1(x), self.norm1(x), self.norm1(x), key_padding_mask=mask)
x = self.ff(self.norm2(x))
return x

Expand Down Expand Up @@ -433,3 +487,120 @@ def initialize_weights(self):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)


class SpecTfEncoderV2(nn.Module):
"""Encoder based Spectral Transformer model (Dynamic Banddef).

This is a version of the Spectral Transformer architecture that allows the
wavelength grid (banddef) to be passed as an input during the forward pass,
enabling each input spectrum to have a different wavelength band.

Model weights are initialized using Xavier initialization and model biases
are initialized to zero with self.initialize_weights().

Attributes:
band_concat: BandConcatV2 module
spectral_embed: SpectralEmbed module
layers: List of EncoderLayer modules
aggregate: Aggregation method ('mean', 'max')
head: Linear layer for classification or regression
"""
def __init__(self,
dim_output: int = 2,
num_heads: int = 8,
dim_proj: int = 64,
dim_ff: int = 64,
dropout: float = 0.1,
agg: str = 'max',
use_residual: bool = False,
num_layers: int = 1):
"""Initialize SpecTfEncoderV2 module.

Args:
dim_output (int): Output dimension of the model. Default 2.
num_heads (int): Number of attention heads. Must be a divisor of
dim_proj. Default 8.
dim_proj (int): Dimension of the projected tensors. Default 64.
dim_ff (int): Dimension of the intermediate tensors. Default 64.
dropout (float): Dropout rate. Default 0.1.
agg (str): Aggregation method ('mean', 'max').
Default 'max'.
use_residual (bool): Whether to use residual connections.
Default False.
num_layers (int): Number of encoder layers. Default 1.
"""
super().__init__()

# Embedding
self.band_concat = BandConcatV2()
self.spectral_embed = SpectralEmbed(n_filters=dim_proj)

# Attention
self.layers = nn.ModuleList([
EncoderLayer(dim_proj, num_heads, dim_ff, dropout, use_residual)
for _ in range(num_layers)
])

# Head
self.agg = agg
self.head = nn.Linear(dim_proj, dim_output)

self.initialize_weights()

def aggregate(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Performs the selected aggregation method. Needs to be broken out here for PyTorch's JiT"""
if self.agg == 'mean':
if mask is not None:
# mask: (b, s), True for padded. Create valid mask (True for valid).
valid_mask = ~mask.unsqueeze(-1).to(torch.bool)
valid_mask_f = valid_mask.to(x.dtype)
sum_x = torch.sum(x * valid_mask_f, dim=1)
count = torch.sum(valid_mask_f, dim=1)

# Use clamp to prevent division by zero if an entire sequence is masked
return sum_x / count.clamp(min=1e-9)

return torch.mean(x, dim=1)

elif self.agg == 'max':
if mask is not None:
# mask: (b, s), True for padded.
mask_expanded = mask.unsqueeze(-1).to(torch.bool)
x_masked = x.masked_fill(mask_expanded, float('-inf'))
return torch.max(x_masked, dim=1)[0]

return torch.max(x, dim=1)[0]

else:
raise ValueError(f"Aggregation method {self.agg} is not implemented.")

def forward(self, x: torch.Tensor, banddef: torch.Tensor, mask: Optional[torch.Tensor] = None):
"""SpecTfEncoderV2 forward pass.

Args:
x (torch.Tensor): Input tensor of shape (b, s, 1)
banddef (torch.Tensor): Band center wavelengths of shape (b, s) or (s,)
mask (torch.Tensor): Optional mask of shape (b, s), True for padded elements.

Returns:
torch.Tensor: Output tensor of shape (b, num_classes)
"""
x = self.band_concat(x, banddef)
x = self.spectral_embed(x)

for layer in self.layers:
x = layer(x, mask=mask)

x = self.aggregate(x, mask=mask)
x = self.head(x)

return x

def initialize_weights(self):
"""Initialize weights for the model."""
for module in self.modules():
if isinstance(module, (nn.Linear, nn.Conv1d)):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
Loading