From 935cb48e107a79dbf03803eabb88fec2782b1464 Mon Sep 17 00:00:00 2001 From: Castorp <50649074+ShinDongWoon@users.noreply.github.com> Date: Thu, 25 Sep 2025 16:34:12 +0900 Subject: [PATCH] Optimize per-series dataset slicing and log train workload --- configs/default.yaml | 2 +- src/timesnet_forecast/data/dataset.py | 33 +++- src/timesnet_forecast/losses.py | 48 +++-- src/timesnet_forecast/models/timesnet.py | 46 +++-- src/timesnet_forecast/train.py | 227 ++++++++++++++++------- tests/test_dataset_pmax.py | 56 +++++- tests/test_fft_period_selector.py | 16 ++ tests/test_negative_binomial_nll.py | 49 +++-- tests/test_timesnet_forward.py | 23 +++ 9 files changed, 373 insertions(+), 127 deletions(-) diff --git a/configs/default.yaml b/configs/default.yaml index 55c6a44..bc47161 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -68,7 +68,7 @@ model: d_ff: 1024 n_layers: 3 dropout: 0.1 - k_periods: 3 + k_periods: 2 min_period_threshold: 7 # 최소 주기 하한선 kernel_set: - [3, 3] diff --git a/src/timesnet_forecast/data/dataset.py b/src/timesnet_forecast/data/dataset.py index 8a6eede..6d64831 100644 --- a/src/timesnet_forecast/data/dataset.py +++ b/src/timesnet_forecast/data/dataset.py @@ -61,6 +61,8 @@ def __init__( else: self.M = valid_mask.astype(np.float32) self.T, self.N = self.X.shape + if self.N <= 0: + raise ValueError("wide_values must contain at least one series column") self.L = int(input_len) if mode == "direct": self.H = int(pred_len) @@ -157,32 +159,45 @@ def __init__( else: self.series_ids = None + self._windows_per_series = int(len(self.idxs)) + def __len__(self) -> int: - return int(len(self.idxs)) + return int(self._windows_per_series * self.N) def __getitem__(self, idx: int) -> tuple[object, ...]: - s = int(self.idxs[idx]) + if self._windows_per_series <= 0: + raise IndexError("SlidingWindowDataset is empty") + window_idx = int(idx // self.N) + series_idx = int(idx % self.N) + if window_idx >= self._windows_per_series: + raise IndexError("index out of range for sliding windows") + s = int(self.idxs[window_idx]) if self.time_shift > 0: delta = np.random.randint(-self.time_shift, self.time_shift + 1) s = int(np.clip(s + delta, 0, self.T - self.L - self.H)) e = s + self.L - x_tensor = self._X_tensor[s:e, :].clone() + x_slice = self._X_tensor[s:e, series_idx] if self.add_noise_std > 0: + x_tensor = x_slice.clone().unsqueeze(-1) noise = torch.randn_like(x_tensor) * self.add_noise_std x_tensor = x_tensor + noise - y_tensor = self._X_tensor[e : e + self.H, :].clone() - mask_tensor = self._M_tensor[e : e + self.H, :].clone() + else: + x_tensor = x_slice.unsqueeze(-1) + y_tensor = self._X_tensor[e : e + self.H, series_idx].unsqueeze(-1) + mask_tensor = self._M_tensor[e : e + self.H, series_idx].unsqueeze(-1) if self.time_marks is not None: - x_mark = self.time_marks[s:e, :].clone() - y_mark = self.time_marks[e : e + self.H, :].clone() + x_mark = self.time_marks[s:e, :] + y_mark = self.time_marks[e : e + self.H, :] else: x_mark = self._empty_time_mark y_mark = self._empty_time_mark items: list[object] = [x_tensor, y_tensor, mask_tensor, x_mark, y_mark] if self.series_static is not None: - items.append(self.series_static) + static_slice = self.series_static[series_idx : series_idx + 1, :] + items.append(static_slice) if self.series_ids is not None: - items.append(self.series_ids) + id_slice = self.series_ids[series_idx : series_idx + 1] + items.append(id_slice) return tuple(items) @staticmethod diff --git a/src/timesnet_forecast/losses.py b/src/timesnet_forecast/losses.py index c8b59fc..3011f60 100644 --- a/src/timesnet_forecast/losses.py +++ b/src/timesnet_forecast/losses.py @@ -3,6 +3,23 @@ import torch +def negative_binomial_mask( + y: torch.Tensor, + rate: torch.Tensor, + dispersion: torch.Tensor, + mask: torch.Tensor | None = None, +) -> torch.Tensor: + """Compute a boolean mask for valid NB likelihood elements.""" + + finite_mask = torch.isfinite(y) & torch.isfinite(rate) & torch.isfinite(dispersion) + if mask is not None: + mask_bool = mask.to(dtype=torch.bool) + if mask_bool.shape != finite_mask.shape: + mask_bool = mask_bool.expand_as(finite_mask) + finite_mask = finite_mask & mask_bool + return finite_mask + + def negative_binomial_nll( y: torch.Tensor, rate: torch.Tensor, @@ -13,26 +30,25 @@ def negative_binomial_nll( """Negative binomial negative log-likelihood averaged over valid elements.""" dtype = torch.float32 - y = y.to(dtype) + y = torch.clamp(y.to(dtype), min=0.0) rate = rate.to(dtype) dispersion = dispersion.to(dtype) alpha = torch.clamp(dispersion, min=eps) mu = torch.clamp(rate, min=eps) - r = 1.0 / alpha - log_p = torch.log(r) - torch.log(r + mu) - log1m_p = torch.log(mu) - torch.log(r + mu) - log_prob = ( - torch.lgamma(y + r) - - torch.lgamma(r) + log1p_alpha_mu = torch.log1p(alpha * mu) + log_alpha = torch.log(alpha) + log_mu = torch.log(mu) + inv_alpha = torch.reciprocal(alpha) + ll = ( + torch.lgamma(y + inv_alpha) + - torch.lgamma(inv_alpha) - torch.lgamma(y + 1.0) - + r * log_p - + y * log1m_p + + inv_alpha * (-log1p_alpha_mu) + + y * (log_alpha + log_mu - log1p_alpha_mu) ) - if mask is not None: - mask = mask.to(dtype) - log_prob = log_prob * mask - denom = torch.clamp(mask.sum(), min=1.0) - else: - denom = log_prob.numel() - return -(log_prob.sum() / denom) + + valid_mask = negative_binomial_mask(y, mu, alpha, mask) + weight = valid_mask.to(dtype) + denom = torch.clamp(weight.sum(), min=1.0) + return -(ll * weight).sum() / denom diff --git a/src/timesnet_forecast/models/timesnet.py b/src/timesnet_forecast/models/timesnet.py index 4c296bb..eef8592 100644 --- a/src/timesnet_forecast/models/timesnet.py +++ b/src/timesnet_forecast/models/timesnet.py @@ -107,9 +107,9 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: empty_amp = torch.zeros(B, 0, dtype=amp_samples.dtype, device=device) return empty_idx, empty_amp.to(dtype) - freq_indices = torch.arange(amp_mean.numel(), device=device, dtype=dtype) - tie_break = freq_indices * torch.finfo(dtype).eps - scores = amp_mean - tie_break + freq_indices = torch.arange(amp_mean.numel(), device=device, dtype=torch.long) + log_indices = torch.log1p(freq_indices.to(torch.float32)) + scores = amp_mean - 1e-8 * log_indices.to(dtype) _, indices = torch.topk(scores, k=k, largest=True) safe_indices = indices.to(device=device, dtype=torch.long).clamp_min(1) sample_values = amp_samples.gather( @@ -117,12 +117,24 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: ) L_t = torch.tensor(L, dtype=torch.long, device=device) + upper_bound = min(self.pmax, max(L - 1, self.min_period_threshold)) + if upper_bound < self.min_period_threshold: + empty_idx = torch.zeros(0, dtype=torch.long, device=device) + empty_amp = torch.zeros(B, 0, dtype=amp_samples.dtype, device=device) + return empty_idx, empty_amp.to(dtype) + periods = (L_t + safe_indices - 1) // safe_indices - periods = torch.clamp( - periods, - min=self.min_period_threshold, - max=self.pmax, - ) + periods = torch.clamp(periods, min=self.min_period_threshold, max=upper_bound) + + cycles = (L_t + periods - 1) // periods + valid_mask = cycles >= 2 + if not torch.any(valid_mask): + empty_idx = torch.zeros(0, dtype=torch.long, device=device) + empty_amp = torch.zeros(B, 0, dtype=amp_samples.dtype, device=device) + return empty_idx, empty_amp.to(dtype) + + periods = periods[valid_mask] + sample_values = sample_values[:, valid_mask] return periods, sample_values.to(dtype) @@ -280,6 +292,7 @@ def __init__( # ``period_selector`` is injected from ``TimesNet`` after instantiation to # avoid registering the shared selector multiple times. self.period_selector: FFTPeriodSelector | None = None + self._period_calls: int = 0 def _build_layers(self, channels: int, device: torch.device, dtype: torch.dtype) -> None: if channels <= 0: @@ -322,6 +335,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.period_selector is None: raise RuntimeError("TimesBlock.period_selector has not been set") + self._period_calls = getattr(self, "_period_calls", 0) + 1 if self.inception is None: if self._configured_d_model is not None and x.size(-1) != self._configured_d_model: raise ValueError( @@ -1019,14 +1033,14 @@ def _ensure_embedding( ) self.pre_embedding_dropout = self.pre_embedding_dropout.to(device=x.device) - if ( - isinstance(self.min_sigma_vector, torch.Tensor) - and self.min_sigma_vector.numel() > 0 - and self.min_sigma_vector.shape[-1] != c_in - ): - raise ValueError( - "min_sigma_vector length does not match number of series" - ) + if isinstance(self.min_sigma_vector, torch.Tensor) and self.min_sigma_vector.numel() > 0: + current = int(self.min_sigma_vector.shape[-1]) + if current < c_in: + raise ValueError( + "min_sigma_vector length does not match number of series" + ) + if current != c_in: + self.min_sigma_vector = self.min_sigma_vector[..., :c_in] if self.embedding_time_features is not None and self.embedding_time_features != time_dim: raise ValueError("Temporal feature dimension changed between calls") diff --git a/src/timesnet_forecast/train.py b/src/timesnet_forecast/train.py index c298c71..9925fbb 100644 --- a/src/timesnet_forecast/train.py +++ b/src/timesnet_forecast/train.py @@ -15,7 +15,7 @@ from .config import Config, save_yaml from .utils.logging import console, print_config from .utils.seed import seed_everything -from .losses import negative_binomial_nll +from .losses import negative_binomial_mask, negative_binomial_nll from .utils.torch_opt import ( amp_autocast, maybe_compile, @@ -192,6 +192,30 @@ def _normalize_optional(value): return value +def _stack_series_columns( + per_id_values: Dict[int, List[np.ndarray]], n_ids: int +) -> np.ndarray: + if n_ids <= 0: + return np.zeros((0, 0), dtype=np.float32) + columns: List[np.ndarray] = [] + expected_len: int | None = None + for sid in range(n_ids): + series_list = per_id_values.get(sid, []) + if series_list: + flat_values = [np.asarray(v, dtype=np.float32).reshape(-1) for v in series_list] + col = np.concatenate(flat_values, axis=0) + else: + col = np.zeros(0, dtype=np.float32) + if expected_len is None: + expected_len = int(col.shape[0]) + elif int(col.shape[0]) != expected_len: + raise ValueError("Mismatched series lengths detected during evaluation") + columns.append(col.reshape(-1, 1)) + if expected_len is None: + return np.zeros((0, n_ids), dtype=np.float32) + return np.concatenate(columns, axis=1) + + def _unpack_batch( batch, ) -> tuple[ @@ -442,8 +466,8 @@ def _eval_wsmape( use_loss_mask: bool = False, ) -> float: model.eval() - ys: List[np.ndarray] = [] - preds: List[np.ndarray] = [] + per_id_targets: Dict[int, List[np.ndarray]] = {i: [] for i in range(len(ids))} + per_id_preds: Dict[int, List[np.ndarray]] = {i: [] for i in range(len(ids))} default_series_ids = torch.arange(len(ids), dtype=torch.long, device=device) with torch.inference_mode(), amp_autocast(True if device.type == "cuda" else False): for batch in loader: @@ -451,7 +475,10 @@ def _eval_wsmape( xb = move_to_device(xb, device) # [B, L, N] yb = move_to_device(yb, device) # [B, H_or_1, N] mask_dev = move_to_device(mask, device) - loss_mask = mask_dev.to(yb.dtype) if use_loss_mask else None + if use_loss_mask: + base_mask = mask_dev > 0.0 + else: + base_mask = torch.ones_like(yb, dtype=torch.bool, device=device) if x_mark is not None: x_mark = x_mark.to(device=device, non_blocking=True) if y_mark is not None: @@ -481,16 +508,29 @@ def _eval_wsmape( series_ids=series_idx, ) rate = rate[:, : yb.shape[1], :] - if loss_mask is not None: - yb_eval = yb * loss_mask - rate_eval = rate * loss_mask - else: - yb_eval = yb - rate_eval = rate - ys.append(yb_eval.detach().float().cpu().numpy()) - preds.append(rate_eval.detach().float().cpu().numpy()) - Y = np.concatenate(ys, axis=0).reshape(-1, len(ids)) - P = np.concatenate(preds, axis=0).reshape(-1, len(ids)) + nb_mask = negative_binomial_mask( + yb, + rate, + torch.ones_like(rate, dtype=rate.dtype, device=rate.device), + base_mask, + ) + mask_for_eval = nb_mask.to(yb.dtype) + yb_eval = yb * mask_for_eval + rate_eval = rate * mask_for_eval + + series_idx = series_idx if series_idx is not None else default_series_ids + if series_idx.dim() == 1: + series_idx = series_idx.unsqueeze(0).expand(yb_eval.size(0), -1) + series_idx_cpu = series_idx.detach().cpu() + y_cpu = yb_eval.detach().float().cpu() + rate_cpu = rate_eval.detach().float().cpu() + for b in range(y_cpu.size(0)): + for n in range(y_cpu.size(2)): + sid = int(series_idx_cpu[b, n].item()) + per_id_targets.setdefault(sid, []).append(y_cpu[b, :, n].numpy().reshape(-1)) + per_id_preds.setdefault(sid, []).append(rate_cpu[b, :, n].numpy().reshape(-1)) + Y = _stack_series_columns(per_id_targets, len(ids)) + P = _stack_series_columns(per_id_preds, len(ids)) return wsmape_grouped(Y, P, ids=ids, weights=None) @@ -506,8 +546,8 @@ def _eval_metrics( min_sigma: float | torch.Tensor = 0.0, ) -> Dict[str, float]: model.eval() - ys: List[np.ndarray] = [] - preds: List[np.ndarray] = [] + per_id_targets: Dict[int, List[np.ndarray]] = {i: [] for i in range(len(ids))} + per_id_preds: Dict[int, List[np.ndarray]] = {i: [] for i in range(len(ids))} nll_num = 0.0 nll_den = 0.0 default_series_ids = torch.arange(len(ids), dtype=torch.long, device=device) @@ -517,7 +557,10 @@ def _eval_metrics( xb = move_to_device(xb, device) # [B, L, N] yb = move_to_device(yb, device) # [B, H_or_1, N] mask_dev = move_to_device(mask, device) - loss_mask = mask_dev.to(yb.dtype) if use_loss_mask else None + if use_loss_mask: + base_mask = mask_dev > 0.0 + else: + base_mask = torch.ones_like(yb, dtype=torch.bool, device=device) if x_mark is not None: x_mark = x_mark.to(device=device, non_blocking=True) if y_mark is not None: @@ -551,29 +594,34 @@ def _eval_metrics( ) rate = rate[:, : yb.shape[1], :] dispersion = dispersion[:, : yb.shape[1], :] - if loss_mask is not None: - mask_for_loss = loss_mask.to(yb.dtype) - yb_eval = yb * mask_for_loss - rate_eval = rate * mask_for_loss - else: - mask_for_loss = torch.ones_like(yb, dtype=yb.dtype, device=yb.device) - yb_eval = yb - rate_eval = rate + nb_mask = negative_binomial_mask(yb, rate, dispersion, base_mask) + mask_for_loss = nb_mask.to(yb.dtype) + yb_eval = yb * mask_for_loss + rate_eval = rate * mask_for_loss nb_loss = negative_binomial_nll( y=yb, rate=rate, dispersion=dispersion, - mask=mask_for_loss, + mask=nb_mask, ) mask_total = float(mask_for_loss.sum().item()) if mask_total <= 0.0: mask_total = float(yb.numel()) if yb.numel() > 0 else 1.0 nll_num += float(nb_loss.item()) * mask_total nll_den += mask_total - ys.append(yb_eval.detach().float().cpu().numpy()) - preds.append(rate_eval.detach().float().cpu().numpy()) - Y = np.concatenate(ys, axis=0).reshape(-1, len(ids)) - P = np.concatenate(preds, axis=0).reshape(-1, len(ids)) + series_idx = series_idx if series_idx is not None else default_series_ids + if series_idx.dim() == 1: + series_idx = series_idx.unsqueeze(0).expand(yb_eval.size(0), -1) + series_idx_cpu = series_idx.detach().cpu() + y_cpu = yb_eval.detach().float().cpu() + rate_cpu = rate_eval.detach().float().cpu() + for b in range(y_cpu.size(0)): + for n in range(y_cpu.size(2)): + sid = int(series_idx_cpu[b, n].item()) + per_id_targets.setdefault(sid, []).append(y_cpu[b, :, n].numpy().reshape(-1)) + per_id_preds.setdefault(sid, []).append(rate_cpu[b, :, n].numpy().reshape(-1)) + Y = _stack_series_columns(per_id_targets, len(ids)) + P = _stack_series_columns(per_id_preds, len(ids)) smape = smape_mean(Y, P) denom = nll_den if nll_den > 0 else 1.0 return {"nll": nll_num / denom, "smape": smape} @@ -733,6 +781,19 @@ def train_once(cfg: Dict) -> Tuple[float, Dict]: ) if len(dl_val.dataset) == 0: raise ValueError("Validation split has no windows; increase train.val.holdout_days or adjust model.input_len/pred_len.") + train_series_count = len(ids) + train_sample_count = len(dl_train.dataset) + train_batches_per_epoch = len(dl_train) if hasattr(dl_train, "__len__") else "?" + approx_windows_per_series = ( + train_sample_count // train_series_count if train_series_count > 0 else 0 + ) + console().print( + ( + "[cyan]Train dataset: " + f"{train_series_count} series × ~{approx_windows_per_series} windows = {train_sample_count} samples; " + f"batches/epoch={train_batches_per_epoch} at batch_size={cfg['train']['batch_size']}[/cyan]" + ) + ) time_feature_dim = _time_feature_dim_from_dataset(dl_train.dataset) dataset_freq = _time_frequency_from_dataset(dl_train.dataset) base_index = wide.index if isinstance(wide.index, pd.DatetimeIndex) else None @@ -753,6 +814,11 @@ def train_once(cfg: Dict) -> Tuple[float, Dict]: warmup_series_static = torch.from_numpy(series_static_np).to( device=device, dtype=torch.float32 ) + warmup_series_static_single: torch.Tensor | None + if warmup_series_static.numel() > 0: + warmup_series_static_single = warmup_series_static[:1, :] + else: + warmup_series_static_single = None series_ids_default = torch.from_numpy(series_id_array).to( device=device, dtype=torch.long ) @@ -834,9 +900,14 @@ def train_once(cfg: Dict) -> Tuple[float, Dict]: ).to(device) # Lazily build model parameters so that downstream utilities see them + warmup_ids_single: torch.Tensor | None + if series_ids_default.numel() > 0: + warmup_ids_single = series_ids_default[:1] + else: + warmup_ids_single = None warmup_kwargs = { - "series_static": warmup_series_static, - "series_ids": series_ids_default, + "series_static": warmup_series_static_single, + "series_ids": warmup_ids_single, } warmup_kwargs = {k: v for k, v in warmup_kwargs.items() if v is not None} if time_features_enabled and time_feature_dim > 0: @@ -847,8 +918,16 @@ def train_once(cfg: Dict) -> Tuple[float, Dict]: device=device, dtype=torch.float32, ) + original_min_sigma_buffer: torch.Tensor | None = None + if ( + isinstance(model.min_sigma_vector, torch.Tensor) + and model.min_sigma_vector.numel() > 0 + ): + original_min_sigma_buffer = model.min_sigma_vector + model.min_sigma_vector = model.min_sigma_vector[..., :1] + with torch.no_grad(): - dummy = torch.zeros(1, input_len, len(ids), device=device) + dummy = torch.zeros(1, input_len, 1, device=device) model(dummy, **warmup_kwargs) if cfg["train"]["channels_last"]: model.to(memory_format=torch.channels_last) @@ -862,6 +941,9 @@ def train_once(cfg: Dict) -> Tuple[float, Dict]: warmup_kwargs=warmup_kwargs, ) + if original_min_sigma_buffer is not None: + model.min_sigma_vector = original_min_sigma_buffer + if isinstance(getattr(model, "min_sigma_vector", None), torch.Tensor) and model.min_sigma_vector.numel() > 0: min_sigma: float | torch.Tensor = model.min_sigma_vector else: @@ -909,7 +991,7 @@ def train_once(cfg: Dict) -> Tuple[float, Dict]: sched_cfg = cfg["train"].get("lr_scheduler", {}) scheduler: torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.ReduceLROnPlateau | None = None - sched_type = sched_cfg.get("type") + sched_type = sched_cfg.get("type") or "cosine" if sched_type == "ReduceLROnPlateau": scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optim, @@ -966,24 +1048,16 @@ def train_once(cfg: Dict) -> Tuple[float, Dict]: cfg["train"]["lr_warmup_steps_effective"] = 0 cfg["train"]["lr_warmup_epochs_effective"] = 0 cfg["train"]["lr_warmup_start_factor_effective"] = 1.0 - - warmup_active = ( - cfg["train"].get("lr_warmup_epochs_effective", 0) > 0 - and scheduler is not None - and hasattr(scheduler, "base_lrs") - ) - if warmup_active: - initial_lrs: List[float] = [] - for base_lr, param_group in zip(scheduler.base_lrs, optim.param_groups): - warmup_lr = base_lr * warmup_start_factor + elif sched_type == "cosine" and warmup_epochs > 0: + warmup_lr = float(cfg["train"]["lr"]) * warmup_start_factor + for param_group in optim.param_groups: param_group["lr"] = warmup_lr - initial_lrs.append(warmup_lr) - if hasattr(scheduler, "_last_lr"): - scheduler._last_lr = initial_lrs + if scheduler is not None and hasattr(scheduler, "_last_lr"): + scheduler._last_lr = [warmup_lr for _ in scheduler._last_lr] if isinstance(scheduler, torch.optim.lr_scheduler.SequentialLR): for sub_scheduler in scheduler._schedulers: if hasattr(sub_scheduler, "_last_lr"): - sub_scheduler._last_lr = initial_lrs + sub_scheduler._last_lr = [warmup_lr for _ in sub_scheduler._last_lr] try: grad_scaler = torch.amp.GradScaler( @@ -1021,9 +1095,9 @@ def train_once(cfg: Dict) -> Tuple[float, Dict]: yb_w = yb_w.to(device, non_blocking=True) if use_loss_masking: mask_w = mask_w.to(device, non_blocking=True) - mb_w = mask_w.to(yb_w.dtype) + base_mask_w = mask_w > 0.0 else: - mb_w = torch.ones_like(yb_w, dtype=yb_w.dtype, device=device) + base_mask_w = torch.ones_like(yb_w, dtype=torch.bool, device=device) if x_mark_w is not None: x_mark_w = x_mark_w.to(device=device, non_blocking=True) if y_mark_w is not None: @@ -1046,12 +1120,13 @@ def train_once(cfg: Dict) -> Tuple[float, Dict]: series_static=static_w, series_ids=series_ids_w, ) + nb_mask_w = negative_binomial_mask(yb_w, rate_w, dispersion_w, base_mask_w) loss_w = ( negative_binomial_nll( y=yb_w, rate=rate_w, dispersion=dispersion_w, - mask=mb_w, + mask=nb_mask_w, ) / accum_steps ) @@ -1074,9 +1149,9 @@ def train_once(cfg: Dict) -> Tuple[float, Dict]: yb0 = yb0.to(device, non_blocking=True) if use_loss_masking: mask0 = mask0.to(device, non_blocking=True) - mb0 = mask0.to(yb0.dtype) + base_mask0 = mask0 > 0.0 else: - mb0 = torch.ones_like(yb0, dtype=yb0.dtype, device=device) + base_mask0 = torch.ones_like(yb0, dtype=torch.bool, device=device) if x_mark0 is not None: x_mark0 = x_mark0.to(device=device, non_blocking=True) if y_mark0 is not None: @@ -1093,10 +1168,10 @@ def train_once(cfg: Dict) -> Tuple[float, Dict]: xb0 = xb0.to(memory_format=torch.channels_last) static_x = torch.empty_like(xb0) static_y = torch.empty_like(yb0) - static_m = torch.empty_like(mb0) + static_m = torch.empty_like(base_mask0) static_x.copy_(xb0) static_y.copy_(yb0) - static_m.copy_(mb0) + static_m.copy_(base_mask0) if x_mark0 is not None: static_mark_buf = torch.empty_like(x_mark0) static_mark_buf.copy_(x_mark0) @@ -1110,6 +1185,7 @@ def train_once(cfg: Dict) -> Tuple[float, Dict]: static_ids_buf.copy_(series_ids0) capture_stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() + mask_stats_buf = torch.zeros(1, dtype=torch.float64, device=device) optim.zero_grad(set_to_none=True) model.eval() with torch.cuda.stream(capture_stream): @@ -1122,12 +1198,16 @@ def train_once(cfg: Dict) -> Tuple[float, Dict]: series_static=static_series_buf, series_ids=static_ids_buf, ) + static_nb_mask = negative_binomial_mask( + static_y, static_rate, static_dispersion, static_m + ) static_loss = negative_binomial_nll( y=static_y, rate=static_rate, dispersion=static_dispersion, - mask=static_m, + mask=static_nb_mask, ) + mask_stats_buf[0] = static_nb_mask.sum().to(mask_stats_buf.dtype) static_scaled = static_loss / accum_steps grad_scaler.scale(static_scaled).backward() graph.capture_end() @@ -1141,14 +1221,14 @@ def train_once(cfg: Dict) -> Tuple[float, Dict]: def graph_step( xb: torch.Tensor, yb: torch.Tensor, - mb: torch.Tensor, + base_mask: torch.Tensor, x_mark: torch.Tensor | None, static_feat: torch.Tensor | None, series_idx: torch.Tensor, - ) -> float: + ) -> tuple[float, float, float]: static_x.copy_(xb) static_y.copy_(yb) - static_m.copy_(mb) + static_m.copy_(base_mask) if static_mark_buf is not None: if x_mark is None: raise RuntimeError( @@ -1172,7 +1252,9 @@ def graph_step( if static_ids_buf is not None: static_ids_buf.copy_(series_idx) graph.replay() - return float(static_loss.item()) + mask_true = float(mask_stats_buf[0].item()) + mask_total = float(static_y.numel()) + return float(static_loss.item()), mask_true, mask_total print_config(cfg, current_lr=optim.param_groups[0]["lr"]) for ep in range(1, epochs + 1): @@ -1182,6 +1264,8 @@ def graph_step( num_batches = len(dl_train) copy_time_total = 0.0 iter_time_total = 0.0 + mask_true_total = 0.0 + mask_total = 0.0 for i, batch in enumerate(tqdm(dl_train, desc=f"Epoch {ep}/{epochs}", leave=False)): iter_start = time.perf_counter() xb, yb, mask, x_mark, y_mark, static_feat, series_idx = _unpack_batch(batch) @@ -1189,9 +1273,9 @@ def graph_step( yb = yb.to(device, non_blocking=True) if use_loss_masking: mask = mask.to(device, non_blocking=True) - mb = mask.to(yb.dtype) + base_mask_batch = mask > 0.0 else: - mb = torch.ones_like(yb, dtype=yb.dtype, device=device) + base_mask_batch = torch.ones_like(yb, dtype=torch.bool, device=device) if x_mark is not None: x_mark = x_mark.to(device=device, non_blocking=True) if y_mark is not None: @@ -1209,7 +1293,11 @@ def graph_step( _assert_min_len(xb, _model_input_len(model)) after_copy = time.perf_counter() if use_graphs: - loss_val = graph_step(xb, yb, mb, x_mark, static_feat, series_idx) + loss_val, mask_true_inc, mask_total_inc = graph_step( + xb, yb, base_mask_batch, x_mark, static_feat, series_idx + ) + mask_true_total += mask_true_inc + mask_total += mask_total_inc else: with amp_autocast(cfg["train"]["amp"] and device.type == "cuda"): rate, dispersion = model( @@ -1218,15 +1306,20 @@ def graph_step( series_static=static_feat, series_ids=series_idx, ) + nb_mask_batch = negative_binomial_mask( + yb, rate, dispersion, base_mask_batch + ) loss_value = negative_binomial_nll( y=yb, rate=rate, dispersion=dispersion, - mask=mb, + mask=nb_mask_batch, ) loss = loss_value / accum_steps grad_scaler.scale(loss).backward() loss_val = float(loss_value.item()) + mask_true_total += float(nb_mask_batch.sum().item()) + mask_total += float(nb_mask_batch.numel()) iter_end = time.perf_counter() copy_time_total += after_copy - iter_start iter_time_total += max(iter_end - iter_start, 1e-12) @@ -1248,6 +1341,12 @@ def graph_step( ) ) + if mask_total > 0.0: + coverage = mask_true_total / mask_total + else: + coverage = 0.0 + console().print(f"[blue]Epoch {ep} loss mask coverage: {coverage:.4f}[/blue]") + eval_metrics = _eval_metrics( model, dl_val, diff --git a/tests/test_dataset_pmax.py b/tests/test_dataset_pmax.py index 49f2063..1f6abcf 100644 --- a/tests/test_dataset_pmax.py +++ b/tests/test_dataset_pmax.py @@ -78,19 +78,67 @@ def test_sliding_window_static_features_collate_shape(): assert m0.dtype == torch.float32 assert s0.dtype == torch.float32 assert ids0.dtype == torch.long - assert s0.shape == static.shape - assert ids0.shape == series_ids.shape + assert s0.shape == (1, static.shape[1]) + assert ids0.shape == (1,) assert isinstance(x_mark, torch.Tensor) and x_mark.numel() == 0 assert isinstance(y_mark, torch.Tensor) and y_mark.numel() == 0 loader = torch.utils.data.DataLoader(ds, batch_size=2, shuffle=False) xb, yb, mb, xmb, ymb, sb, idb = next(iter(loader)) - assert sb.shape == (2, static.shape[0], static.shape[1]) - assert idb.shape == (2, series_ids.shape[0]) + assert xb.shape == (2, 3, 1) + assert yb.shape == (2, 2, 1) + assert sb.shape == (2, 1, static.shape[1]) + assert idb.shape == (2, 1) assert isinstance(xmb, torch.Tensor) and xmb.numel() == 0 assert isinstance(ymb, torch.Tensor) and ymb.numel() == 0 +def test_sliding_window_series_are_isolated_per_sample(): + values = np.stack( + [ + np.linspace(1.0, 6.0, num=6, dtype=np.float32), + np.linspace(10.0, 60.0, num=6, dtype=np.float32), + ], + axis=1, + ) + ds = SlidingWindowDataset( + values, + input_len=2, + pred_len=1, + mode="direct", + augment=None, + ) + + x_first, y_first, *_ = ds[0] + x_second, y_second, *_ = ds[1] + np.testing.assert_allclose(x_first.squeeze(-1).numpy(), [1.0, 2.0]) + np.testing.assert_allclose(y_first.squeeze(-1).numpy(), [3.0]) + np.testing.assert_allclose(x_second.squeeze(-1).numpy(), [10.0, 20.0]) + np.testing.assert_allclose(y_second.squeeze(-1).numpy(), [30.0]) + + loader = torch.utils.data.DataLoader(ds, batch_size=2, shuffle=False) + xb, _, _, _, _ = next(iter(loader))[:5] + assert xb.shape == (2, 2, 1) + np.testing.assert_allclose(xb[0].squeeze(-1).numpy(), [1.0, 2.0]) + np.testing.assert_allclose(xb[1].squeeze(-1).numpy(), [10.0, 20.0]) + + +def test_dataframe_loader_preserves_single_series_shape(): + dates = pd.date_range("2024-01-01", periods=4, freq="D") + records = [] + for name, values in {"A": [1, 2, 3, 4], "B": [10, 11, 12, 13]}.items(): + for d, v in zip(dates, values): + records.append({"date": d, "id": name, "target": v}) + df = pd.DataFrame(records) + wide = df.pivot(index="date", columns="id", values="target").sort_index(axis=1) + values = wide.to_numpy(dtype=np.float32) + ds = SlidingWindowDataset(values, input_len=2, pred_len=1, mode="direct") + + x_sample, y_sample, *_ = ds[0] + assert x_sample.shape == (2, 1) + assert y_sample.shape == (1, 1) + + def test_sliding_window_time_features_marks(): idx = pd.date_range("2023-01-01", periods=8, freq="D") values = np.arange(8, dtype=np.float32).reshape(-1, 1) diff --git a/tests/test_fft_period_selector.py b/tests/test_fft_period_selector.py index 3afa020..e41197b 100644 --- a/tests/test_fft_period_selector.py +++ b/tests/test_fft_period_selector.py @@ -100,3 +100,19 @@ def test_fft_period_selector_amp_non_power_of_two_sequence(): assert torch.allclose(amplitudes_amp, amplitudes_ref, atol=1e-3, rtol=1e-3) assert x_amp.grad is not None assert torch.allclose(x_amp.grad, grad_ref, atol=1e-3, rtol=1e-3) + + +def test_fft_period_selector_enforces_min_cycles(): + L = 28 + t = torch.arange(L, dtype=torch.float32) + weekly = torch.sin(2 * math.pi * t / 7) + x = weekly.view(1, L, 1) + + selector = FFTPeriodSelector(k_periods=3, pmax=L) + periods, amplitudes = selector(x) + + assert periods.numel() > 0 + assert torch.all(periods < L) + cycles = (L + periods - 1) // periods + assert torch.any(cycles >= 2) + assert amplitudes.shape[0] == 1 diff --git a/tests/test_negative_binomial_nll.py b/tests/test_negative_binomial_nll.py index 8e0a711..4a57ddd 100644 --- a/tests/test_negative_binomial_nll.py +++ b/tests/test_negative_binomial_nll.py @@ -8,7 +8,7 @@ sys.path.append(str(Path(__file__).resolve().parents[1] / "src")) -from timesnet_forecast.losses import negative_binomial_nll +from timesnet_forecast.losses import negative_binomial_mask, negative_binomial_nll def test_negative_binomial_nll_matches_manual(): @@ -20,17 +20,18 @@ def test_negative_binomial_nll_matches_manual(): alpha = torch.clamp(dispersion, min=1e-8) mu = torch.clamp(rate, min=1e-8) - r = 1.0 / alpha - log_p = torch.log(r) - torch.log(r + mu) - log1m_p = torch.log(mu) - torch.log(r + mu) - manual = -( - torch.lgamma(y + r) - - torch.lgamma(r) + log1p_alpha_mu = torch.log1p(alpha * mu) + log_alpha = torch.log(alpha) + log_mu = torch.log(mu) + inv_alpha = 1.0 / alpha + manual_ll = ( + torch.lgamma(y + inv_alpha) + - torch.lgamma(inv_alpha) - torch.lgamma(y + 1.0) - + r * log_p - + y * log1m_p + + inv_alpha * (-log1p_alpha_mu) + + y * (log_alpha + log_mu - log1p_alpha_mu) ) - assert torch.allclose(loss, manual.mean()) + assert torch.allclose(loss, -manual_ll.mean()) def test_negative_binomial_nll_respects_mask(): @@ -45,21 +46,35 @@ def test_negative_binomial_nll_respects_mask(): # Manual computation using only the unmasked elements alpha = torch.clamp(dispersion, min=1e-8) mu = torch.clamp(rate, min=1e-8) - r = 1.0 / alpha - log_p = torch.log(r) - torch.log(r + mu) - log1m_p = torch.log(mu) - torch.log(r + mu) + inv_alpha = 1.0 / alpha + log1p_alpha_mu = torch.log1p(alpha * mu) + log_alpha = torch.log(alpha) + log_mu = torch.log(mu) log_prob = ( - torch.lgamma(y + r) - - torch.lgamma(r) + torch.lgamma(y + inv_alpha) + - torch.lgamma(inv_alpha) - torch.lgamma(y + 1.0) - + r * log_p - + y * log1m_p + + inv_alpha * (-log1p_alpha_mu) + + y * (log_alpha + log_mu - log1p_alpha_mu) ) manual_masked = -(log_prob * mask).sum() / mask.sum() assert torch.allclose(masked_loss, manual_masked) +def test_negative_binomial_mask_ignores_zeros_but_masks_nans(): + y = torch.tensor([[[0.0, float("nan")]]]) + rate = torch.tensor([[[1.0, 2.0]]]) + dispersion = torch.tensor([[[0.5, 0.5]]]) + base_mask = torch.ones_like(y) + + valid_mask = negative_binomial_mask(y, rate, dispersion, base_mask) + assert valid_mask.dtype == torch.bool + assert valid_mask.shape == y.shape + assert valid_mask[0, 0, 0] + assert not valid_mask[0, 0, 1] + + def test_negative_binomial_nll_autocast_stability(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device.type == "cuda": diff --git a/tests/test_timesnet_forward.py b/tests/test_timesnet_forward.py index f3e95cf..671e219 100644 --- a/tests/test_timesnet_forward.py +++ b/tests/test_timesnet_forward.py @@ -47,6 +47,29 @@ def test_forward_shape_and_head_processing(): assert dispersion_long.shape == dispersion_head.shape == (B, H, N) +def test_timesnet_blocks_track_period_calls(): + torch.manual_seed(0) + B, L, H, N = 1, 12, 3, 1 + model = TimesNet( + input_len=L, + pred_len=H, + d_model=8, + d_ff=16, + n_layers=2, + k_periods=2, + kernel_set=[(3, 3)], + dropout=0.0, + activation="gelu", + mode="direct", + ) + x = torch.randn(B, L, N) + with torch.no_grad(): + model(x) + period_counts = [getattr(block, "_period_calls", 0) for block in model.blocks] + assert len(period_counts) == 2 + assert all(count > 0 for count in period_counts) + + def test_timesnet_pre_embedding_norm_adapts_to_feature_count(): torch.manual_seed(0) B, L, H, N = 2, 12, 4, 1