From 2e3a417eb9f6ec601017b91dfba092b3ab0fc95a Mon Sep 17 00:00:00 2001 From: Abdelrahman Helal Date: Mon, 1 Jun 2026 20:57:16 -0700 Subject: [PATCH 1/4] Add support for PytorchGeometric models --- requirements.txt | 3 + shiftkit/__init__.py | 6 + shiftkit/models/__init__.py | 7 + shiftkit/models/gnn_pyg.py | 254 ++++++++++++++++++++++++++++++++++++ 4 files changed, 270 insertions(+) create mode 100644 shiftkit/models/gnn_pyg.py diff --git a/requirements.txt b/requirements.txt index ab964f7..6a3b783 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,9 @@ matplotlib>=3.7.0 scikit-learn>=1.2.0 tqdm>=4.65.0 +# Optional — required only for shiftkit.models.GNN (PyTorch Geometric) +# torch-geometric>=2.4.0 + # Optional — required only for UMAP projection in diagnostics # umap-learn>=0.5.0 diff --git a/shiftkit/__init__.py b/shiftkit/__init__.py index 3d23bcd..bd5e7d4 100644 --- a/shiftkit/__init__.py +++ b/shiftkit/__init__.py @@ -12,6 +12,10 @@ from .data.datasets import DataManager from .models.networks import MLP, CNN, MLPRegressor from .models.gnn import SimpleGCN +try: + from .models.gnn_pyg import GNN +except ImportError: + GNN = None # torch-geometric not installed from .methods.base import BaseTrainer, TrainerRegistry from .methods.mmd import MMDLoss, MMDTrainer, SourceOnlyTrainer from .methods.lmmd import LMMDLoss, LMMDTrainer @@ -42,3 +46,5 @@ "plot_latent_space", "plot_training_history", "compare_latent_spaces", "plot_confusion_matrix", "plot_roc_curve", ] +if GNN is not None: + __all__.append("GNN") diff --git a/shiftkit/models/__init__.py b/shiftkit/models/__init__.py index 6e6269c..e4ff36a 100644 --- a/shiftkit/models/__init__.py +++ b/shiftkit/models/__init__.py @@ -1,4 +1,11 @@ from .networks import MLP, CNN, MLPRegressor from .gnn import SimpleGCN +try: + from .gnn_pyg import GNN +except ImportError: + GNN = None # torch-geometric not installed + __all__ = ["MLP", "CNN", "MLPRegressor", "SimpleGCN"] +if GNN is not None: + __all__.append("GNN") diff --git a/shiftkit/models/gnn_pyg.py b/shiftkit/models/gnn_pyg.py new file mode 100644 index 0000000..4a09f1f --- /dev/null +++ b/shiftkit/models/gnn_pyg.py @@ -0,0 +1,254 @@ +""" +PyTorch Geometric GNN for graph-level domain adaptation. + +GNN +--- +Configurable stack of ``torch_geometric.nn.conv`` layers with graph-level +pooling. Accepts PyG ``Data`` / ``Batch`` objects (use +``torch_geometric.loader.DataLoader``) and exposes the same +``encode`` / ``classify`` / ``regress`` / ``forward`` interface as +:class:`~shiftkit.models.gnn.SimpleGCN` and :class:`~shiftkit.models.networks.MLP`. + +Requires ``torch-geometric`` (optional dependency):: + + pip install torch-geometric +""" + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import TYPE_CHECKING, Union + +if TYPE_CHECKING: + from torch_geometric.data import Batch + +try: + from torch_geometric.data import Data + from torch_geometric.nn import ( + SAGEConv, + GCNConv, + GATConv, + GINConv, + GraphConv, + global_mean_pool, + global_max_pool, + global_add_pool, + ) +except ImportError as e: + raise ImportError( + "shiftkit.models.GNN requires torch-geometric. " + "Install it with: pip install torch-geometric" + ) from e + +PyGData = Union[Data, "Batch"] + +# ─── conv registry ─────────────────────────────────────────────────────────── + +_CONV_ALIASES = { + "SAGE": "SAGE", + "SAGECONV": "SAGE", + "GCN": "GCN", + "GCNCONV": "GCN", + "GAT": "GAT", + "GATCONV": "GAT", + "GIN": "GIN", + "GINCONV": "GIN", + "GRAPHCONV": "GRAPHCONV", + "GRAPH": "GRAPHCONV", +} + +_CONV_CLASSES = { + "SAGE": SAGEConv, + "GCN": GCNConv, + "GAT": GATConv, + "GIN": GINConv, + "GRAPHCONV": GraphConv, +} + +_POOL_FUNCS = { + "mean": global_mean_pool, + "max": global_max_pool, + "sum": global_add_pool, + "add": global_add_pool, +} + + +def _resolve_conv(name: str) -> str: + """Return canonical registry key for *name* (case-insensitive).""" + key = _CONV_ALIASES.get(name.upper().replace("_", "")) + if key is None: + supported = sorted({k for k in _CONV_CLASSES}) + raise ValueError( + f"Unknown GNN model '{name}'. Supported: {supported}" + ) + return key + + +def _gin_mlp(in_channels: int, out_channels: int) -> nn.Sequential: + return nn.Sequential( + nn.Linear(in_channels, out_channels), + nn.ReLU(), + nn.Linear(out_channels, out_channels), + ) + + +def _build_conv( + conv_key: str, + in_channels: int, + out_channels: int, + aggr: str = "mean", +) -> nn.Module: + """Instantiate one convolution layer for the chosen architecture.""" + if conv_key == "SAGE": + return SAGEConv(in_channels, out_channels, aggr=aggr) + if conv_key == "GCN": + return GCNConv(in_channels, out_channels) + if conv_key == "GAT": + return GATConv(in_channels, out_channels, heads=1, concat=False) + if conv_key == "GIN": + return GINConv(_gin_mlp(in_channels, out_channels)) + if conv_key == "GRAPHCONV": + return GraphConv(in_channels, out_channels, aggr=aggr) + raise ValueError(f"Unhandled conv key: {conv_key}") + + +def _graph_batch_vector(data: PyGData) -> torch.Tensor: + """Return per-node batch indices; synthesise zeros for a single graph.""" + if getattr(data, "batch", None) is not None: + return data.batch + return torch.zeros(data.num_nodes, dtype=torch.long, device=data.x.device) + + +# ─── GNN model ───────────────────────────────────────────────────────────────── + +class GNN(nn.Module): + """ + Configurable PyG GNN for graph-level classification or regression. + + Parameters + ---------- + data : template ``Data`` object (used for ``num_node_features``) + model_name : conv type — ``SAGE``, ``GCN``, ``GAT``, ``GIN``, ``GraphConv`` + hidden_channels : width of conv layers and graph-level latent size + num_layers : number of message-passing layers (>= 1) + num_classes : output classes for ``classify`` (required if ``regress=False``) + regress : if ``True``, build regression head only; ``forward`` uses ``regress`` + pool : graph readout — ``mean``, ``max``, ``sum``, ``add``, or ``none`` (node-level, no pool) + use_layernorm : apply ``LayerNorm`` after each conv + dropout : dropout probability between conv layers + aggr : aggregation for convs that support it (e.g. SAGE, GraphConv) + """ + + def __init__( + self, + data: Data, + model_name: str, + hidden_channels: int, + num_layers: int, + num_classes: int = 2, + regress: bool = False, + pool: str = "mean", + use_layernorm: bool = True, + dropout: float = 0.0, + aggr: str = "mean", + ): + super().__init__() + + if num_layers < 1: + raise ValueError("num_layers must be >= 1") + + pool_key = pool.lower() + if pool_key not in _POOL_FUNCS and pool_key != "none": + raise ValueError( + f"Unknown pool '{pool}'. Choose from: {list(_POOL_FUNCS)} + ['none']" + ) + + if not regress and num_classes < 1: + raise ValueError("num_classes must be >= 1 when regress=False") + + conv_key = _resolve_conv(model_name) + in_channels = data.num_node_features + + self.is_regression = regress + self.latent_dim = hidden_channels + self.hidden_channels = hidden_channels + self.num_layers = num_layers + self.dropout = dropout + self.pool = pool_key + self._pool_fn = None if pool_key == "none" else _POOL_FUNCS[pool_key] + self._conv_key = conv_key + + self.convs = nn.ModuleList() + self.norms = nn.ModuleList() + + self.convs.append(_build_conv(conv_key, in_channels, hidden_channels, aggr)) + self.norms.append( + nn.LayerNorm(hidden_channels) if use_layernorm else nn.Identity() + ) + + for _ in range(num_layers - 1): + self.convs.append( + _build_conv(conv_key, hidden_channels, hidden_channels, aggr) + ) + self.norms.append( + nn.LayerNorm(hidden_channels) if use_layernorm else nn.Identity() + ) + + if regress: + self.regressor = nn.Linear(hidden_channels, 1) + self.classifier = None + else: + self.classifier = nn.Linear(hidden_channels, num_classes) + self.regressor = None + + def encode(self, data: PyGData) -> torch.Tensor: + """ + Message passing, then optional global pooling. + + Parameters + ---------- + data : PyG ``Data`` or ``Batch`` + + Returns + ------- + z : (num_nodes, hidden_channels) if ``pool='none'``, else (num_graphs, hidden_channels) + """ + x, edge_index = data.x, data.edge_index + + h = x + for conv, norm in zip(self.convs, self.norms): + h = conv(h, edge_index) + h = norm(h) + h = F.relu(h) + if self.dropout > 0.0: + h = F.dropout(h, p=self.dropout, training=self.training) + + if self._pool_fn is None: + return h + batch = _graph_batch_vector(data) + return self._pool_fn(h, batch) + + def classify(self, z: torch.Tensor) -> torch.Tensor: + """Linear classification head on graph-level features.""" + if self.classifier is None: + raise RuntimeError( + "GNN was built with regress=True; classify() is not available." + ) + return self.classifier(z) + + def regress(self, z: torch.Tensor) -> torch.Tensor: + """Linear regression head (scalar output per graph).""" + if self.regressor is None: + raise RuntimeError( + "GNN was built with regress=False; regress() is not available. " + "Use classify() or reconstruct with regress=True." + ) + return self.regressor(z) + + def forward(self, data: PyGData) -> torch.Tensor: + z = self.encode(data) + if self.is_regression: + return self.regress(z) + return self.classify(z) From e351cebf282425b790052b483defe80cd0045600 Mon Sep 17 00:00:00 2001 From: Abdelrahman Helal Date: Tue, 2 Jun 2026 16:13:57 -0700 Subject: [PATCH 2/4] Add support for external PyG Data for both graph-level and node-level regression and classification --- shiftkit/data/__init__.py | 9 + shiftkit/data/datasets.py | 43 ++++ shiftkit/data/pyg_utils.py | 354 +++++++++++++++++++++++++++++++++ shiftkit/methods/mmd.py | 93 +++++---- shiftkit/methods/node_batch.py | 64 ++++++ shiftkit/methods/regression.py | 67 +++++-- 6 files changed, 579 insertions(+), 51 deletions(-) create mode 100644 shiftkit/data/pyg_utils.py create mode 100644 shiftkit/methods/node_batch.py diff --git a/shiftkit/data/__init__.py b/shiftkit/data/__init__.py index 2ae3fb6..16f093f 100644 --- a/shiftkit/data/__init__.py +++ b/shiftkit/data/__init__.py @@ -1,3 +1,12 @@ from .datasets import DataManager, SineWaveDataset, CaliforniaHousingDataset +try: + from .pyg_utils import NodeGraphBatch, ensure_masks, build_pyg_domain_loaders +except ImportError: + NodeGraphBatch = None + ensure_masks = None + build_pyg_domain_loaders = None + __all__ = ["DataManager", "SineWaveDataset", "CaliforniaHousingDataset"] +if NodeGraphBatch is not None: + __all__ += ["NodeGraphBatch", "ensure_masks", "build_pyg_domain_loaders"] diff --git a/shiftkit/data/datasets.py b/shiftkit/data/datasets.py index 640ce74..4ce364a 100644 --- a/shiftkit/data/datasets.py +++ b/shiftkit/data/datasets.py @@ -428,6 +428,49 @@ def _california_housing(root, batch_size, train, num_workers, **kw): _REGISTRY["california_housing"] = _california_housing + def _pyg_domains(root, batch_size, train, num_workers, **kw): + """ + PyG source/target domain pair (graph-level or node-level). + + Required kwargs + --------------- + source, target : PyG ``Data`` (node-level) or list/dataset of ``Data`` (graph-level) + task_level : ``"node"`` or ``"graph"`` (default ``"node"``) + + Optional kwargs + ----------------- + train_ratio, val_ratio, split_seed, split_mode (``"stratified"`` | ``"random"``) + """ + from .pyg_utils import build_pyg_domain_loaders + + source = kw.get("source") + target = kw.get("target") + if source is None or target is None: + raise ValueError( + "pyg_domains requires 'source' and 'target' PyG Data object(s). " + "Example: dm.load('pyg_domains', source=src_data, target=tgt_data, ...)" + ) + task_level = kw.get("task_level", "node") + train_ratio = float(kw.get("train_ratio", 0.6)) + val_ratio = float(kw.get("val_ratio", 0.2)) + split_seed = int(kw.get("split_seed", 42)) + split_mode = kw.get("split_mode", "stratified") + + return build_pyg_domain_loaders( + task_level=task_level, + source=source, + target=target, + train=train, + batch_size=batch_size, + num_workers=num_workers, + train_ratio=train_ratio, + val_ratio=val_ratio, + split_seed=split_seed, + split_mode=split_mode, + ) + + _REGISTRY["pyg_domains"] = _pyg_domains + _register_defaults() diff --git a/shiftkit/data/pyg_utils.py b/shiftkit/data/pyg_utils.py new file mode 100644 index 0000000..d5c4f5f --- /dev/null +++ b/shiftkit/data/pyg_utils.py @@ -0,0 +1,354 @@ +""" +PyTorch Geometric utilities for ShiftKit DataManager. + +Supports graph-level (many graphs per domain) and node-level (one graph per +domain) domain adaptation with stratified train/val/test splits. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset + +try: + from torch_geometric.data import Data, Batch +except ImportError as e: + raise ImportError( + "shiftkit.data.pyg_utils requires torch-geometric. " + "Install it with: pip install torch-geometric" + ) from e + + +# ─── batch container for node-level single-graph domains ───────────────────── + +@dataclass +class NodeGraphBatch: + """One full graph plus node indices and labels for the current step.""" + + graph: Data + node_idx: torch.Tensor + y: torch.Tensor + + +def is_node_graph_batch(x) -> bool: + return isinstance(x, NodeGraphBatch) + + +def move_node_graph_batch(batch: NodeGraphBatch, device: torch.device) -> NodeGraphBatch: + return NodeGraphBatch( + graph=batch.graph.to(device), + node_idx=batch.node_idx.to(device), + y=batch.y.to(device), + ) + + +# ─── mask / split helpers ──────────────────────────────────────────────────── + +def _has_masks(data: Data) -> bool: + return ( + hasattr(data, "train_mask") + and data.train_mask is not None + and data.train_mask.any() + ) + + +def _labels_numpy(data: Data) -> np.ndarray: + y = data.y + if y is None: + raise ValueError("PyG Data must have a 'y' attribute for stratified splitting.") + return y.detach().cpu().numpy().reshape(-1) + + +def _is_discrete_labels(y: np.ndarray) -> bool: + if y.dtype.kind in ("i", "u", "b"): + return True + uniq = np.unique(y) + if len(uniq) <= 20 and np.allclose(uniq, np.round(uniq)): + return True + return False + + +def _stratified_indices( + y: np.ndarray, + train_ratio: float, + val_ratio: float, + seed: int, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + rng = np.random.RandomState(seed) + n = len(y) + train_mask = np.zeros(n, dtype=bool) + val_mask = np.zeros(n, dtype=bool) + test_mask = np.zeros(n, dtype=bool) + + for cls in np.unique(y): + idx = np.where(y == cls)[0] + rng.shuffle(idx) + n_cls = len(idx) + n_train = max(1, int(n_cls * train_ratio)) if n_cls > 2 else max(0, n_cls - 1) + n_val = max(0, int(n_cls * val_ratio)) + if n_train + n_val >= n_cls: + n_train = max(1, n_cls - 1) + n_val = 0 + train_mask[idx[:n_train]] = True + val_mask[idx[n_train : n_train + n_val]] = True + test_mask[idx[n_train + n_val :]] = True + + unassigned = ~(train_mask | val_mask | test_mask) + if unassigned.any(): + idx_rest = np.where(unassigned)[0] + rng.shuffle(idx_rest) + test_mask[idx_rest] = True + + return train_mask, val_mask, test_mask + + +def _random_indices( + n: int, + train_ratio: float, + val_ratio: float, + seed: int, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + rng = np.random.RandomState(seed) + perm = rng.permutation(n) + n_train = max(1, int(n * train_ratio)) + n_val = max(0, int(n * val_ratio)) + if n_train + n_val >= n: + n_train = max(1, n - 1) + n_val = 0 + train_idx = perm[:n_train] + val_idx = perm[n_train : n_train + n_val] + test_idx = perm[n_train + n_val :] + train_mask = np.zeros(n, dtype=bool) + val_mask = np.zeros(n, dtype=bool) + test_mask = np.zeros(n, dtype=bool) + train_mask[train_idx] = True + val_mask[val_idx] = True + test_mask[test_idx] = True + if not test_mask.any(): + test_mask[perm[-1]] = True + train_mask[perm[-1]] = False + return train_mask, val_mask, test_mask + + +def ensure_masks( + data: Data, + train_ratio: float = 0.6, + val_ratio: float = 0.2, + seed: int = 42, + split_mode: str = "stratified", +) -> Data: + """ + Assign ``train_mask``, ``val_mask``, and ``test_mask`` on *data* in-place. + + Skips splitting if ``train_mask`` is already present and non-empty. + """ + if _has_masks(data): + return data + + y_np = _labels_numpy(data) + n = data.num_nodes + use_stratified = split_mode == "stratified" and _is_discrete_labels(y_np) + + if use_stratified: + tr, va, te = _stratified_indices(y_np, train_ratio, val_ratio, seed) + else: + tr, va, te = _random_indices(n, train_ratio, val_ratio, seed) + + device = data.y.device if data.y is not None else "cpu" + data.train_mask = torch.tensor(tr, dtype=torch.bool, device=device) + data.val_mask = torch.tensor(va, dtype=torch.bool, device=device) + data.test_mask = torch.tensor(te, dtype=torch.bool, device=device) + return data + + +def _normalize_graph_list( + graphs: Union[Data, Sequence[Data], Dataset], +) -> List[Data]: + if isinstance(graphs, Data): + return [graphs] + if isinstance(graphs, Dataset): + return [graphs[i] for i in range(len(graphs))] + return list(graphs) + + +def split_graph_list( + graphs: Union[Sequence[Data], Data, Dataset], + train: bool, + train_ratio: float = 0.6, + val_ratio: float = 0.2, + seed: int = 42, + split_mode: str = "stratified", +) -> List[Data]: + """ + Split a list of graphs into train+val (``train=True``) or test (``train=False``). + """ + graph_list = _normalize_graph_list(graphs) + n = len(graph_list) + if n == 0: + raise ValueError("Graph list is empty.") + + labels = [] + for g in graph_list: + if g.y is None: + labels.append(0) + elif g.y.numel() == 1: + labels.append(int(g.y.item()) if g.y.dtype in (torch.int64, torch.int32) else float(g.y.item())) + else: + labels.append(int(g.y.view(-1)[0].item())) + y_np = np.array(labels) + + use_stratified = split_mode == "stratified" and _is_discrete_labels( + y_np.astype(float) if y_np.dtype.kind == "f" else y_np + ) + + if use_stratified: + tr, va, te = _stratified_indices(y_np, train_ratio, val_ratio, seed) + else: + tr, va, te = _random_indices(n, train_ratio, val_ratio, seed) + + fit_mask = tr | va + chosen = fit_mask if train else te + return [g for g, m in zip(graph_list, chosen) if m] + + +# ─── datasets / loaders ────────────────────────────────────────────────────── + +class _GraphListDataset(Dataset): + def __init__(self, graphs: List[Data]): + self.graphs = graphs + + def __len__(self): + return len(self.graphs) + + def __getitem__(self, idx): + g = self.graphs[idx] + y = g.y + if y is None: + raise ValueError("Graph-level task requires graph label g.y") + if y.dim() > 0: + y = y.view(-1)[0] + return g, y.long() if y.dtype in (torch.int64, torch.int32) else y.float() + + +def _graph_collate(items): + graphs = [item[0] for item in items] + ys = torch.stack([item[1] for item in items]) + return Batch.from_data_list(graphs), ys + + +def build_graph_loaders( + source_graphs: List[Data], + target_graphs: List[Data], + batch_size: int, + num_workers: int, + shuffle: bool, +) -> Tuple[DataLoader, DataLoader]: + src_loader = DataLoader( + _GraphListDataset(source_graphs), + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + collate_fn=_graph_collate, + ) + tgt_loader = DataLoader( + _GraphListDataset(target_graphs), + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + collate_fn=_graph_collate, + ) + return src_loader, tgt_loader + + +class _SingleGraphNodeDataset(Dataset): + """One batch per epoch: full graph + train or test node indices.""" + + def __init__(self, data: Data, eval_split: bool = False): + self.data = data + self.eval_split = eval_split + + def __len__(self): + return 1 + + def __getitem__(self, idx): + mask = self.data.test_mask if self.eval_split else self.data.train_mask + node_idx = mask.nonzero(as_tuple=False).view(-1) + y = self.data.y[node_idx] + if y.dtype in (torch.int64, torch.int32): + y = y.long().view(-1) + else: + y = y.float().view(-1) + batch = NodeGraphBatch(graph=self.data, node_idx=node_idx, y=y) + return batch, y + + +def _node_collate(items): + return items[0] + + +def build_node_loaders( + source: Data, + target: Data, + batch_size: int, + num_workers: int, + train: bool, +) -> Tuple[DataLoader, DataLoader]: + eval_split = not train + src_loader = DataLoader( + _SingleGraphNodeDataset(source, eval_split=eval_split), + batch_size=1, + shuffle=False, + num_workers=num_workers, + collate_fn=_node_collate, + ) + tgt_loader = DataLoader( + _SingleGraphNodeDataset(target, eval_split=eval_split), + batch_size=1, + shuffle=False, + num_workers=num_workers, + collate_fn=_node_collate, + ) + return src_loader, tgt_loader + + +def build_pyg_domain_loaders( + task_level: str, + source, + target, + train: bool, + batch_size: int, + num_workers: int, + train_ratio: float, + val_ratio: float, + split_seed: int, + split_mode: str, +) -> Tuple[DataLoader, DataLoader]: + """ + Build (source_loader, target_loader) for ``pyg_domains`` factory. + """ + task_level = task_level.lower() + if task_level not in ("node", "graph"): + raise ValueError("task_level must be 'node' or 'graph'") + + if task_level == "node": + if isinstance(source, list) or isinstance(target, list): + raise ValueError( + "task_level='node' expects a single PyG Data object per domain, not a list." + ) + source = ensure_masks(source, train_ratio, val_ratio, split_seed, split_mode) + target = ensure_masks(target, train_ratio, val_ratio, split_seed + 1, split_mode) + return build_node_loaders(source, target, batch_size, num_workers, train) + + src_graphs = split_graph_list( + source, train, train_ratio, val_ratio, split_seed, split_mode + ) + tgt_graphs = split_graph_list( + target, train, train_ratio, val_ratio, split_seed + 1, split_mode + ) + return build_graph_loaders( + src_graphs, tgt_graphs, batch_size, num_workers, shuffle=train + ) diff --git a/shiftkit/methods/mmd.py b/shiftkit/methods/mmd.py index e52c4b4..2a2a536 100644 --- a/shiftkit/methods/mmd.py +++ b/shiftkit/methods/mmd.py @@ -27,6 +27,16 @@ from tqdm import tqdm from typing import Optional, List +from .node_batch import ( + is_node_graph_batch, + move_node_graph_batch, + unpack_batch, + batch_accuracy, + node_classify_logits, + node_latent_vectors, + node_classification_correct, +) + # ─── MMD Loss ──────────────────────────────────────────────────────────────── @@ -72,12 +82,6 @@ def _auto_device() -> torch.device: return torch.device("cpu") -@torch.no_grad() -def _batch_accuracy(model: nn.Module, x: torch.Tensor, y: torch.Tensor) -> int: - """Return number of correct predictions for a single batch (no grad).""" - return (model(x).argmax(1) == y).sum().item() - - # ─── MMD Trainer ───────────────────────────────────────────────────────────── class MMDTrainer: @@ -153,22 +157,33 @@ def _train_epoch(self, epoch: int, total_epochs: int, warmup: bool = False) -> d loader, total=n_batches, desc=f"Epoch {epoch}/{total_epochs}", leave=False ): - x_src, y_src = x_src.to(self.device), y_src.to(self.device) - x_tgt, y_tgt = x_tgt.to(self.device), y_tgt.to(self.device) - - z_src = self.model.encode(x_src) - - logits = self.model.classify(z_src) - ce = self.ce_loss(logits, y_src) - - if warmup: - loss = ce - mmd_val = 0.0 + x_src, y_src, node_batch = unpack_batch(x_src, y_src, self.device) + x_tgt, y_tgt, _ = unpack_batch(x_tgt, y_tgt, self.device) + + if node_batch: + z_src = node_latent_vectors(self.model, x_src) + logits = self.model.classify(z_src) + ce = self.ce_loss(logits, y_src) + if warmup: + loss = ce + mmd_val = 0.0 + else: + z_tgt = node_latent_vectors(self.model, x_tgt) + mmd = self.mmd_loss(z_src, z_tgt) + loss = ce + self.mmd_weight * mmd + mmd_val = mmd.item() else: - z_tgt = self.model.encode(x_tgt) - mmd = self.mmd_loss(z_src, z_tgt) - loss = ce + self.mmd_weight * mmd - mmd_val = mmd.item() + z_src = self.model.encode(x_src) + logits = self.model.classify(z_src) + ce = self.ce_loss(logits, y_src) + if warmup: + loss = ce + mmd_val = 0.0 + else: + z_tgt = self.model.encode(x_tgt) + mmd = self.mmd_loss(z_src, z_tgt) + loss = ce + self.mmd_weight * mmd + mmd_val = mmd.item() self.optimizer.zero_grad() loss.backward() @@ -180,8 +195,7 @@ def _train_epoch(self, epoch: int, total_epochs: int, warmup: bool = False) -> d src_correct += (logits.argmax(1) == y_src).sum().item() n_src += y_src.size(0) - # target accuracy (no grad, reuses current weights) - tgt_correct += _batch_accuracy(self.model, x_tgt, y_tgt) + tgt_correct += batch_accuracy(self.model, x_tgt, y_tgt) n_tgt += y_tgt.size(0) return { @@ -199,9 +213,13 @@ def evaluate(self, loader: DataLoader, domain: str = "source") -> dict: self.model.eval() correct = total = 0 for x, y in loader: - x, y = x.to(self.device), y.to(self.device) - correct += (self.model(x).argmax(1) == y).sum().item() - total += y.size(0) + if is_node_graph_batch(x): + x = move_node_graph_batch(x, self.device) + correct += node_classification_correct(self.model, x) + else: + x, y = x.to(self.device), y.to(self.device) + correct += (self.model(x).argmax(1) == y).sum().item() + total += y.size(0) return {"domain": domain, "accuracy": correct / total, "n_samples": total} @@ -270,11 +288,14 @@ def _train_epoch(self, epoch: int, total_epochs: int) -> dict: loader, total=n_batches, desc=f"Epoch {epoch}/{total_epochs}", leave=False ): - x_src, y_src = x_src.to(self.device), y_src.to(self.device) - x_tgt, y_tgt = x_tgt.to(self.device), y_tgt.to(self.device) + x_src, y_src, node_batch = unpack_batch(x_src, y_src, self.device) + x_tgt, y_tgt, _ = unpack_batch(x_tgt, y_tgt, self.device) - logits = self.model(x_src) - loss = self.ce_loss(logits, y_src) + if node_batch: + logits = node_classify_logits(self.model, x_src) + else: + logits = self.model(x_src) + loss = self.ce_loss(logits, y_src) self.optimizer.zero_grad() loss.backward() @@ -284,7 +305,7 @@ def _train_epoch(self, epoch: int, total_epochs: int) -> dict: src_correct += (logits.argmax(1) == y_src).sum().item() n_src += y_src.size(0) - tgt_correct += _batch_accuracy(self.model, x_tgt, y_tgt) + tgt_correct += batch_accuracy(self.model, x_tgt, y_tgt) n_tgt += y_tgt.size(0) return { @@ -302,7 +323,11 @@ def evaluate(self, loader: DataLoader, domain: str = "source") -> dict: self.model.eval() correct = total = 0 for x, y in loader: - x, y = x.to(self.device), y.to(self.device) - correct += (self.model(x).argmax(1) == y).sum().item() - total += y.size(0) + if is_node_graph_batch(x): + x = move_node_graph_batch(x, self.device) + correct += node_classification_correct(self.model, x) + else: + x, y = x.to(self.device), y.to(self.device) + correct += (self.model(x).argmax(1) == y).sum().item() + total += y.size(0) return {"domain": domain, "accuracy": correct / total, "n_samples": total} diff --git a/shiftkit/methods/node_batch.py b/shiftkit/methods/node_batch.py new file mode 100644 index 0000000..b558a60 --- /dev/null +++ b/shiftkit/methods/node_batch.py @@ -0,0 +1,64 @@ +""" +Shared helpers for node-level PyG batches (single graph per domain). +""" + +from typing import Tuple, Union + +import torch +import torch.nn as nn + +from shiftkit.data.pyg_utils import NodeGraphBatch, is_node_graph_batch, move_node_graph_batch + + +def unpack_batch( + x, y: torch.Tensor, device: torch.device +) -> Tuple[Union[torch.Tensor, NodeGraphBatch], torch.Tensor, bool]: + """Return ``(x, y, is_node_batch)``; move tensors/batches to *device*.""" + if is_node_graph_batch(x): + batch = move_node_graph_batch(x, device) + return batch, batch.y, True + return x.to(device), y.to(device), False + + +@torch.no_grad() +def batch_accuracy(model: nn.Module, x, y: torch.Tensor) -> int: + """Return number of correct predictions for a single batch (no grad).""" + if is_node_graph_batch(x): + x = move_node_graph_batch(x, next(model.parameters()).device) + return node_classification_correct(model, x) + return (model(x).argmax(1) == y).sum().item() + + +def node_classify_logits(model: nn.Module, batch: NodeGraphBatch) -> torch.Tensor: + z = model.encode(batch.graph) + return model.classify(z[batch.node_idx]) + + +def node_regress_preds(model: nn.Module, batch: NodeGraphBatch) -> torch.Tensor: + z = model.encode(batch.graph) + preds = model.regress(z[batch.node_idx]) + return preds.view(-1) if preds.dim() > 1 and preds.size(-1) == 1 else preds.squeeze(-1) + + +def node_latent_vectors(model: nn.Module, batch: NodeGraphBatch) -> torch.Tensor: + z = model.encode(batch.graph) + return z[batch.node_idx] + + +@torch.no_grad() +def node_classification_correct(model: nn.Module, batch: NodeGraphBatch) -> int: + logits = node_classify_logits(model, batch) + return (logits.argmax(1) == batch.y).sum().item() + + +__all__ = [ + "NodeGraphBatch", + "is_node_graph_batch", + "move_node_graph_batch", + "unpack_batch", + "batch_accuracy", + "node_classify_logits", + "node_regress_preds", + "node_latent_vectors", + "node_classification_correct", +] diff --git a/shiftkit/methods/regression.py b/shiftkit/methods/regression.py index ce5015f..7ff6c5d 100644 --- a/shiftkit/methods/regression.py +++ b/shiftkit/methods/regression.py @@ -25,6 +25,13 @@ from typing import Optional, List from .mmd import MMDLoss, _auto_device +from .node_batch import ( + is_node_graph_batch, + move_node_graph_batch, + unpack_batch, + node_latent_vectors, + node_regress_preds, +) # ─── Source-Only Regression Baseline ───────────────────────────────────────── @@ -85,10 +92,13 @@ def _train_epoch(self, epoch: int, total_epochs: int) -> dict: zip(self.source_loader, self.target_loader), total=n_batches, desc=f"Epoch {epoch}/{total_epochs}", leave=False ): - x_src, y_src = x_src.to(self.device), y_src.to(self.device) - x_tgt, y_tgt = x_tgt.to(self.device), y_tgt.to(self.device) + x_src, y_src, node_batch = unpack_batch(x_src, y_src, self.device) + x_tgt, y_tgt, _ = unpack_batch(x_tgt, y_tgt, self.device) - pred_src = self.model(x_src) + if node_batch: + pred_src = node_regress_preds(self.model, x_src) + else: + pred_src = self.model(x_src).view_as(y_src) loss = self.mse_loss(pred_src, y_src) self.optimizer.zero_grad() @@ -100,7 +110,10 @@ def _train_epoch(self, epoch: int, total_epochs: int) -> dict: n_src += y_src.size(0) with torch.no_grad(): - pred_tgt = self.model(x_tgt) + if is_node_graph_batch(x_tgt): + pred_tgt = node_regress_preds(self.model, x_tgt) + else: + pred_tgt = self.model(x_tgt).view_as(y_tgt) tgt_se += ((pred_tgt - y_tgt) ** 2).sum().item() n_tgt += y_tgt.size(0) @@ -119,9 +132,14 @@ def evaluate(self, loader: DataLoader, domain: str = "source") -> dict: self.model.eval() ys, preds = [], [] for x, y in loader: - x, y = x.to(self.device), y.to(self.device) - preds.append(self.model(x)) - ys.append(y) + if is_node_graph_batch(x): + x = move_node_graph_batch(x, self.device) + preds.append(node_regress_preds(self.model, x)) + ys.append(x.y) + else: + x, y = x.to(self.device), y.to(self.device) + preds.append(self.model(x).view_as(y)) + ys.append(y) ys = torch.cat(ys) preds = torch.cat(preds) mse = ((preds - ys) ** 2).mean().item() @@ -211,18 +229,25 @@ def _train_epoch(self, epoch: int, total_epochs: int, warmup: bool = False) -> d zip(self.source_loader, self.target_loader), total=n_batches, desc=f"Epoch {epoch}/{total_epochs}", leave=False ): - x_src, y_src = x_src.to(self.device), y_src.to(self.device) - x_tgt, y_tgt = x_tgt.to(self.device), y_tgt.to(self.device) + x_src, y_src, node_batch = unpack_batch(x_src, y_src, self.device) + x_tgt, y_tgt, _ = unpack_batch(x_tgt, y_tgt, self.device) - z_src = self.model.encode(x_src) - pred_src = self.model.regress(z_src) - mse = self.mse_loss(pred_src, y_src) + if node_batch: + z_src = node_latent_vectors(self.model, x_src) + pred_src = self.model.regress(z_src).view_as(y_src) + else: + z_src = self.model.encode(x_src) + pred_src = self.model.regress(z_src).view_as(y_src) + mse = self.mse_loss(pred_src, y_src) if warmup: loss = mse mmd_val = 0.0 else: - z_tgt = self.model.encode(x_tgt) + if node_batch: + z_tgt = node_latent_vectors(self.model, x_tgt) + else: + z_tgt = self.model.encode(x_tgt) mmd = self.mmd_loss(z_src, z_tgt) loss = mse + self.mmd_weight * mmd mmd_val = mmd.item() @@ -238,7 +263,10 @@ def _train_epoch(self, epoch: int, total_epochs: int, warmup: bool = False) -> d n_src += y_src.size(0) with torch.no_grad(): - pred_tgt = self.model.regress(self.model.encode(x_tgt)) + if is_node_graph_batch(x_tgt): + pred_tgt = node_regress_preds(self.model, x_tgt) + else: + pred_tgt = self.model.regress(self.model.encode(x_tgt)).view_as(y_tgt) tgt_se += ((pred_tgt - y_tgt) ** 2).sum().item() n_tgt += y_tgt.size(0) @@ -257,9 +285,14 @@ def evaluate(self, loader: DataLoader, domain: str = "source") -> dict: self.model.eval() ys, preds = [], [] for x, y in loader: - x, y = x.to(self.device), y.to(self.device) - preds.append(self.model(x)) - ys.append(y) + if is_node_graph_batch(x): + x = move_node_graph_batch(x, self.device) + preds.append(node_regress_preds(self.model, x)) + ys.append(x.y) + else: + x, y = x.to(self.device), y.to(self.device) + preds.append(self.model(x).view_as(y)) + ys.append(y) ys = torch.cat(ys) preds = torch.cat(preds) mse = ((preds - ys) ** 2).mean().item() From c93179f3b018a9dbb6cded8b62d2a0c4c1293a6f Mon Sep 17 00:00:00 2001 From: Abdelrahman Helal Date: Tue, 2 Jun 2026 16:15:05 -0700 Subject: [PATCH 3/4] Add example for using the PyG data and model --- docs/data.md | 59 ++++++++++++++++++++++++++ examples/pyg_node_mmd.py | 92 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+) create mode 100644 examples/pyg_node_mmd.py diff --git a/docs/data.md b/docs/data.md index 8c66f90..5397025 100644 --- a/docs/data.md +++ b/docs/data.md @@ -163,3 +163,62 @@ test_src, test_tgt = dm.load("synthetic_graphs", train=False) |-----|--------|--------|-------------| | `"mnist_noisy_mnist"` | `torchvision.MNIST` | `NoisyMNIST` | `noise_std` (default `0.3`) | | `"synthetic_graphs"` | `SyntheticGraphDataset` (noise=0.1) | `SyntheticGraphDataset` (noise=0.5, flip=0.05) | `n_graphs`, `n_nodes`, `feat_dim`, `feature_noise_src`, `feature_noise_tgt`, `edge_flip_prob` | +| `"pyg_domains"` | User-supplied PyG `Data` | User-supplied PyG `Data` | `source`, `target`, `task_level`, `train_ratio`, `val_ratio`, `split_seed`, `split_mode` | + +--- + +## PyG domains (`pyg_domains`) + +Requires `torch-geometric`. Loads external PyG graphs for domain adaptation with automatic **stratified** train/val/test splits when masks are not already on the `Data` objects. + +### Node-level (one graph per domain) + +Use when labels live on **nodes** (transductive node classification). Each domain is a single `torch_geometric.data.Data` object. Training runs message passing on the full graph; loss and MMD use **train** nodes only; evaluation uses **test** nodes. + +```python +from shiftkit.data import DataManager +from shiftkit.models import GNN +from shiftkit.methods import MMDTrainer + +dm = DataManager(batch_size=1, num_workers=0) +train_src, train_tgt = dm.load( + "pyg_domains", + train=True, + task_level="node", + source=source_graph, + target=target_graph, + train_ratio=0.6, + val_ratio=0.2, + split_seed=42, + split_mode="stratified", +) +test_src, test_tgt = dm.load("pyg_domains", train=False, ...) + +model = GNN(source_graph, "SAGE", hidden_channels=64, num_layers=2, + num_classes=10, pool="none") +trainer = MMDTrainer(model, train_src, train_tgt, mmd_weight=1.0) +``` + +Pair with `shiftkit.models.GNN(..., pool="none")` so `encode()` returns per-node embeddings. + +### Graph-level (many graphs per domain) + +Pass a **list** of `Data` objects per domain. Splits are by graph index (stratified on graph labels when discrete). Use default `pool="mean"` on `GNN`. + +```python +train_src, train_tgt = dm.load( + "pyg_domains", + train=True, + task_level="graph", + source=list_of_src_graphs, + target=list_of_tgt_graphs, + train_ratio=0.6, + val_ratio=0.2, +) +``` + +### Existing masks + +If `data.train_mask` is already set, automatic splitting is skipped. Loaders use `train_mask` for `train=True` and `test_mask` for `train=False`. + +See `examples/pyg_node_mmd.py` for a full node-level example. diff --git a/examples/pyg_node_mmd.py b/examples/pyg_node_mmd.py new file mode 100644 index 0000000..3ed025e --- /dev/null +++ b/examples/pyg_node_mmd.py @@ -0,0 +1,92 @@ +""" +Example: node-level domain adaptation on two PyG graphs (one per domain). + +Uses DataManager.load("pyg_domains") with stratified node masks and +shiftkit.models.GNN with pool="none". + +Run from repo root: + python examples/pyg_node_mmd.py +""" + +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +from torch_geometric.data import Data + +from shiftkit.data import DataManager +from shiftkit.models import GNN +from shiftkit.methods import MMDTrainer, SourceOnlyTrainer + + +def make_domain_graph(n_nodes: int, feat_dim: int, n_classes: int, seed: int, shift: float = 0.0) -> Data: + torch.manual_seed(seed) + x = torch.randn(n_nodes, feat_dim) + shift + row = torch.arange(n_nodes - 1) + edge_index = torch.stack([row, row + 1], dim=0) + edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1) + y = torch.randint(0, n_classes, (n_nodes,)) + return Data(x=x, edge_index=edge_index, y=y) + + +if __name__ == "__main__": + N_NODES = 300 + FEAT = 8 + NUM_CLASSES = 3 + EPOCHS = 30 + + source_graph = make_domain_graph(N_NODES, FEAT, NUM_CLASSES, seed=0, shift=0.0) + target_graph = make_domain_graph(N_NODES, FEAT, NUM_CLASSES, seed=1, shift=1.5) + + dm = DataManager(batch_size=1, num_workers=0) + train_src, train_tgt = dm.load( + "pyg_domains", + train=True, + task_level="node", + source=source_graph, + target=target_graph, + train_ratio=0.6, + val_ratio=0.2, + split_seed=42, + split_mode="stratified", + ) + test_src, test_tgt = dm.load( + "pyg_domains", + train=False, + task_level="node", + source=source_graph, + target=target_graph, + train_ratio=0.6, + val_ratio=0.2, + split_seed=42, + split_mode="stratified", + ) + + model_so = GNN( + source_graph, "SAGE", hidden_channels=32, num_layers=2, + num_classes=NUM_CLASSES, pool="none", + ) + model_mmd = GNN( + source_graph, "SAGE", hidden_channels=32, num_layers=2, + num_classes=NUM_CLASSES, pool="none", + ) + + print("Training Source-Only...") + so = SourceOnlyTrainer(model_so, train_src, train_tgt, lr=1e-3, device="cpu") + so.fit(epochs=EPOCHS) + + print("Training MMD...") + mmd = MMDTrainer(model_mmd, train_src, train_tgt, mmd_weight=0.5, lr=1e-3, device="cpu") + mmd.fit(epochs=EPOCHS) + + for name, trainer in [("Source-Only", so), ("MMD", mmd)]: + r_src = trainer.evaluate(test_src, domain="source-test") + r_tgt = trainer.evaluate(test_tgt, domain="target-test") + print( + f"{name:12s} src acc={r_src['accuracy']*100:.1f}% " + f"tgt acc={r_tgt['accuracy']*100:.1f}%" + ) + + print("Done.") From 99d486ef0d824ff11d94cadadfe5505dd1065297 Mon Sep 17 00:00:00 2001 From: Abdelrahman Helal Date: Tue, 2 Jun 2026 16:15:53 -0700 Subject: [PATCH 4/4] Add domain-only colouring for regression tasks --- shiftkit/__init__.py | 7 +- shiftkit/diagnostics/__init__.py | 2 + shiftkit/diagnostics/plots.py | 106 +++++++++++++++++++++++++++++-- 3 files changed, 108 insertions(+), 7 deletions(-) diff --git a/shiftkit/__init__.py b/shiftkit/__init__.py index bd5e7d4..3b24c4b 100644 --- a/shiftkit/__init__.py +++ b/shiftkit/__init__.py @@ -26,8 +26,8 @@ from .methods.kliep import KLIEPWeightEstimator, KLIEPTrainer from .data.datasets import SineWaveDataset, CaliforniaHousingDataset from .diagnostics.plots import ( - plot_latent_space, plot_training_history, compare_latent_spaces, - plot_confusion_matrix, plot_roc_curve, + plot_latent_space, plot_latent_space_domains, plot_training_history, + compare_latent_spaces, plot_confusion_matrix, plot_roc_curve, ) __version__ = "0.1.0" @@ -43,7 +43,8 @@ "SIDDATrainer", "SourceOnlyRegressionTrainer", "MMDRegressionTrainer", "KLIEPWeightEstimator", "KLIEPTrainer", - "plot_latent_space", "plot_training_history", "compare_latent_spaces", + "plot_latent_space", "plot_latent_space_domains", "plot_training_history", + "compare_latent_spaces", "plot_confusion_matrix", "plot_roc_curve", ] if GNN is not None: diff --git a/shiftkit/diagnostics/__init__.py b/shiftkit/diagnostics/__init__.py index 9c7cc18..2ee5b36 100644 --- a/shiftkit/diagnostics/__init__.py +++ b/shiftkit/diagnostics/__init__.py @@ -1,5 +1,6 @@ from .plots import ( plot_latent_space, + plot_latent_space_domains, plot_training_history, compare_latent_spaces, plot_confusion_matrix, @@ -8,6 +9,7 @@ __all__ = [ "plot_latent_space", + "plot_latent_space_domains", "plot_training_history", "compare_latent_spaces", "plot_confusion_matrix", diff --git a/shiftkit/diagnostics/plots.py b/shiftkit/diagnostics/plots.py index 1580f96..d7c1417 100644 --- a/shiftkit/diagnostics/plots.py +++ b/shiftkit/diagnostics/plots.py @@ -106,10 +106,14 @@ def _run_projection( if method == "tsne": from sklearn.manifold import TSNE - reducer = TSNE( - n_components=2, perplexity=perplexity, n_iter=n_iter, + tsne_kw = dict( + n_components=2, perplexity=perplexity, random_state=42, init="pca", learning_rate="auto", ) + try: + reducer = TSNE(**tsne_kw, n_iter=n_iter) + except TypeError: + reducer = TSNE(**tsne_kw, max_iter=n_iter) elif method == "isomap": from sklearn.manifold import Isomap reducer = Isomap(n_components=2, n_neighbors=n_neighbors) @@ -129,9 +133,13 @@ def _run_projection( return reducer.fit_transform(z_all), domain_labels -def _draw_domain_panel(ax, z2d, domain_labels, title): +def _draw_domain_panel( + ax, z2d, domain_labels, title, + domain_names: Optional[tuple] = None, +): palette = ["#4C72B0", "#DD8452"] - for d, (label, color) in enumerate(zip(["Source", "Target"], palette)): + names = domain_names if domain_names is not None else ("Source", "Target") + for d, (label, color) in enumerate(zip(names, palette)): mask = domain_labels == d ax.scatter(z2d[mask, 0], z2d[mask, 1], c=color, label=label, s=8, alpha=0.6, linewidths=0) @@ -140,6 +148,34 @@ def _draw_domain_panel(ax, z2d, domain_labels, title): ax.set_xticks([]); ax.set_yticks([]) +@torch.no_grad() +def _collect_node_embeddings( + model: nn.Module, + data, + mask_attr: str, + device: torch.device, + max_samples: int, +) -> np.ndarray: + """ + Encode a single PyG graph and return latent vectors for masked nodes. + + For node-level models (``pool='none'``), ``encode`` returns shape ``(N, D)``. + """ + mask = getattr(data, mask_attr, None) + if mask is None: + raise AttributeError(f"Graph has no mask attribute '{mask_attr}'") + idx = mask.nonzero(as_tuple=False).view(-1) + if idx.numel() == 0: + raise ValueError(f"No nodes selected by {mask_attr}") + + model.eval() + z = model.encode(data.to(device)) + if idx.numel() > max_samples: + perm = torch.randperm(idx.numel(), device=idx.device)[:max_samples] + idx = idx[perm] + return z[idx].cpu().numpy() + + def _draw_class_panel(ax, z2d, y_src, y_tgt, title, class_names): class_labels = np.concatenate([y_src, y_tgt]) unique = sorted(np.unique(class_labels)) @@ -212,6 +248,68 @@ def plot_latent_space( return fig +def plot_latent_space_domains( + model: nn.Module, + source_graph, + target_graph, + max_samples_per_domain: int = 2000, + node_mask: str = "test_mask", + projection: str = "tsne", + perplexity: float = 30.0, + n_iter: int = 1000, + n_neighbors: int = 15, + min_dist: float = 0.1, + domain_names: tuple = ("Source", "Target"), + title: str = "Latent space by domain", + save_path: Optional[str] = None, + show: bool = True, +) -> plt.Figure: + """ + Plot a single 2-D latent projection coloured by domain only. + + Intended for node-level PyG graphs (one graph per domain), e.g. FIREbox vs + TNG300. Pass ``domain_names`` to label the legend (default Source / Target). + + Parameters + ---------- + source_graph, target_graph : PyG ``Data`` objects (one graph per domain) + max_samples_per_domain : cap on nodes sampled per graph + node_mask : mask attribute on ``Data`` (e.g. ``test_mask``) + projection : ``tsne``, ``isomap``, or ``umap`` + domain_names : legend labels for domain 0 and 1 + """ + device = _device_of(model) + print(f"Collecting {domain_names[0]} node embeddings …") + z_src = _collect_node_embeddings( + model, source_graph, node_mask, device, max_samples_per_domain + ) + print(f"Collecting {domain_names[1]} node embeddings …") + z_tgt = _collect_node_embeddings( + model, target_graph, node_mask, device, max_samples_per_domain + ) + + z2d, domain_labels = _run_projection( + z_src, z_tgt, projection, perplexity, n_iter, n_neighbors, min_dist + ) + xlabel, ylabel = _PROJ_AXIS_LABELS[projection.lower()] + + fig, ax = plt.subplots(1, 1, figsize=(7, 6)) + fig.suptitle(title, fontsize=13, fontweight="bold") + _draw_domain_panel(ax, z2d, domain_labels, "By domain", domain_names=domain_names) + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + + plt.tight_layout() + if save_path: + from pathlib import Path + Path(save_path).parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path, dpi=150, bbox_inches="tight") + print(f"Figure saved to {save_path}") + if show: + plt.show() + return fig + + def compare_latent_spaces( models: dict, source_loader: DataLoader,