diff --git a/README.md b/README.md index d6bad67..b9d3a40 100644 --- a/README.md +++ b/README.md @@ -1,320 +1,121 @@ -# PyGIP -PyGIP is a Python library designed for experimenting with graph-based model extraction attacks and defenses. It provides -a modular framework to implement and test attack and defense strategies on graph datasets. +**QueryBasedVerification** is a defense module implemented under the **PyGIP** framework that replicates the core defense proposed in the paper _"Securing Graph Neural Networks in MLaaS: A Comprehensive Realization of Query-based Integrity Verification"_ (Wu et al., 2023). -## Installation -To get started with PyGIP, set up your environment by installing the required dependencies: +## Experimental Parameters** -```bash -pip install -r reqs.txt -``` +#### Common Parameters +| Parameter | Value Used | Paper Value | Notes | +| ---------------------- | ---------------------------------------- | ------------------------------------------ | +| `attack_node_fraction` | `0.1` | `0.3` | Lowered to reduce impact and runtime | +| `k` (num fingerprints) | `5` | `10` | Halved to reduce query overhead while maintaining effectiveness | +| `attack_trial_map` | `bitflip: 20`, `random: 5`, `mettack: 5` | Paper uses 400 trials for BFA | Reduced for faster experimentation | +| `bit_position` | `30` | Unspecified, but paper flips exponent bits | Matches intent of BFA attack | -Ensure you have Python installed (version 3.8 or higher recommended) along with the necessary libraries listed -in `reqs.txt`. +#### Bit Flip Attack (BFA) Specific Parameters -Specifically, using following command to install `dgl 2.2.1` and ensure your `pytorch==2.3.0`. +| Parameter | Value Used | Paper Value | Notes | +| ------------- | ----------------------- | ----------- | ------------------------------ | +| `num_trials` | `20` | `400` | Downsampled to speed up runs | -```shell -pip install dgl==2.2.1 -f https://data.dgl.ai/wheels/torch-2.3/repo.html -``` +#### Random Poisoning Attack -## Quick Start +| Parameter | Value Used | Paper Value | Notes | +| ------------- | ---------- | ------------------------- | -------------------------- | +| `num_trials` | `5` | Not directly specified | Chosen for time-efficiency | -Here’s a simple example to launch a model extraction attack using PyGIP: +#### Mettack Poisoning Attack -```python -from datasets import Cora -from models.attack import ModelExtractionAttack0 +| Parameter | Value Used | Paper Value | Notes | +| ----------------------- | ---------- | ------------------------------ | ----------------------------------------------- | +| `poison_frac` | `0.005` | ~0.01 (for 100 perturbations) | Halved to ~50 perturbations for faster runtime | +| `epochs` | `30` | `200` | Reduced to speed up training | +| `surrogate_epochs` | `20` | `200` | Reduced for surrogate model efficiency | +| `candidate_sample_size` | `50` | `100` (default) | Smaller pool for runtime reasons | -# Load the Cora dataset -dataset = Cora() -# Initialize the attack with a sampling ratio of 0.25 -mea = ModelExtractionAttack0(dataset, 0.25) +## Results -# Execute the attack -mea.attack() -``` +**Cora Dataset** -This code loads the Cora dataset, initializes a basic model extraction attack (`ModelExtractionAttack0`), and runs the -attack with a specified sampling ratio. -Here’s an expanded and detailed version of the "Contribute to Code" section for your README.md, incorporating the -specifics of `BaseAttack` and `Dataset` you provided. This version is thorough, clear, and tailored for contributors: +### Transductive-F Detection Rate Comparison +| Attack | Our Detection Rate | Paper Detection Rate | +|-----------|--------------------|----------------------| +| BFA | 0.69 | 0.711 | +| BFA-F | 0.67 | 0.96 | +| BFA-L | 0.7 | 0.5 | +| random | 0.52 | 0.647 | +| mettack | 0.84 | 0.588 | -## Implementation +### Transductive-L Detection Rate Comparison -PyGIP is built to be modular and extensible, allowing contributors to implement their own attack and defense strategies. -Below, we detail how to extend the framework by implementing custom attack and defense classes, with a focus on how to -leverage the provided dataset structure. +| Attack | Our Detection Rate | Paper Detection Rate | +|-----------|--------------------|----------------------| +| BFA | 0.63 | 0.982 | +| BFA-F | 0.77 | 0.81 | +| BFA-L | 0.74 | 1.0 | +| random | 0.72 | 0.353 | +| mettack | 0.88 | 0.598 | -### Implementing Attack +### Inductive-F Detection Rate Comparison -To create a custom attack, you need to extend the abstract base class `BaseAttack`. Here’s the structure -of `BaseAttack`: +| Attack | Our Detection Rate | Paper Detection Rate | +|-----------|--------------------|----------------------| +| BFA | 0.58 | 0.667 | +| BFA-F | 0.7 | 1.0 | +| BFA-L | 0.66 | 0.382 | +| random | 0.68 | 1.0 | +| mettack | 1.0 | 1.0 | -```python -class BaseAttack(ABC): - def __init__(self, dataset: Dataset, attack_node_fraction: float, model_path: str = None): - """Base class for all attack implementations.""" - self.dataset = dataset - self.graph = dataset.graph # Access the DGL-based graph directly - # Additional initialization can go here +### Inductive-L Detection Rate Comparison - @abstractmethod - def attack(self): - raise NotImplementedError +| Attack | Our Detection Rate | Paper Detection Rate | +|-----------|--------------------|----------------------| +| BFA | 0.73 | 0.688 | +| BFA-F | 0.62 | 0.989 | +| BFA-L | 0.63 | 0.348 | +| random | 0.44 | 1.0 | +| mettack | 0.68 | 1.0 | - def _train_target_model(self): - raise NotImplementedError - def _train_attack_model(self): - raise NotImplementedError +**Citeseer Dataset** - def _load_model(self, model_path): - raise NotImplementedError -``` +### Transductive-F Detection Rate Comparison -To implement your own attack: +| Attack | Our Detection Rate | Paper Detection Rate | +|-----------|--------------------|----------------------| +| BFA | 0.63 | 0.586 | +| BFA-F | 0.70 | 0.430 | +| BFA-L | 0.56 | 0.529 | +| random | 0.60 | 0.412 | +| mettack | 0.68 | 0.353 | -1. **Inherit from `BaseAttack`**: - Create a new class that inherits from `BaseAttack`. You’ll need to provide the following required parameters in the - constructor: +### Transductive-L Detection Rate Comparison -- `dataset`: An instance of the `Dataset` class (see below for details). -- `attack_node_fraction`: A float between 0 and 1 representing the fraction of nodes to attack. -- `model_path` (optional): A string specifying the path to a pre-trained model (defaults to `None`). - -You need to implement following methods: - -- `attack()`: Add main attack logic here. If multiple attack types are supported, define the attack type as an optional - argument to this function. - For each specific attack type, implement a corresponding helper function such as `_attack_type1()` - or `_attack_type2()`, - and call the appropriate helper inside `attack()` based on the given method name. -- `_load_model()`: Load victim model. -- `_train_target_model()`: Train victim model. -- `_train_attack_model()`: Train attack model. -- `_helper_func()`(optional): Add your helper functions based on your needs, but keep the methods private. - -2. **Implement the `attack()` Method**: - Override the abstract `attack()` method with your attack logic, and return a dict of results. For example: - -```python -class MyCustomAttack(BaseAttack): - def __init__(self, dataset: Dataset, attack_node_fraction: float, model_path: str = None): - super().__init__(dataset, attack_node_fraction, model_path) - # Additional initialization if needed - - def attack(self): - # Example: Access the graph and perform an attack - print(f"Attacking {self.attack_node_fraction * 100}% of nodes") - num_nodes = self.graph.num_nodes() - print(f"Graph has {num_nodes} nodes") - # Add your attack logic here - return { - 'metric1': 'metric1 here', - 'metric2': 'metric2 here' - } - - def _load_model(self): - # add your logic here - pass - - def _train_target_model(self): - # add your logic here - pass - - def _train_attack_model(self): - # add your logic here - pass -``` - -### Implementing Defense - -To create a custom defense, you need to extend the abstract base class `BaseDefense`. Here’s the structure -of `BaseDefense`: - -```python -class BaseDefense(ABC): - def __init__(self, dataset: Dataset, attack_node_fraction: float): - """Base class for all defense implementations.""" - # add initialization here - - @abstractmethod - def defend(self): - raise NotImplementedError - - def _load_model(self): - raise NotImplementedError - - def _train_target_model(self): - raise NotImplementedError - - def _train_defense_model(self): - raise NotImplementedError - - def _train_surrogate_model(self): - raise NotImplementedError -``` - -To implement your own defense: - -1. **Inherit from `BaseDefense`**: - Create a new class that inherits from `BaseDefense`. You’ll need to provide the following required parameters in the - constructor: - -- `dataset`: An instance of the `Dataset` class (see below for details). -- `attack_node_fraction`: A float between 0 and 1 representing the fraction of nodes to attack. -- `model_path` (optional): A string specifying the path to a pre-trained model (defaults to `None`). - -You need to implement following methods: - -- `defense()`: Add main defense logic here. If multiple defense types are supported, define the defense type as an - optional argument to this function. - For each specific defense type, implement a corresponding helper function such as `_defense_type1()` - or `_defense_type2()`, - and call the appropriate helper inside `defense()` based on the given method name. -- `_load_model()`: Load victim model. -- `_train_target_model()`: Train victim model. -- `_train_defense_model()`: Train defense model. -- `_train_surrogate_model()`: Train attack model. -- `_helper_func()`(optional): Add your helper functions based on your needs, but keep the methods private. - - -2. **Implement the `defense()` Method**: - Override the abstract `defense()` method with your defense logic, and return a dict of results. For example: - -```python -class MyCustomDefense(BaseDefense): - def defend(self): - # Step 1: Train target model - target_model = self._train_target_model() - # Step 2: Attack target model - attack = MyCustomAttack(self.dataset, attack_node_fraction=0.3) - attack.attack(target_model) - # Step 3: Train defense model - defense_model = self._train_defense_model() - # Step 4: Test defense against attack - attack = MyCustomAttack(self.dataset, attack_node_fraction=0.3) - attack.attack(defense_model) - # Print performance metrics - - def _load_model(self): - # add your logic here - pass - - def _train_target_model(self): - # add your logic here - pass - - def _train_defense_model(self): - # add your logic here - pass - - def _train_surrogate_model(self): - # add your logic here - pass -``` - -### Understanding the Dataset Class - -The `Dataset` class standardizes the data format across PyGIP. Here’s its structure: - -```python -class Dataset(object): - def __init__(self, api_type='pyg', path='./downloads/'): - self.api_type = api_type # Set to 'pyg' for torch_geometric-based graphs - self.path = path # Directory for dataset storage - self.dataset_name = "" # Name of the dataset (e.g., "Cora") - - # Graph properties - self.node_number = 0 # Number of nodes - self.feature_number = 0 # Number of features per node - self.label_number = 0 # Number of label classes - - # Core data - self.graph = None # PyG graph object - self.features = None # Node features - self.labels = None # Node labels - - # Data splits - self.train_mask = None # Boolean mask for training nodes - self.val_mask = None # Boolean mask for validation nodes - self.test_mask = None # Boolean mask for test nodes -``` - -- **Importance**: We are currently using the default api_type='pyg' to load the data. It is important to note that when - api_type='pyg', `self.graph` should be an instance of `torch_geometric.data.Data`. In your implementation, make sure to - use our defined Dataset class to build your code. -- Additional attributes like `self.dataset.features` (node features), `self.dataset.labels` (node labels), - and `self.dataset.train_mask` (training split) are also available if your logic requires them. - -### Miscellaneous Tips - -- **Reference Implementation**: The `ModelExtractionAttack0` class is a fully implemented attack example. Study it for - inspiration or as a template. -- **Flexibility**: Add as many helper functions as needed within your class to keep your code clean and modular. -- **Backbone Models**: We provide several basic backbone models like `GCN, GraphSAGE`. You can use or add more - at `from models.nn import GraphSAGE`. - -By following these guidelines, you can seamlessly integrate your custom attack or defense strategies into PyGIP. Happy -coding! - -## Internal Code Submission Guideline - -For internal team members with write access to the repository: - -1. Always Use Feature/Fix Branches - -- Never commit directly to the main or develop branch. -- Create a new branch for each feature, bug fix. - -```shell -git checkout -b feat/your-feature-name -``` - -```shell -git checkout -b fix/your-fix-name -``` - -2. Keep Commits Clean & Meaningful - -- feat: add data loader for graph dataset -- fix: resolve crash on edge cases - -Use clear commit messages following the format: - -```shell -: -``` - -3. Test Before Pushing - -- Test your implementation in `example.py`, and compare the performance with the results in original paper. - -4. Push to Internal Branch - -- Always run `git pull origin pygip-release` before pushing your changes -- Submit a pull request targeting the `pygip-release` branch -- Write a brief summary describing the features you’ve added, how to run your method, and how to evaluate its - performance - -Push to the remote feature branch. - -```shell -git push origin feat/your-feature-name -``` - -## External Pull Request Guideline - -Refer to [guidline](.github/CONTRIBUTING.md) - -## License - -MIT License - -## Contact - -For questions or contributions, please contact blshen@fsu.edu. +| Attack | Our Detection Rate | Paper Detection Rate | +|-----------|--------------------|----------------------| +| BFA | 0.56 | 0.289 | +| BFA-F | 0.61 | 0.430 | +| BFA-L | 0.53 | 0.133 | +| random | 0.76 | 0.824 | +| mettack | 0.68 | 0.235 | + +### Inductive-F Detection Rate Comparison + +| Attack | Our Detection Rate | Paper Detection Rate | +|-----------|--------------------|----------------------| +| BFA | 0.48 | 0.941 | +| BFA-F | 0.57 | 0.882 | +| BFA-L | 0.68 | 0.529 | +| random | 0.64 | 1.0 | +| mettack | 0.92 | 1.0 | + +### Inductive-L Detection Rate Comparison + +| Attack | Our Detection Rate | Paper Detection Rate | +|-----------|--------------------|----------------------| +| BFA | 0.59 | 0.901 | +| BFA-F | 0.53 | 0.852 | +| BFA-L | 0.67 | 0.569 | +| random | 0.72 | 1.0 | +| mettack | 0.92 | 1.0 | diff --git a/models/defense/QueryBasedVerification.py b/models/defense/QueryBasedVerification.py new file mode 100644 index 0000000..51bc7a1 --- /dev/null +++ b/models/defense/QueryBasedVerification.py @@ -0,0 +1,1233 @@ +from .base import BaseDefense +import torch +import torch.nn.functional as F +from torch.optim import Adam +from models.nn import GCN +import numpy as np +import random +from collections import Counter +from torch_geometric.utils import to_networkx, from_networkx, to_undirected +import networkx as nx +import copy +import torch.optim as optim +import dgl +from itertools import combinations +from tqdm import tqdm + + +class QueryBasedVerificationDefense(BaseDefense): + supported_api_types = {"dgl"} + supported_datasets = {} + def __init__(self, dataset, attack_node_fraction, model_path=None): + super().__init__(dataset, attack_node_fraction) + self.model_path = model_path + + + + def defend(self, fingerprint_mode='inductive', knowledge='full', attack_type='bitflip', + k=5, num_trials=10, use_edge_perturbation=False, verbose=True, **kwargs): + + """ + Main defense routine. Generates fingerprints, runs attacks, and verifies integrity. + Returns a dict with per-trial and average metrics. + """ + trial_results = [] + for trial in range(num_trials): + if verbose: + print(f"\n=== Trial {trial+1}/{num_trials} ===") + + + model_clean = self._train_target_model() + acc_clean = self._evaluate_accuracy(model_clean, self.dataset) + + + fingerprints = self._generate_fingerprints(model_clean, mode=fingerprint_mode, knowledge=knowledge, k=k, + perturb_fingerprints=use_edge_perturbation, + perturb_budget=kwargs.get('perturb_budget', 5), + **kwargs) + + + bit = kwargs.pop('bit', 30) + bfa_variant = kwargs.pop('bfa_variant', 'BFA') + + poisoned_model, attack_info = self._run_attack( + model_clean, + attack_type=attack_type, + knowledge=knowledge, + bit=bit, + bfa_variant=bfa_variant, + **kwargs + ) + + poisoned_dataset = copy.deepcopy(self.dataset) + if 'graph' in attack_info: + poisoned_dataset.graph_data = attack_info['graph'] + acc_poisoned = self._evaluate_accuracy(poisoned_model, poisoned_dataset) + + + flipped_info = self._evaluate_fingerprints(poisoned_model, fingerprints) + + flip_rate = flipped_info['flip_rate'] + acc_drop = acc_clean - acc_poisoned + num_flipped = len(flipped_info['flipped']) + num_total = len(fingerprints) + detection_rate = num_flipped / num_total if num_total > 0 else 0.0 + + if verbose: + print(f"Clean Accuracy: {acc_clean:.4f}") + print(f"Poisoned Accuracy: {acc_poisoned:.4f}") + print(f"Accuracy Drop: {acc_drop:.4f}") + print(f"Flip Rate: {flip_rate:.4f}") + print(f"Detection Rate: {detection_rate:.4f}") + + + trial_results.append({ + 'flip_rate': flip_rate, + 'accuracy_drop': acc_drop, + 'detection_rate': detection_rate + }) + + + avg_flip_rate = sum(r['flip_rate'] for r in trial_results) / num_trials + avg_acc_drop = sum(r['accuracy_drop'] for r in trial_results) / num_trials + avg_detection_rate = sum(r['detection_rate'] for r in trial_results) / num_trials + + + + return { + 'trial_results': trial_results, + 'average_flip_rate': avg_flip_rate, + 'average_accuracy_drop': avg_acc_drop, + 'average_detection_rate': avg_detection_rate + } + + + def _get_features(self): + return self.graph_data.ndata['feat'] if hasattr(self.graph_data, 'ndata') else self.graph_data.x + + + def _train_target_model(self, epochs=200): + """ + Trains target GCN model according to protocol in + Wu et al. (2023), Section 6.1 for graph node classification. + + Returns + ------- + model : torch.nn.Module + The trained GCN model. + """ + model = GCN( + feature_number=self.dataset.feature_number, + label_number=self.dataset.label_number + ).to(self.device) + print(f"Training target model on device: {self.device} ...") + + optimizer = Adam(model.parameters(), lr=0.02) + loss_fn = torch.nn.NLLLoss() + + features = self._get_features().to(self.device) + labels = self.dataset.labels.to(self.device) + train_mask = self.dataset.train_mask.to(self.device) + val_mask = getattr(self.dataset, "val_mask", None) + if val_mask is None: + val_mask = self.dataset.test_mask + val_mask = val_mask.to(self.device) + + for epoch in range(epochs): + model.train() + logits = model(self.graph_data.to(self.device), features) + log_probs = F.log_softmax(logits, dim=1) + loss = loss_fn(log_probs[train_mask], labels[train_mask]) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if (epoch + 1) % 10 == 0 or epoch == 0: + model.eval() + with torch.no_grad(): + val_logits = model(self.graph_data.to(self.device), features) + val_log_probs = F.log_softmax(val_logits, dim=1) + val_pred = val_log_probs[val_mask].max(1)[1] + val_acc = (val_pred == labels[val_mask]).float().mean().item() + print(f"Epoch {epoch+1}: Loss={loss.item():.4f} | Val Acc={val_acc:.4f}") + + return model + + def _load_model(self, model_path): + model = GCN( + in_feats=self.dataset.feature_number, + hidden_feats=16, + out_feats=self.dataset.label_number + ) + model.load_state_dict(torch.load(model_path)) + return model + + + def _generate_fingerprints(self, model, mode='transductive', knowledge='full', k=5, **kwargs): + """ + Wrapper for fingerprint generation based on mode and knowledge level. + Returns: + List of fingerprints + """ + if mode == 'transductive': + generator = TransductiveFingerprintGenerator( + model=model, + dataset=self.dataset, + candidate_fraction=kwargs.get('candidate_fraction', 1.0), + random_seed=kwargs.get('random_seed', None), + device=self.device, + randomize=kwargs.get('randomize', True), + ) + fingerprints = generator.generate_fingerprints(k=k, method=knowledge) + + unified_fingerprints = [(self.graph_data, node_id, label) for (node_id, label) in fingerprints] + + elif mode == 'inductive': + generator = InductiveFingerprintGenerator( + model=model, + shadow_graph=self.dataset.graph_data, + knowledge=knowledge, + candidate_fraction=kwargs.get('candidate_fraction', 0.3), + num_fingerprints=k, + randomize=kwargs.get('randomize', True), + random_seed=kwargs.get('random_seed', None), + device=self.device, + perturb_fingerprints=kwargs.get('perturb_fingerprints', False), + perturb_budget=kwargs.get('perturb_budget', 5), + ) + fingerprints = generator.generate_fingerprints(method=knowledge) + if kwargs.get('perturb_fingerprints', False): + for i, (graph, node_idx, label) in enumerate(fingerprints): + generator.shadow_graph = graph + generator.greedy_edge_perturbation( + node_idx=node_idx, + perturb_budget=kwargs.get('perturb_budget', 5), + knowledge=knowledge + ) + fingerprints[i] = (generator.shadow_graph, node_idx, label) + + unified_fingerprints = fingerprints + + else: + raise ValueError("Unknown fingerprinting mode. Use 'transductive' or 'inductive'.") + + return unified_fingerprints + + def _evaluate_fingerprints(self, model, fingerprints): + """ + Checks if fingerprinted nodes have changed labels under the given model. + + Args: + model: The model to evaluate. + fingerprints: List of (graph, node_id, label) tuples. + + Returns: + results: { + 'flipped': List[Tuple[node_id, old_label, new_label]], + 'flip_rate': float + } + """ + model.eval() + flipped = [] + + with torch.no_grad(): + for graph, node_id, expected_label in fingerprints: + x = graph.ndata['feat'] if hasattr(graph, 'ndata') else graph.x + logits = model(graph.to(self.device), x.to(self.device)) + pred = logits[node_id].argmax().item() + if pred != expected_label: + flipped.append((node_id, expected_label, pred)) + + return { + 'flipped': flipped, + 'flip_rate': len(flipped) / len(fingerprints) if fingerprints else 0.0 + } + + + def _run_attack(self, model, attack_type='mettack', knowledge='full', **kwargs): + """ + Run the specified attack on the model. + Returns: + poisoned_model: torch.nn.Module + metadata: dict with info about the attack + """ + if attack_type == 'bitflip': + bit = kwargs.get('bit', 30) + bfa_variant = kwargs.get('bfa_variant', 'BFA') + attacker = BitFlipAttack(model, attack_type=bfa_variant, bit=bit) + attack_info = attacker.apply() + return model, attack_info + + elif attack_type == 'random': + perturbed_graph = self._random_edge_addition_poisoning( + node_fraction=kwargs.get('node_fraction', 0.1), + edges_per_node=kwargs.get('edges_per_node', 5), + random_seed=kwargs.get('random_seed', None), + ) + poisoned_model = self._retrain_poisoned_model( + poisoned_graph=perturbed_graph, + epochs=kwargs.get('epochs', 200), + ) + return poisoned_model, {'type': 'random_poison', 'graph': perturbed_graph} + + elif attack_type == 'mettack': + num_edges = self.graph_data.num_edges() + poison_frac = kwargs.get('poison_frac', 0.05) + n_perturbations = int(poison_frac * num_edges) + + helper = MettackHelper( + graph=self.graph_data, + features=self._get_features(), + labels=self.dataset.labels, + train_mask=self.dataset.train_mask, + val_mask=getattr(self.dataset, 'val_mask', None), + test_mask=self.dataset.test_mask, + n_perturbations=n_perturbations, + device=self.device, + max_perturbations=kwargs.get('max_perturbations', 50), + surrogate_epochs=kwargs.get('surrogate_epochs', 30), + candidate_sample_size=kwargs.get('candidate_sample_size', 20), + ) + poisoned_graph, attack_metrics = helper.run() + poisoned_model = self._retrain_poisoned_model( + poisoned_graph=poisoned_graph, + epochs=kwargs.get('epochs', 200), + ) + return poisoned_model, {'type': 'mettack', 'metrics': attack_metrics, 'graph': poisoned_graph} + + else: + raise ValueError(f"Unsupported attack_type: {attack_type}") + + + def _random_edge_addition_poisoning(self, node_fraction=0.1, edges_per_node=5, random_seed=None): + """ + Poison a fraction of nodes by adding random edges. + + Args: + dataset: Dataset object (DGL-based) + node_fraction: Fraction of nodes to poison (e.g., 0.1 = 10%) + edges_per_node: Number of random edges to add per poisoned node + random_seed: Optional seed + + Returns: + poisoned_graph: DGLGraph + """ + if random_seed is not None: + random.seed(random_seed) + torch.manual_seed(random_seed) + + poisoned_graph = copy.deepcopy(self.graph_data) + num_nodes = poisoned_graph.num_nodes() + num_poisoned_nodes = int(node_fraction * num_nodes) + poisoned_nodes = random.sample(range(num_nodes), num_poisoned_nodes) + + new_edges = [] + + for src in poisoned_nodes: + for _ in range(edges_per_node): + dst = random.randint(0, num_nodes - 1) + if src != dst and \ + not poisoned_graph.has_edges_between(src, dst) and \ + not poisoned_graph.has_edges_between(dst, src): + new_edges.append((src, dst)) + new_edges.append((dst, src)) + + if new_edges: + src, dst = zip(*new_edges) + poisoned_graph.add_edges(src, dst) + + return poisoned_graph + + + def _retrain_poisoned_model(self, poisoned_graph, epochs=200): + """ + Retrain target GCN using the poisoned graph structure. + + Args: + dataset: Original Dataset object (provides features, labels, masks) + poisoned_graph: DGLGraph (with new random edges added) + defense_class: The defense class to use for model training (e.g., QueryBasedVerificationDefense) + device: 'cpu' or 'cuda' + + Returns: + model: Trained GCN model + """ + dataset_poisoned = copy.deepcopy(self.dataset) + dataset_poisoned.graph_data = poisoned_graph + + defense = QueryBasedVerificationDefense(dataset=dataset_poisoned, attack_node_fraction=0.1) + model = defense._train_target_model(epochs=epochs) + return model + + + def _evaluate_accuracy(self, model, dataset): + """ + Evaluates test accuracy of the given model on the dataset. + + Args: + model: Trained GCN model + dataset: Dataset object (provides features, labels, test_mask, graph) + + Returns: + accuracy: float (test accuracy, 0-1) + """ + model.eval() + features = self._get_features().to(self.device) + labels = dataset.labels.to(self.device) + test_mask = dataset.test_mask + + with torch.no_grad(): + logits = model(dataset.graph_data.to(self.device), features) + pred = logits.argmax(dim=1) + correct = (pred[test_mask] == labels[test_mask]).float() + accuracy = correct.sum().item() / test_mask.sum().item() + return accuracy + + def run_full_pipeline(self, attack_type='random', mode='transductive', knowledge='full', k=5, trials=1, **kwargs): + """ + Runs the full fingerprinting + attack + evaluation pipeline. + + Parameters: + attack_type: 'random', 'bitflip', or 'mettack' + mode: 'transductive' or 'inductive' + knowledge: 'full' or 'limited' + k: number of fingerprints + trials: number of repeated trials + kwargs: extra params for attack or fingerprinting + + Prints per-trial results and summary statistics. + """ + flip_rates = [] + acc_drops = [] + + for trial in range(trials): + print(f"\n=== Trial {trial+1}/{trials} ===") + + model_clean = self._train_target_model() + acc_clean = self._evaluate_accuracy(model_clean, self.dataset) + print(f"Clean model accuracy: {acc_clean:.4f}") + + fingerprints = self._generate_fingerprints(model_clean, mode=mode, knowledge=knowledge, k=k, **kwargs) + + model_poisoned, attack_meta = self._run_attack(model_clean, attack_type=attack_type, knowledge=knowledge, **kwargs) + acc_poisoned = self._evaluate_accuracy(model_poisoned, self.dataset) + print(f"Poisoned model accuracy: {acc_poisoned:.4f}") + + eval_result = self._evaluate_fingerprints(model_poisoned, fingerprints) + flip_rate = eval_result['flip_rate'] + print(f"Fingerprint flip rate: {flip_rate:.4f}") + for (nid, old, new) in eval_result['flipped']: + print(f" Node {nid}: {old} → {new}") + + flip_rates.append(flip_rate) + acc_drops.append(acc_clean - acc_poisoned) + + print("\n=== Summary ===") + print(f"Avg Accuracy Drop: {np.mean(acc_drops):.4f}") + print(f"Avg Fingerprint Flip Rate: {np.mean(flip_rates):.4f}") + + + + +class TransductiveFingerprintGenerator: + def __init__(self, model, dataset, candidate_fraction=0.3, random_seed=None, device='cpu', randomize=True): + self.device = torch.device(device) + self.model = model.to(self.device) + self.dataset = dataset + self.graph_data = dataset.graph_data + self.candidate_fraction = candidate_fraction + self.random_seed = random_seed + self.randomize = randomize + + def _get_features(self): + """Backend-agnostic feature getter (DGL or PyG).""" + return self.graph_data.ndata['feat'] if hasattr(self.graph_data, 'ndata') else self.graph_data.x + + def get_candidate_nodes(self): + """Randomly sample a subset of nodes as candidates.""" + all_nodes = torch.arange(self.graph_data.num_nodes()) + num_candidates = max(1, int(len(all_nodes) * self.candidate_fraction)) + + if self.randomize and self.candidate_fraction < 1.0: + generator = torch.Generator(device=self.device) + if self.random_seed is not None: + generator.manual_seed(self.random_seed) + idx = torch.randperm(len(all_nodes), generator=generator)[:num_candidates] + return all_nodes[idx] + return all_nodes + + def compute_fingerprint_scores_full(self, candidate_nodes): + """Full-knowledge fingerprint scores (gradient-based).""" + self.model.eval() + scores = [] + x = self._get_features().to(self.device) + logits = self.model(self.graph_data.to(self.device), x) + + for node in candidate_nodes: + self.model.zero_grad() + logit = logits[node] + label = logit.argmax().item() + loss = F.cross_entropy(logit.unsqueeze(0), torch.tensor([label], device=self.device)) + loss.backward(retain_graph=True) + grad_norm = sum((p.grad ** 2).sum().item() for p in self.model.parameters() if p.grad is not None) + scores.append(grad_norm) + + return torch.tensor(scores, device=self.device) + + def compute_fingerprint_scores_limited(self, candidate_nodes): + """Limited-knowledge fingerprint scores (confidence margin).""" + self.model.eval() + x = self._get_features().to(self.device) + with torch.no_grad(): + logits = self.model(self.graph_data.to(self.device), x) + probs = F.softmax(logits, dim=1) + labels = probs.argmax(dim=1) + scores = 1.0 - probs[candidate_nodes, labels[candidate_nodes]] + return scores + + def select_top_fingerprints(self, scores, candidate_nodes, k, method='full'): + """Selects top-k fingerprint nodes after filtering out extreme score outliers.""" + q = 0.99 if method == 'full' else 1.0 + threshold = torch.quantile(scores, q) + mask = scores <= threshold + + filtered_scores = scores[mask] + filtered_candidates = candidate_nodes[mask] + + if filtered_scores.size(0) < k: + k = filtered_scores.size(0) + + topk = torch.topk(filtered_scores, k) + return filtered_candidates[topk.indices], topk.values + + def generate_fingerprints(self, k=5, method='full'): + candidate_nodes = self.get_candidate_nodes().to(self.device) + x = self._get_features().to(self.device) + + with torch.no_grad(): + logits = self.model(self.graph_data.to(self.device), x) + labels = logits.argmax(dim=1) + + if method == 'full': + scores = self.compute_fingerprint_scores_full(candidate_nodes) + elif method == 'limited': + scores = self.compute_fingerprint_scores_limited(candidate_nodes) + else: + raise ValueError("method must be 'full' or 'limited'") + + class_to_candidates = {} + for i, node in enumerate(candidate_nodes): + cls = int(labels[node]) + class_to_candidates.setdefault(cls, []).append((node.item(), scores[i].item())) + + rng = random.Random(self.random_seed) + class_list = list(class_to_candidates.keys()) + rng.shuffle(class_list) + + fingerprints = [] + for cls in class_list: + class_nodes = sorted(class_to_candidates[cls], key=lambda x: x[1], reverse=True) + top_node = class_nodes[0][0] + fingerprints.append((top_node, cls)) + if len(fingerprints) >= k: + break + + if len(fingerprints) < k: + fingerprint_nodes, _ = self.select_top_fingerprints(scores, candidate_nodes, k, method=method) + fingerprints = [(int(n), int(labels[n])) for n in fingerprint_nodes] + + return fingerprints + + +class InductiveFingerprintGenerator: + def __init__(self, model, dataset, shadow_graph=None, knowledge='limited', + candidate_fraction=0.3, num_fingerprints=5, + randomize=True, random_seed=None, device='cpu', + perturb_fingerprints=False, perturb_budget=5): + self.device = torch.device(device) + self.model = model.to(self.device) + self.dataset = dataset + self.shadow_graph = shadow_graph if shadow_graph is not None else dataset.graph_data + self.knowledge = knowledge + self.candidate_fraction = candidate_fraction + self.num_fingerprints = num_fingerprints + self.randomize = randomize + self.random_seed = random_seed + self.perturb_fingerprints = perturb_fingerprints + self.perturb_budget = perturb_budget + + if self.random_seed is not None: + torch.manual_seed(self.random_seed) + random.seed(self.random_seed) + + + def _get_features(self): + return self.shadow_graph.ndata['feat'] if hasattr(self.shadow_graph, 'ndata') else self.shadow_graph.x + + + def get_candidate_nodes(self): + all_nodes = torch.arange(self.shadow_graph.num_nodes()) + num_candidates = max(1, int(len(all_nodes) * self.candidate_fraction)) + + if self.randomize and self.candidate_fraction < 1.0: + generator = torch.Generator(device=self.device) + if self.random_seed is not None: + generator.manual_seed(self.random_seed) + idx = torch.randperm(len(all_nodes), generator=generator)[:num_candidates] + candidates = all_nodes[idx] + else: + candidates = all_nodes + + return candidates + + + def compute_fingerprint_score(self, node_idx, graph_override=None): + """ + Computes the fingerprint score for a given node according to knowledge mode. + If graph_override is provided, scoring is done on that graph instead of shadow_graph. + """ + graph = graph_override if graph_override is not None else self.shadow_graph + x = (graph.ndata['feat'] if hasattr(graph, 'ndata') else graph.x).to(self.device) + self.model.eval() + + if self.knowledge == 'limited': + with torch.no_grad(): + logits = self.model(graph.to(self.device), x) + probs = torch.softmax(logits[node_idx], dim=0) + pred_class = probs.argmax().item() + return 1 - probs[pred_class].item() + + elif self.knowledge == 'full': + x.requires_grad_(True) + logits = self.model(graph.to(self.device), x) + pred = logits[node_idx] + label = pred.argmax().item() + + self.model.zero_grad() + loss = torch.nn.functional.nll_loss( + torch.log_softmax(pred.unsqueeze(0), dim=1), + torch.tensor([label], device=self.device) + ) + loss.backward(retain_graph=True) + + grad = x.grad[node_idx] + grad_norm_sq = (grad ** 2).sum().item() + x.requires_grad_(False) + x.grad = None + return grad_norm_sq + else: + raise ValueError("knowledge must be 'limited' or 'full'") + + + def generate_fingerprint_nodes(self): + """ + Step 3: Identifies and returns the top-k (num_fingerprints) nodes with the highest + fingerprint scores from the candidate set. (Section 4.2.2) + + Returns: + List[int]: Indices of selected fingerprint nodes. + """ + candidates = self.get_candidate_nodes() + scores = [] + for idx in candidates: + score = self.compute_fingerprint_score(idx) + scores.append((score, int(idx))) + + scores.sort(reverse=True) + selected = [idx for (_, idx) in scores[:self.num_fingerprints]] + return selected + + + def save_fingerprint_tuples(self, node_indices): + self.model.eval() + x = self._get_features().to(self.device) + with torch.no_grad(): + logits = self.model(self.shadow_graph.to(self.device), x) + labels = logits.argmax(dim=1).cpu().numpy() + return [(self.shadow_graph, int(idx), int(labels[idx])) for idx in node_indices] + + + def generate_fingerprints(self, method='full'): + """ + Generate inductive fingerprints for model watermarking. + + Parameters: + method (str): 'full' for gradient-based or 'limited' for output-based + + Returns: + List of fingerprints + """ + if method == 'full': + return self._generate_full() + elif method == 'limited': + return self._generate_limited() + else: + raise ValueError(f"Invalid fingerprinting method: '{method}'") + + def _generate_full(self): + """ + Implements full knowledge fingerprint generation (gradient-based). + Based on Section 4.2.1 and 5.2 of Wu et al. (2023). + """ + self.knowledge = 'full' + print("[Fingerprint] Generating FULL knowledge fingerprints...") + fingerprint_nodes = self.generate_fingerprint_nodes() + + if self.perturb_fingerprints: + print("[Fingerprint] Applying greedy feature perturbation (FULL)...") + self.greedy_perturb_fingerprints(fingerprint_nodes) + + return self.save_fingerprint_tuples(fingerprint_nodes) + + def _generate_limited(self): + """ + Implements limited knowledge fingerprint generation (output-based). + Based on Section 4.2.2 and 5.2 of Wu et al. (2023). + """ + self.knowledge = 'limited' + print("[Fingerprint] Generating LIMITED knowledge fingerprints...") + fingerprint_nodes = self.generate_fingerprint_nodes() + + if self.perturb_fingerprints: + print("[Fingerprint] Applying greedy feature perturbation (LIMITED)...") + self.greedy_perturb_fingerprints(fingerprint_nodes) + + return self.save_fingerprint_tuples(fingerprint_nodes) + + + def greedy_perturb_fingerprints(self, node_indices): + """ + Greedily perturbs each fingerprint node's features (not edges) to increase its + fingerprint score, without changing the predicted label. + + - For each node, for each feature dimension: + - Add or subtract a small epsilon. + - Accept change if predicted label stays the same and fingerprint score increases. + - Stop after perturb_budget attempts or no improvement. + + Returns: + List[int]: Indices of perturbed fingerprint nodes (features in shadow_graph are updated in-place). + """ + epsilon = 0.01 + features = self._get_features().clone().detach().to(self.device) + self.shadow_graph = self.shadow_graph.to(self.device) + + for idx in node_indices: + num_tries = 0 + improved = True + while num_tries < self.perturb_budget and improved: + improved = False + current_score = self.compute_fingerprint_score(idx, graph_override=self.shadow_graph) + + self.model.eval() + with torch.no_grad(): + logits = self.model(self.shadow_graph, features) + pred_label = logits[idx].argmax().item() + + original_features = features[idx].clone() + for dim in range(features.shape[1]): + for direction in [+1, -1]: + features[idx][dim] += direction * epsilon + + self.model.eval() + with torch.no_grad(): + logits_new = self.model(self.shadow_graph, features) + new_pred_label = logits_new[idx].argmax().item() + new_score = self.compute_fingerprint_score(idx, graph_override=self.shadow_graph) + + if new_pred_label == pred_label and new_score > current_score: + current_score = new_score + improved = True + num_tries += 1 + else: + features[idx][dim] = original_features[dim] + + if num_tries >= self.perturb_budget: + break + if num_tries >= self.perturb_budget: + break + + if hasattr(self.shadow_graph, 'ndata'): + self.shadow_graph.ndata['feat'] = features + else: + self.shadow_graph.x = features + return node_indices + + def greedy_edge_perturbation(self, node_idx, perturb_budget=5, knowledge='full'): + """ + Dispatch to greedy edge perturbation strategy based on verifier knowledge level. + + Args: + node_idx (int): Fingerprint node index. + perturb_budget (int): Number of edge perturbations allowed. + knowledge (str): 'full' or 'limited' + """ + if knowledge == 'full': + self._greedy_edge_perturbation_f(node_idx, perturb_budget) + elif knowledge == 'limited': + self._greedy_edge_perturbation_l(node_idx, perturb_budget) + else: + raise ValueError("knowledge must be 'full' or 'limited'") + + + def _greedy_edge_perturbation_f(self, node_idx, perturb_budget): + """ + Full knowledge edge perturbation (Inductive-F). + Increases fingerprint score using model gradients while preserving prediction. + """ + + g_nx = to_networkx(self.shadow_graph.to('cpu'), to_undirected=True) + x = self._get_features().to(self.device) + self.model.eval() + + with torch.no_grad(): + original_pred = self.model(self.shadow_graph.to(self.device), x)[node_idx].argmax().item() + + def score_fn(modified_graph): + return self.compute_fingerprint_score(node_idx, graph_override=modified_graph) + + neighbors = list(g_nx.neighbors(node_idx)) + non_neighbors = list(set(range(self.shadow_graph.num_nodes())) - set(neighbors) - {node_idx}) + + applied = 0 + while applied < perturb_budget: + best_delta = 0 + best_graph = None + best_action = None + + for nbr in non_neighbors: + temp_g = copy.deepcopy(g_nx) + temp_g.add_edge(node_idx, nbr) + g_temp = from_networkx(temp_g).to(self.device) + with torch.no_grad(): + pred = self.model(g_temp, x)[node_idx].argmax().item() + if pred != original_pred: + continue + delta = score_fn(g_temp) - score_fn(self.shadow_graph) + if delta > best_delta: + best_delta = delta + best_graph = g_temp + best_action = ('add', nbr) + + for nbr in neighbors: + temp_g = copy.deepcopy(g_nx) + if temp_g.has_edge(node_idx, nbr): + temp_g.remove_edge(node_idx, nbr) + g_temp = from_networkx(temp_g).to(self.device) + with torch.no_grad(): + pred = self.model(g_temp, x)[node_idx].argmax().item() + if pred != original_pred: + continue + delta = score_fn(g_temp) - score_fn(self.shadow_graph) + if delta > best_delta: + best_delta = delta + best_graph = g_temp + best_action = ('remove', nbr) + + if best_graph is None: + break + self.shadow_graph = best_graph + g_nx = to_networkx(best_graph.to('cpu'), to_undirected=True) + + if best_action[0] == 'add': + non_neighbors.remove(best_action[1]) + neighbors.append(best_action[1]) + else: + neighbors.remove(best_action[1]) + non_neighbors.append(best_action[1]) + + applied += 1 + + def _greedy_edge_perturbation_l(self, node_idx, perturb_budget): + """ + Limited knowledge edge perturbation (Inductive-L). + Uses confidence margin (1 - confidence) as proxy for fingerprint sensitivity. + """ + + g_nx = to_networkx(self.shadow_graph.to('cpu'), to_undirected=True) + x = self._get_features().to(self.device) + self.model.eval() + + with torch.no_grad(): + logits = self.model(self.shadow_graph.to(self.device), x) + original_pred = logits[node_idx].argmax().item() + original_conf = F.softmax(logits[node_idx], dim=0)[original_pred].item() + original_score = 1 - original_conf + + def score_fn(modified_graph): + with torch.no_grad(): + logits = self.model(modified_graph.to(self.device), x) + pred = logits[node_idx].argmax().item() + if pred != original_pred: + return -1 + conf = F.softmax(logits[node_idx], dim=0)[pred].item() + return 1 - conf + + neighbors = list(g_nx.neighbors(node_idx)) + non_neighbors = list(set(range(self.shadow_graph.num_nodes())) - set(neighbors) - {node_idx}) + + applied = 0 + while applied < perturb_budget: + best_delta = 0 + best_graph = None + best_action = None + + for nbr in non_neighbors: + temp_g = copy.deepcopy(g_nx) + temp_g.add_edge(node_idx, nbr) + g_temp = from_networkx(temp_g).to(self.device) + new_score = score_fn(g_temp) + delta = new_score - original_score + if new_score >= 0 and delta > best_delta: + best_delta = delta + best_graph = g_temp + best_action = ('add', nbr) + + for nbr in neighbors: + temp_g = copy.deepcopy(g_nx) + if temp_g.has_edge(node_idx, nbr): + temp_g.remove_edge(node_idx, nbr) + g_temp = from_networkx(temp_g).to(self.device) + new_score = score_fn(g_temp) + delta = new_score - original_score + if new_score >= 0 and delta > best_delta: + best_delta = delta + best_graph = g_temp + best_action = ('remove', nbr) + + if best_graph is None: + break + self.shadow_graph = best_graph + g_nx = to_networkx(best_graph.to('cpu'), to_undirected=True) + + if best_action[0] == 'add': + non_neighbors.remove(best_action[1]) + neighbors.append(best_action[1]) + else: + neighbors.remove(best_action[1]) + non_neighbors.append(best_action[1]) + + applied += 1 + +class BitFlipAttack: + def __init__(self, model, attack_type='random', bit=0): + self.model = model + self.attack_type = attack_type + self.bit = bit + + def _get_target_params(self): + params = [p for p in self.model.parameters() if p.requires_grad and p.numel() > 0] + if self.attack_type in ['random', 'BFA']: + return params + elif self.attack_type == 'BFA-F': + return [params[0]] + elif self.attack_type == 'BFA-L': + return [params[-1]] + else: + raise ValueError(f"Unknown attack_type {self.attack_type}") + + def _true_bit_flip(self, tensor, index=None, bit=0): + a = tensor.detach().cpu().numpy().copy() + flat = a.ravel() + if index is None: + index = np.random.randint(0, flat.size) + old_val = flat[index] + int_view = np.frombuffer(flat[index].tobytes(), dtype=np.uint32)[0] + int_view ^= (1 << bit) + new_val = np.frombuffer(np.uint32(int_view).tobytes(), dtype=np.float32)[0] + flat[index] = new_val + a = flat.reshape(a.shape) + tensor.data = torch.from_numpy(a).to(tensor.device) + return old_val, new_val, index + + def apply(self): + params = self._get_target_params() + with torch.no_grad(): + layer_idx = random.randrange(len(params)) + param = params[layer_idx] + idx = random.randrange(param.numel()) + old_val, new_val, actual_idx = self._true_bit_flip(param, index=idx, bit=self.bit) + return { + 'layer': layer_idx, + 'param_idx': actual_idx, + 'old_val': old_val, + 'new_val': new_val, + 'bit': self.bit, + 'attack_type': self.attack_type + } + + + +class MettackHelper: + def __init__(self, graph, features, labels, train_mask, val_mask, test_mask, + n_perturbations=5, device='cpu', max_perturbations=50, + surrogate_epochs=30, candidate_sample_size=20): + self.device = device + self.graph = dgl.add_self_loop(graph).to(self.device) + self.features = features.to(self.device) + self.labels = labels.to(self.device) + self.train_mask = train_mask.to(self.device) + self.surrogate_epochs = surrogate_epochs + self.candidate_sample_size = candidate_sample_size + if val_mask is not None: + self.val_mask = val_mask.to(self.device) + else: + self.val_mask = self._create_val_mask_from_train(train_mask).to(self.device) + + self.test_mask = test_mask.to(self.device) + + self.n_perturbations = n_perturbations + + in_feats = features.shape[1] + n_classes = int(labels.max().item()) + 1 + self.surrogate = GCN(in_feats, n_classes).to(self.device) + + torch.manual_seed(42) + np.random.seed(42) + + + self.modified_edges = set() + + original_graph_no_self_loop = dgl.remove_self_loop(graph) + self.original_edges = set(zip(original_graph_no_self_loop.edges()[0].cpu().numpy(), + original_graph_no_self_loop.edges()[1].cpu().numpy())) + + self.candidate_edges = self._get_candidate_edges() + + def _create_val_mask_from_train(self, train_mask): + """ + Create a validation mask by taking a subset of training nodes. + This is needed when the dataset doesn't provide a validation mask. + """ + train_indices = torch.where(train_mask)[0] + n_val = min(500, len(train_indices) // 4) + + perm = torch.randperm(len(train_indices)) + val_indices = train_indices[perm[:n_val]] + + + val_mask = torch.zeros_like(train_mask, dtype=torch.bool) + val_mask[val_indices] = True + + + self.train_mask = train_mask.clone() + self.train_mask[val_indices] = False + + return val_mask + + def run(self): + """ + Main entrypoint to run the Mettack algorithm. + Returns: + poisoned_graph (DGLGraph): The perturbed graph with edges changed. + metrics (dict): Metrics for before/after attack, for evaluation. + """ + print("Starting Mettack attack...") + + + print("Training surrogate model...") + self._train_surrogate() + + + print("Applying structure attack...") + poisoned_graph = self._apply_structure_attack() + + print("Evaluating attack results...") + metrics = self._evaluate(poisoned_graph) + + return poisoned_graph, metrics + + def _train_surrogate(self): + """ + Trains a surrogate GCN on the clean graph. + (Matches Wu et al., Section 6.1) + """ + optimizer = optim.Adam(self.surrogate.parameters(), lr=0.01, weight_decay=5e-4) + self.surrogate.train() + + + for epoch in range(self.surrogate_epochs): + optimizer.zero_grad() + logits = self.surrogate(self.graph, self.features) + loss = F.cross_entropy(logits[self.train_mask], self.labels[self.train_mask]) + loss.backward() + optimizer.step() + + if epoch % 50 == 0: + self.surrogate.eval() + with torch.no_grad(): + val_logits = self.surrogate(self.graph, self.features) + val_acc = self._compute_accuracy(val_logits[self.val_mask], + self.labels[self.val_mask]) + print(f"Surrogate epoch {epoch}: Val Acc = {val_acc:.4f}") + self.surrogate.train() + + def _apply_structure_attack(self): + """ + Runs the Mettack structure perturbation loop (bi-level optimization). + - At each step, modify the adjacency matrix (add/remove an edge). + - Select the perturbation that maximizes surrogate model loss on the validation nodes. + - Repeat up to n_perturbations times. + Returns a new DGLGraph with edges modified. + (See Appendix A.2 in Wu et al.) + """ + current_graph = copy.deepcopy(self.graph) + perturbed_edges = set() + + for step in range(self.n_perturbations): + print(f"Perturbation step {step + 1}/{self.n_perturbations}") + + best_edge = None + best_loss = -float('inf') + best_action = None + + + candidate_sample = np.random.choice(len(self.candidate_edges), + min(self.candidate_sample_size, len(self.candidate_edges)), + replace=False) + + + for idx in tqdm(candidate_sample, desc="Evaluating candidates"): + edge = self.candidate_edges[idx] + + if edge in perturbed_edges or (edge[1], edge[0]) in perturbed_edges: + continue + + for action in ['add', 'remove']: + if action == 'add' and edge in self.original_edges: + continue + if action == 'remove' and edge not in self.original_edges: + continue + + temp_graph = self._apply_single_perturbation(current_graph, edge, action) + + attack_loss = self._compute_attack_loss(temp_graph) + + if attack_loss > best_loss: + best_loss = attack_loss + best_edge = edge + best_action = action + + if best_edge is not None: + current_graph = self._apply_single_perturbation(current_graph, best_edge, best_action) + perturbed_edges.add(best_edge) + self.modified_edges.add((best_edge, best_action)) + print(f"Applied {best_action} edge {best_edge} with loss increase: {best_loss:.4f}") + else: + print("No beneficial perturbation found, stopping early.") + break + + return current_graph + + def _get_candidate_edges(self): + """ + Generate candidate edges for perturbation. + Includes both existing edges (for removal) and non-existing edges (for addition). + """ + n_nodes = self.graph.num_nodes() + + all_possible_edges = [] + for i in range(n_nodes): + for j in range(i + 1, n_nodes): + all_possible_edges.append((i, j)) + + return all_possible_edges[:min(10000, len(all_possible_edges))] + + def _apply_single_perturbation(self, graph, edge, action): + """ + Apply a single edge perturbation (add or remove) to the graph. + """ + temp_graph = copy.deepcopy(graph) + + if action == 'add': + temp_graph.add_edges([edge[0], edge[1]], [edge[1], edge[0]]) + elif action == 'remove': + src, dst = temp_graph.edges() + edge_ids = [] + + for i, (s, d) in enumerate(zip(src.cpu().numpy(), dst.cpu().numpy())): + if (s == edge[0] and d == edge[1]) or (s == edge[1] and d == edge[0]): + edge_ids.append(i) + + if edge_ids: + temp_graph.remove_edges(edge_ids) + + temp_graph = dgl.add_self_loop(temp_graph) + + return temp_graph + + def _compute_attack_loss(self, perturbed_graph): + """ + Compute the attack loss on a perturbed graph. + This measures how much the surrogate model's performance degrades. + Uses proper bi-level optimization as in the original Mettack paper. + """ + + temp_surrogate = copy.deepcopy(self.surrogate) + temp_surrogate.train() + + + optimizer = optim.Adam(temp_surrogate.parameters(), lr=0.01) + + for _ in range(5): + optimizer.zero_grad() + logits = temp_surrogate(perturbed_graph, self.features) + loss = F.cross_entropy(logits[self.train_mask], self.labels[self.train_mask]) + loss.backward() + optimizer.step() + + + temp_surrogate.eval() + with torch.no_grad(): + val_logits = temp_surrogate(perturbed_graph, self.features) + val_loss = F.cross_entropy(val_logits[self.val_mask], self.labels[self.val_mask]) + + return val_loss.item() + + def _evaluate(self, poisoned_graph): + """ + Evaluates GCN accuracy before/after poisoning, etc. + """ + metrics = {} + + + self.surrogate.eval() + with torch.no_grad(): + clean_logits = self.surrogate(self.graph, self.features) + clean_acc = self._compute_accuracy(clean_logits[self.test_mask], + self.labels[self.test_mask]) + metrics['clean_test_acc'] = clean_acc + + + poisoned_model = GCN(self.features.shape[1], + int(self.labels.max().item()) + 1).to(self.device) + optimizer = optim.Adam(poisoned_model.parameters(), lr=0.01, weight_decay=5e-4) + + poisoned_model.train() + for epoch in range(200): + optimizer.zero_grad() + logits = poisoned_model(poisoned_graph, self.features) + loss = F.cross_entropy(logits[self.train_mask], self.labels[self.train_mask]) + loss.backward() + optimizer.step() + + + poisoned_model.eval() + with torch.no_grad(): + poisoned_logits = poisoned_model(poisoned_graph, self.features) + poisoned_acc = self._compute_accuracy(poisoned_logits[self.test_mask], + self.labels[self.test_mask]) + metrics['poisoned_test_acc'] = poisoned_acc + + metrics['accuracy_drop'] = clean_acc - poisoned_acc + metrics['num_perturbations'] = len(self.modified_edges) + + + return metrics + + def _compute_accuracy(self, logits, labels): + """Helper function to compute accuracy.""" + _, predicted = torch.max(logits, 1) + correct = (predicted == labels).sum().item() + return correct / len(labels) diff --git a/models/defense/__init__.py b/models/defense/__init__.py index a507949..7432a8f 100644 --- a/models/defense/__init__.py +++ b/models/defense/__init__.py @@ -1,5 +1,6 @@ from .base import BaseDefense from .SurviveWM2 import OptimizedWatermarkDefense +from .QueryBasedVerification import QueryBasedVerificationDefense from .WatermarkDefense import ( WatermarkByRandomGraph, ) @@ -9,5 +10,6 @@ __all__ = [ 'BaseDefense', 'WatermarkByRandomGraph', - 'OptimizedWatermarkDefense' + 'OptimizedWatermarkDefense', + 'QueryBasedVerificationDefense' ]