+from __future__ import annotations
+
+from typing import Dict, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class MLP(nn.Module):
+ def __init__(self, in_dim: int, out_dim: int, hidden_dim: int = 64, dropout: float = 0.0):
+ super().__init__()
+ self.layers = nn.Sequential(
+ nn.Linear(in_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
+ nn.Linear(hidden_dim, out_dim),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.layers(x)
+
+
+class KAN(nn.Module):
+ """
+ Lightweight KAN-style harmonic basis encoder for node features.
+ """
+
+ def __init__(self, in_dim: int, harmonics: int = 5, hidden_dim: int = 64):
+ super().__init__()
+ self.in_dim = in_dim
+ self.harmonics = harmonics
+ self.feature_proj = nn.ModuleList(
+ [nn.Linear(2 * harmonics + 1, hidden_dim) for _ in range(in_dim)]
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # x: (B, N, F)
+ outputs = []
+ for i in range(self.in_dim):
+ xi = x[:, :, i].unsqueeze(-1)
+ basis = [torch.ones_like(xi)]
+ for k in range(1, self.harmonics + 1):
+ basis.append(torch.sin(k * xi))
+ basis.append(torch.cos(k * xi))
+ basis = torch.cat(basis, dim=-1)
+ outputs.append(self.feature_proj[i](basis))
+ return torch.stack(outputs, dim=0).sum(dim=0)
+
+
+class GNBlock(nn.Module):
+ """
+ Message-passing block with residual edge and node updates.
+ """
+
+ def __init__(self, hidden_dim: int, dropout: float = 0.0):
+ super().__init__()
+ self.edge_mlp = MLP(3 * hidden_dim, hidden_dim, hidden_dim, dropout=dropout)
+ self.node_mlp = MLP(2 * hidden_dim, hidden_dim, hidden_dim, dropout=dropout)
+
+ def forward(
+ self,
+ node: torch.Tensor,
+ edge: torch.Tensor,
+ senders: torch.Tensor,
+ receivers: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ sender_feat = node[:, senders, :]
+ receiver_feat = node[:, receivers, :]
+
+ edge_input = torch.cat([edge, sender_feat, receiver_feat], dim=-1)
+ edge = edge + self.edge_mlp(edge_input)
+
+ agg = torch.zeros_like(node)
+ agg.index_add_(1, receivers, edge)
+
+ # Degree-normalized aggregation improves stability when graph density changes.
+ deg = torch.zeros(node.size(1), device=node.device, dtype=node.dtype)
+ deg.index_add_(0, receivers, torch.ones_like(receivers, dtype=node.dtype))
+ agg = agg / deg.clamp(min=1.0).view(1, -1, 1)
+
+ node_input = torch.cat([node, agg], dim=-1)
+ node = node + self.node_mlp(node_input)
+ return node, edge
+
+
+
+
[docs]
+
class HydroGraphNet(nn.Module):
+
"""
+
PhysicsNeMo-inspired HydroGraphNet:
+
encoder -> message-passing processor -> residual delta-state decoder.
+
+
Supports one-step forward prediction and autoregressive rollout.
+
"""
+
+
def __init__(
+
self,
+
node_in_dim: int,
+
edge_in_dim: int,
+
out_dim: int,
+
hidden_dim: int = 64,
+
harmonics: int = 5,
+
num_gn_blocks: int = 5,
+
state_dim: Optional[int] = None,
+
rollout_steps: int = 1,
+
enforce_nonnegative: bool = False,
+
dropout: float = 0.0,
+
):
+
super().__init__()
+
self.node_in_dim = int(node_in_dim)
+
self.edge_in_dim = int(edge_in_dim)
+
self.out_dim = int(out_dim)
+
self.state_dim = int(state_dim) if state_dim is not None else min(2, self.node_in_dim)
+
self.state_dim = max(1, min(self.state_dim, self.node_in_dim))
+
if self.out_dim > self.state_dim:
+
raise ValueError(
+
f"out_dim={self.out_dim} cannot exceed residual state_dim={self.state_dim}."
+
)
+
self.rollout_steps = max(1, int(rollout_steps))
+
self.enforce_nonnegative = bool(enforce_nonnegative)
+
+
# Encoder
+
self.node_encoder = KAN(
+
in_dim=self.node_in_dim,
+
hidden_dim=hidden_dim,
+
harmonics=harmonics,
+
)
+
self.edge_encoder = MLP(
+
in_dim=self.edge_in_dim,
+
out_dim=hidden_dim,
+
hidden_dim=hidden_dim,
+
dropout=dropout,
+
)
+
+
# Processor
+
self.processor = nn.ModuleList(
+
[GNBlock(hidden_dim=hidden_dim, dropout=dropout) for _ in range(num_gn_blocks)]
+
)
+
+
# Decoder predicts delta of physically meaningful states.
+
self.decoder = MLP(
+
in_dim=hidden_dim,
+
out_dim=self.state_dim,
+
hidden_dim=hidden_dim,
+
dropout=dropout,
+
)
+
+
+
[docs]
+
def _edge_index(self, adj: torch.Tensor, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
+
if adj.dim() == 2:
+
a = adj
+
elif adj.dim() == 3:
+
if adj.size(0) != batch_size:
+
raise ValueError(f"adj batch size mismatch: got {adj.size(0)}, expected {batch_size}")
+
a = adj[0]
+
for i in range(1, batch_size):
+
if not torch.allclose(adj[i], a):
+
raise ValueError(
+
"Per-sample varying adjacency is not supported yet. "
+
"Provide a shared (N, N) adjacency or identical (B, N, N) adjacency."
+
)
+
else:
+
raise ValueError("adj must be shaped (N, N) or (B, N, N).")
+
+
a = (a > 0).to(dtype=torch.bool)
+
a.fill_diagonal_(True)
+
return a.nonzero(as_tuple=True)
+
+
+
+
[docs]
+
def _match_edge_dim(self, edge_feat: torch.Tensor) -> torch.Tensor:
+
# edge_feat: (B, E, F_edge_raw)
+
f_raw = edge_feat.size(-1)
+
if f_raw == self.edge_in_dim:
+
return edge_feat
+
if f_raw > self.edge_in_dim:
+
return edge_feat[..., : self.edge_in_dim]
+
pad = torch.zeros(
+
edge_feat.size(0),
+
edge_feat.size(1),
+
self.edge_in_dim - f_raw,
+
device=edge_feat.device,
+
dtype=edge_feat.dtype,
+
)
+
return torch.cat([edge_feat, pad], dim=-1)
+
+
+
+
+
+
+
[docs]
+
def _one_step(
+
self,
+
node_x: torch.Tensor,
+
batch: Dict[str, torch.Tensor],
+
) -> torch.Tensor:
+
# node_x: (B, N, F)
+
if node_x.ndim != 3:
+
raise ValueError(f"Expected node_x with shape (B,N,F), got {tuple(node_x.shape)}")
+
if node_x.size(-1) < self.state_dim:
+
raise ValueError(
+
f"Input feature dim {node_x.size(-1)} is smaller than state_dim {self.state_dim}."
+
)
+
+
adj = batch.get("adj")
+
if adj is None:
+
raise ValueError("HydroGraphNet requires `adj` in the batch.")
+
adj = adj.to(device=node_x.device)
+
+
senders, receivers = self._edge_index(adj, batch_size=node_x.size(0))
+
+
# ---- encoder ----
+
node = self.node_encoder(node_x)
+
edge_in = self._prepare_edge_inputs(
+
batch=batch,
+
senders=senders,
+
receivers=receivers,
+
batch_size=node.size(0),
+
device=node.device,
+
dtype=node.dtype,
+
)
+
edge = self.edge_encoder(edge_in)
+
+
# ---- processor ----
+
for gn in self.processor:
+
node, edge = gn(node, edge, senders, receivers)
+
+
# ---- decoder: residual state update ----
+
delta_state = self.decoder(node) # (B, N, state_dim)
+
prev_state = node_x[..., : self.state_dim]
+
next_state = prev_state + delta_state
+
if self.enforce_nonnegative:
+
next_state = next_state.clamp_min(0.0)
+
+
# Return requested targets from the evolved state.
+
return next_state[..., : self.out_dim]
+
+
+
+
[docs]
+
def rollout(self, batch: Dict[str, torch.Tensor], predict_steps: int) -> torch.Tensor:
+
batch_roll = dict(batch)
+
batch_roll["predict_steps"] = int(predict_steps)
+
out = self.forward(batch_roll)
+
if out.ndim != 4:
+
raise RuntimeError("rollout expected stacked output with shape (B, S, N, out_dim).")
+
return out
+
+
+
+
[docs]
+
def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
+
# batch["x"]: (B, T, N, F)
+
x = batch["x"]
+
if x.ndim != 4:
+
raise ValueError(
+
f"HydroGraphNet expects x shaped (B, T, N, F), got {tuple(x.shape)}"
+
)
+
+
predict_steps = int(batch.get("predict_steps", self.rollout_steps))
+
predict_steps = max(1, predict_steps)
+
+
history = x
+
preds = []
+
for _ in range(predict_steps):
+
node_x = history[:, -1]
+
y_next = self._one_step(node_x=node_x, batch=batch) # (B, N, out_dim)
+
preds.append(y_next)
+
+
if predict_steps > 1:
+
next_frame = history[:, -1].clone()
+
next_frame[..., : self.out_dim] = y_next
+
history = torch.cat([history[:, 1:], next_frame.unsqueeze(1)], dim=1)
+
+
if predict_steps == 1:
+
return preds[0]
+
return torch.stack(preds, dim=1)
+
+
+
+
+
+
[docs]
+
class HydroGraphNetLoss(nn.Module):
+
"""
+
Supervised regression loss with optional continuity regularization.
+
"""
+
+
def __init__(self, supervised_weight: float = 1.0, continuity_weight: float = 0.0):
+
super().__init__()
+
self.supervised_weight = float(supervised_weight)
+
self.continuity_weight = float(continuity_weight)
+
+
+
[docs]
+
def forward(
+
self,
+
preds: torch.Tensor,
+
targets: torch.Tensor,
+
prev_state: Optional[torch.Tensor] = None,
+
cell_area: Optional[torch.Tensor] = None,
+
) -> Tuple[torch.Tensor, Dict[str, float]]:
+
supervised = F.mse_loss(preds, targets)
+
total = self.supervised_weight * supervised
+
metrics: Dict[str, float] = {"mse": float(supervised.detach().cpu())}
+
+
if (
+
self.continuity_weight > 0
+
and prev_state is not None
+
and cell_area is not None
+
and preds.size(-1) >= 2
+
and prev_state.size(-1) >= 2
+
):
+
# Approximate local continuity: depth-change * area ~= volume-change
+
depth_delta = preds[..., 0] - prev_state[..., 0]
+
volume_delta = preds[..., 1] - prev_state[..., 1]
+
area = cell_area.to(device=preds.device, dtype=preds.dtype)
+
if area.dim() == 1:
+
area = area.unsqueeze(0)
+
continuity = F.mse_loss(depth_delta * area, volume_delta)
+
total = total + self.continuity_weight * continuity
+
metrics["continuity"] = float(continuity.detach().cpu())
+
+
metrics["total"] = float(total.detach().cpu())
+
return total, metrics
+
+
+
+
+
+
[docs]
+
def hydrographnet_builder(
+
task: str,
+
node_in_dim: int,
+
edge_in_dim: int,
+
out_dim: int,
+
**kwargs,
+
) -> HydroGraphNet:
+
if task != "regression":
+
raise ValueError("HydroGraphNet only supports regression")
+
+
return HydroGraphNet(
+
node_in_dim=node_in_dim,
+
edge_in_dim=edge_in_dim,
+
out_dim=out_dim,
+
hidden_dim=kwargs.get("hidden_dim", 64),
+
harmonics=kwargs.get("harmonics", 5),
+
num_gn_blocks=kwargs.get("num_gn_blocks", 5),
+
state_dim=kwargs.get("state_dim"),
+
rollout_steps=kwargs.get("rollout_steps", 1),
+
enforce_nonnegative=kwargs.get("enforce_nonnegative", False),
+
dropout=kwargs.get("dropout", 0.0),
+
)
+
+
+
+__all__ = ["HydroGraphNet", "HydroGraphNetLoss", "hydrographnet_builder"]
+
+