From d2a55c0fbf950d4b5937426642e85bbe6dee44e2 Mon Sep 17 00:00:00 2001 From: Jake Lee Date: Fri, 22 May 2026 21:35:21 -0700 Subject: [PATCH 1/4] V2 Encoder moves banddef to input. Draft impl. --- spectf/model.py | 150 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 149 insertions(+), 1 deletion(-) diff --git a/spectf/model.py b/spectf/model.py index e41a118..cc32869 100644 --- a/spectf/model.py +++ b/spectf/model.py @@ -56,6 +56,55 @@ def forward(self, spectra: torch.Tensor): return encoded +class BandConcatV2(nn.Module): + """Module to concatenate band wavelength information to spectra (Dynamic). + + This serves as the positional encoding for the transformer, and replaces the + traditional additive sinusoidal encoding. Band wavelengths are passed during + the forward pass and normalized to a fixed mean and standard deviation. + Default mean (1440) and std (600) are set based on the EMIT spectral range. + + Attributes: + mean (int): Predefined mean of the band center wavelengths. + std (int): Predefined stddev of the band center wavelengths. + """ + + def __init__(self, mean: int = 1440, std: int = 600): + """Initialize BandConcatV2 module. + + Args: + mean (int): Mean of the band center wavelengths. Default 1440. + std (int): Stddev of the band center wavelengths. Default 600. + """ + super().__init__() + self.mean = mean + self.std = std + + def forward(self, spectra: torch.Tensor, banddef: torch.Tensor): + """BandConcatV2 forward pass. + + Args: + spectra (torch.Tensor): tensor of shape (b, s, 1) + banddef (torch.Tensor): tensor of shape (b, s) or (s,) + + Returns: + torch.Tensor: concatenated tensor of shape (b, s, 2) + """ + # Ensure banddef is (b, s, 1) + if banddef.dim() == 1: + # (s,) -> (1, s, 1) + banddef = banddef.unsqueeze(-1).unsqueeze(0) + elif banddef.dim() == 2: + # (b, s) -> (b, s, 1) + banddef = banddef.unsqueeze(-1) + + # Normalize band wavelengths + banddef_norm = (banddef - self.mean) / self.std + + encoded = torch.cat((spectra, banddef_norm.expand_as(spectra)), dim=-1) + return encoded + + class SpectralEmbed(nn.Module): """Module to embed spectra per-band using a linear layer. @@ -177,7 +226,8 @@ def __init__(self, dim_model: int, num_heads: int, dropout: float = 0.1, Args: dim_model (int): Dimension of the input and output tensors. - num_heads (int): Number of attention heads. + num_heads (int): Number of attention heads. Must be a divisor of + dim_model. dropout (float): Dropout rate. Default 0.1. use_residual (bool): Whether to use residual connections. Default False. @@ -433,3 +483,101 @@ def initialize_weights(self): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) + + +class SpecTfEncoderV2(nn.Module): + """Encoder based Spectral Transformer model (Dynamic Banddef). + + This is a version of the Spectral Transformer architecture that allows the + wavelength grid (banddef) to be passed as an input during the forward pass, + enabling each input spectrum to have a different wavelength band. + + Model weights are initialized using Xavier initialization and model biases + are initialized to zero with self.initialize_weights(). + + Attributes: + band_concat: BandConcatV2 module + spectral_embed: SpectralEmbed module + layers: List of EncoderLayer modules + aggregate: Aggregation method ('mean', 'max') + head: Linear layer for classification or regression + """ + def __init__(self, + dim_output: int = 2, + num_heads: int = 8, + dim_proj: int = 64, + dim_ff: int = 64, + dropout: float = 0.1, + agg: str = 'max', + use_residual: bool = False, + num_layers: int = 1): + """Initialize SpecTfEncoderV2 module. + + Args: + dim_output (int): Output dimension of the model. Default 2. + num_heads (int): Number of attention heads. Must be a divisor of + dim_proj. Default 8. + dim_proj (int): Dimension of the projected tensors. Default 64. + dim_ff (int): Dimension of the intermediate tensors. Default 64. + dropout (float): Dropout rate. Default 0.1. + agg (str): Aggregation method ('mean', 'max'). + Default 'max'. + use_residual (bool): Whether to use residual connections. + Default False. + num_layers (int): Number of encoder layers. Default 1. + """ + super().__init__() + + # Embedding + self.band_concat = BandConcatV2() + self.spectral_embed = SpectralEmbed(n_filters=dim_proj) + + # Attention + self.layers = nn.ModuleList([ + EncoderLayer(dim_proj, num_heads, dim_ff, dropout, use_residual) + for _ in range(num_layers) + ]) + + # Head + self.agg = agg + self.head = nn.Linear(dim_proj, dim_output) + + self.initialize_weights() + + def aggregate(self, x): + """Performs the selected aggregation method. Needs to be broken out here for PyTorch's JiT""" + if self.agg == 'mean': + return torch.mean(x, dim=1) + elif self.agg == 'max': + return torch.max(x, dim=1)[0] + else: + raise ValueError(f'Aggregation method {self.agg} is not implemented.') + + def forward(self, x: torch.Tensor, banddef: torch.Tensor): + """SpecTfEncoderV2 forward pass. + + Args: + x (torch.Tensor): Input tensor of shape (b, s, 1) + banddef (torch.Tensor): Band center wavelengths of shape (b, s) or (s,) + + Returns: + torch.Tensor: Output tensor of shape (b, num_classes) + """ + x = self.band_concat(x, banddef) + x = self.spectral_embed(x) + + for layer in self.layers: + x = layer(x) + + x = self.aggregate(x) + x = self.head(x) + + return x + + def initialize_weights(self): + """Initialize weights for the model.""" + for module in self.modules(): + if isinstance(module, (nn.Linear, nn.Conv1d)): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) From c380fbb073443b5e18fecc059bac989059263fe3 Mon Sep 17 00:00:00 2001 From: Jake Lee Date: Fri, 22 May 2026 21:43:03 -0700 Subject: [PATCH 2/4] Key padding mask implementation for var length inputs --- spectf/model.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/spectf/model.py b/spectf/model.py index cc32869..9144998 100644 --- a/spectf/model.py +++ b/spectf/model.py @@ -243,19 +243,21 @@ def __init__(self, dim_model: int, num_heads: int, dropout: float = 0.1, self.use_residual = use_residual def forward(self, query: torch.Tensor, key: torch.Tensor, - value: torch.Tensor): + value: torch.Tensor, key_padding_mask: torch.Tensor = None): """AttentionBlock forward pass. Args: query (torch.Tensor): Query tensor of shape (b, s, dim_model) key (torch.Tensor): Key tensor of shape (b, s, dim_model) value (torch.Tensor): Value tensor of shape (b, s, dim_model) + key_padding_mask (torch.Tensor): Optional mask of shape (b, s), + True for padded elements. Returns: torch.Tensor: Output tensor of shape (b, s, dim_model) """ residual = query - x = self.attention(query, key, value)[0] + x = self.attention(query, key, value, key_padding_mask=key_padding_mask)[0] x = self.dropout(x) if self.use_residual: x = x + residual @@ -300,16 +302,17 @@ def __init__(self, dim_model: int, num_heads: int, dim_ff: int, self.norm1 = nn.LayerNorm(dim_model) self.norm2 = nn.LayerNorm(dim_model) - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor, mask: torch.Tensor = None): """EncoderLayer forward pass. Args: x (torch.Tensor): Input tensor of shape (b, s, dim_model) + mask (torch.Tensor): Optional mask of shape (b, s), True for padded elements. Returns: torch.Tensor: Output tensor of shape (b, s, dim_model) """ - x = self.attention(self.norm1(x), self.norm1(x), self.norm1(x)) + x = self.attention(self.norm1(x), self.norm1(x), self.norm1(x), key_padding_mask=mask) x = self.ff(self.norm2(x)) return x @@ -544,21 +547,33 @@ def __init__(self, self.initialize_weights() - def aggregate(self, x): + def aggregate(self, x, mask: torch.Tensor = None): """Performs the selected aggregation method. Needs to be broken out here for PyTorch's JiT""" if self.agg == 'mean': + if mask is not None: + # mask: (b, s), True for padded. Create valid mask (True for valid). + valid_mask = ~mask.unsqueeze(-1).bool() + sum_x = torch.sum(x * valid_mask.float(), dim=1) + count = torch.sum(valid_mask.float(), dim=1) + return sum_x / count return torch.mean(x, dim=1) elif self.agg == 'max': + if mask is not None: + # mask: (b, s), True for padded. + mask_expanded = mask.unsqueeze(-1).bool() + x_masked = x.masked_fill(mask_expanded, float('-inf')) + return torch.max(x_masked, dim=1)[0] return torch.max(x, dim=1)[0] else: raise ValueError(f'Aggregation method {self.agg} is not implemented.') - def forward(self, x: torch.Tensor, banddef: torch.Tensor): + def forward(self, x: torch.Tensor, banddef: torch.Tensor, mask: torch.Tensor = None): """SpecTfEncoderV2 forward pass. Args: x (torch.Tensor): Input tensor of shape (b, s, 1) banddef (torch.Tensor): Band center wavelengths of shape (b, s) or (s,) + mask (torch.Tensor): Optional mask of shape (b, s), True for padded elements. Returns: torch.Tensor: Output tensor of shape (b, num_classes) @@ -567,9 +582,9 @@ def forward(self, x: torch.Tensor, banddef: torch.Tensor): x = self.spectral_embed(x) for layer in self.layers: - x = layer(x) + x = layer(x, mask=mask) - x = self.aggregate(x) + x = self.aggregate(x, mask=mask) x = self.head(x) return x From 5d715de81fa349e1772601f2801399dc0ddb9f77 Mon Sep 17 00:00:00 2001 From: Jake Lee Date: Fri, 29 May 2026 22:38:19 -0700 Subject: [PATCH 3/4] Add new dataset version for AV3 --- spectf/dataset.py | 48 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/spectf/dataset.py b/spectf/dataset.py index e1d27b1..451b89e 100644 --- a/spectf/dataset.py +++ b/spectf/dataset.py @@ -136,3 +136,51 @@ def __getitem__(self, idx): 'spectra': torch.unsqueeze(out_spec, -1), 'label': self.labels[idx] } + + +class SpectraDatasetV2(Dataset): + """A PyTorch dataset class for access of ML-ready HDF5 spectral data (Dynamic Banddef). + + Attributes: + spectra (ndarray): The spectral data. + labels (ndarray): The corresponding labels for the spectral data. + wavelengths (ndarray): The corresponding wavelength band definitions. + transform (callable): Transformations or normalizations + for each spectral data point. + device (str): The device to load the data onto (e.g., 'cpu', 'cuda:0'). + """ + + def __init__(self, spectra: np.ndarray, labels: np.ndarray, wavelengths: np.ndarray, + transform: bool = None, device: str = 'cpu'): + """ Initialize the SpectraDatasetV2 object. + + Args: + spectra (np.ndarray): The spectral data. + labels (np.ndarray): The corresponding labels for the spectral data. + wavelengths (np.ndarray): The corresponding wavelength band definitions. + transform (callable): Optional transform to be applied to + each spectral data point. Default None. + device (str): The device to load the data onto. Default 'cpu'. + """ + super().__init__() + self.spectra = torch.tensor(spectra, dtype=torch.float32).to(device) + self.labels = torch.tensor(labels).to(device) + self.labels[self.labels==2] = 0 # shadow considered clear + self.wavelengths = torch.tensor(wavelengths, dtype=torch.float32).to(device) + + self.transform = transform + + def __len__(self): + return len(self.labels) + + def __getitem__(self, idx): + + out_spec = self.spectra[idx] + if self.transform is not None: + out_spec = self.transform(out_spec) + + return { + 'spectra': torch.unsqueeze(out_spec, -1), + 'label': self.labels[idx], + 'banddef': self.wavelengths[idx] + } From 66975085dcba72154195658aacb163db9837a2d9 Mon Sep 17 00:00:00 2001 From: Jake Lee Date: Mon, 1 Jun 2026 02:22:29 -0700 Subject: [PATCH 4/4] v2 jit fix --- spectf/model.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/spectf/model.py b/spectf/model.py index 9144998..704e91d 100644 --- a/spectf/model.py +++ b/spectf/model.py @@ -8,6 +8,7 @@ import torch from torch import nn +from typing import Optional class BandConcat(nn.Module): @@ -243,7 +244,7 @@ def __init__(self, dim_model: int, num_heads: int, dropout: float = 0.1, self.use_residual = use_residual def forward(self, query: torch.Tensor, key: torch.Tensor, - value: torch.Tensor, key_padding_mask: torch.Tensor = None): + value: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None): """AttentionBlock forward pass. Args: @@ -302,7 +303,7 @@ def __init__(self, dim_model: int, num_heads: int, dim_ff: int, self.norm1 = nn.LayerNorm(dim_model) self.norm2 = nn.LayerNorm(dim_model) - def forward(self, x: torch.Tensor, mask: torch.Tensor = None): + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None): """EncoderLayer forward pass. Args: @@ -547,27 +548,34 @@ def __init__(self, self.initialize_weights() - def aggregate(self, x, mask: torch.Tensor = None): + def aggregate(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Performs the selected aggregation method. Needs to be broken out here for PyTorch's JiT""" if self.agg == 'mean': if mask is not None: # mask: (b, s), True for padded. Create valid mask (True for valid). - valid_mask = ~mask.unsqueeze(-1).bool() - sum_x = torch.sum(x * valid_mask.float(), dim=1) - count = torch.sum(valid_mask.float(), dim=1) - return sum_x / count + valid_mask = ~mask.unsqueeze(-1).to(torch.bool) + valid_mask_f = valid_mask.to(x.dtype) + sum_x = torch.sum(x * valid_mask_f, dim=1) + count = torch.sum(valid_mask_f, dim=1) + + # Use clamp to prevent division by zero if an entire sequence is masked + return sum_x / count.clamp(min=1e-9) + return torch.mean(x, dim=1) + elif self.agg == 'max': if mask is not None: # mask: (b, s), True for padded. - mask_expanded = mask.unsqueeze(-1).bool() + mask_expanded = mask.unsqueeze(-1).to(torch.bool) x_masked = x.masked_fill(mask_expanded, float('-inf')) return torch.max(x_masked, dim=1)[0] + return torch.max(x, dim=1)[0] + else: - raise ValueError(f'Aggregation method {self.agg} is not implemented.') + raise ValueError(f"Aggregation method {self.agg} is not implemented.") - def forward(self, x: torch.Tensor, banddef: torch.Tensor, mask: torch.Tensor = None): + def forward(self, x: torch.Tensor, banddef: torch.Tensor, mask: Optional[torch.Tensor] = None): """SpecTfEncoderV2 forward pass. Args: