Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions examples/defog/_sbm_isolated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import sys
import json
import numpy as np

try:
import graph_tool.all as gt
from scipy.stats import chi2
except ImportError:
print("ERROR: graph-tool and scipy required", file=sys.stderr)
sys.exit(1)

def main():
data = json.loads(sys.stdin.read())
edges = data['edges']
p_intra = data['p_intra']
p_inter = data['p_inter']
strict = data['strict']
refinement_steps = data['refinement_steps']

gt_g = gt.Graph()
if edges:
gt_g.add_edge_list(edges)

try:
state = gt.minimize_blockmodel_dl(gt_g)
except ValueError:
print("False")
return

# Refine using merge-split MCMC
for _ in range(refinement_steps):
state.multiflip_mcmc_sweep(beta=np.inf, niter=10)

b = gt.contiguous_map(state.get_blocks())
state = state.copy(b=b)
e = state.get_matrix()
n_blocks = state.get_nonempty_B()
node_counts = state.get_nr().get_array()[:n_blocks]
edge_counts = e.todense()[:n_blocks, :n_blocks]

if strict:
if (node_counts > 40).sum() > 0 or (node_counts < 20).sum() > 0 or n_blocks > 5 or n_blocks < 2:
print("False")
return

max_intra_edges = node_counts * (node_counts - 1)
est_p_intra = np.diagonal(edge_counts) / (max_intra_edges + 1e-6)

max_inter_edges = node_counts.reshape((-1, 1)) @ node_counts.reshape((1, -1))
np.fill_diagonal(edge_counts, 0)
est_p_inter = edge_counts / (max_inter_edges + 1e-6)

W_p_intra = (est_p_intra - p_intra) ** 2 / (est_p_intra * (1 - est_p_intra) + 1e-6)
W_p_inter = (est_p_inter - p_inter) ** 2 / (est_p_inter * (1 - est_p_inter) + 1e-6)

W = W_p_inter.copy()
np.fill_diagonal(W, W_p_intra)
p = 1 - chi2.cdf(abs(W), 1)
p = p.mean()

if p > 0.9:
print("True")
else:
print("False")

if __name__ == "__main__":
main()
201 changes: 201 additions & 0 deletions examples/defog/dataset_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import numpy as np
import tensorlayerx as tlx
from gammagl.data import Graph
from datasets import (
PlanarGraphDataset, TreeGraphDataset, SBMGraphDataset, Comm20GraphDataset,
QM9Gen, GuacaMolDataset, ZINC250kGen, MOSESDataset, TLSGraphDataset
)

def create_synthetic_dataset(num_graphs=100, min_nodes=10, max_nodes=20,
num_node_types=2, num_edge_types=2, p_edge=0.3):
r"""Create a synthetic graph dataset for testing."""
graphs = []
for _ in range(num_graphs):
n = np.random.randint(min_nodes, max_nodes + 1)
node_labels = np.random.randint(0, num_node_types, size=n)
x = np.eye(num_node_types, dtype=np.float32)[node_labels]
adj = (np.random.rand(n, n) < p_edge).astype(np.float32)
adj = np.triu(adj, k=1)
adj = adj + adj.T
np.fill_diagonal(adj, 0)
src, dst = np.nonzero(adj)
edge_index = np.stack([src, dst], axis=0).astype(np.int64)

if num_edge_types > 1:
edge_labels = np.random.randint(1, num_edge_types, size=len(src))
edge_attr = np.eye(num_edge_types, dtype=np.float32)[edge_labels]
else:
edge_attr = np.ones((len(src), 1), dtype=np.float32)

g = Graph(
x=tlx.convert_to_tensor(x),
edge_index=tlx.convert_to_tensor(edge_index),
edge_attr=tlx.convert_to_tensor(edge_attr),
y=tlx.convert_to_tensor(np.zeros(1, dtype=np.float32)),
)
graphs.append(g)
return graphs


def compute_dataset_infos(dataset, num_node_types, num_edge_types):
r"""Compute dataset statistics dynamically."""
total_graphs = len(dataset)
node_counts = []
node_type_counts = np.zeros(num_node_types, dtype=np.float32)
edge_type_counts = np.zeros(num_edge_types, dtype=np.float32)

for idx in range(total_graphs):
g = dataset[idx]
x_val = g.x
x_np = x_val if isinstance(x_val, np.ndarray) else tlx.convert_to_numpy(x_val)
n = x_np.shape[0]
node_counts.append(n)

node_labels = np.argmax(x_np, axis=-1)
for label in node_labels:
node_type_counts[label] += 1

if getattr(g, 'edge_attr', None) is not None:
ea_val = g.edge_attr
ea_np = ea_val if isinstance(ea_val, np.ndarray) else tlx.convert_to_numpy(ea_val)
edge_type_sums = ea_np.sum(axis=0)
if edge_type_sums.shape[0] > 1:
edge_type_counts[1:] += edge_type_sums[1:]

total_pairs = n * (n - 1)
n_edges = g.edge_index.shape[1] if getattr(g, 'edge_index', None) is not None else 0
n_no_edge = total_pairs - n_edges
edge_type_counts[0] += n_no_edge

if (idx + 1) <= 3 or (idx + 1) % 50000 == 0 or (idx + 1) == total_graphs:
print(f" Processing graph {idx + 1}/{total_graphs}...")

max_n = max(node_counts)
node_dist = np.zeros(max_n + 1, dtype=np.float32)
for nc in node_counts:
node_dist[nc] += 1
node_dist = node_dist / node_dist.sum()

return {
'output_dims': {
'X': num_node_types,
'E': num_edge_types,
'y': 0,
},
'node_types': node_type_counts,
'edge_types': edge_type_counts,
'max_n_nodes': max_n,
'node_dist': node_dist,
}


class SpectreNodeTransform:
def __init__(self, num_node_types=2):
self.num_node_types = num_node_types

def __call__(self, data):
n = data.x.shape[0]
x_np = np.zeros((n, self.num_node_types), dtype=np.float32)
x_np[:, 0] = 1.0
data.x = tlx.convert_to_tensor(x_np, dtype=tlx.float32)
y_val = data.y if hasattr(data, 'y') and data.y is not None else np.zeros((1, 0), dtype=np.float32)
if isinstance(y_val, np.ndarray) and y_val.size == 0:
y_val = np.zeros((1, 0), dtype=np.float32)
data.y = tlx.convert_to_tensor(y_val, dtype=tlx.float32) if isinstance(y_val, np.ndarray) else y_val
return data


class GenericNodeTransform:
def __call__(self, data):
y_val = data.y if hasattr(data, 'y') and data.y is not None else np.zeros((1, 0), dtype=np.float32)
data.y = tlx.convert_to_tensor(y_val, dtype=tlx.float32) if isinstance(y_val, np.ndarray) else y_val
return data


def load_real_dataset(name, root=None, conditional=False, target='mu', remove_h=None):
if name == 'planar':
ds_cls = PlanarGraphDataset
num_node_types, num_edge_types = 2, 2
convert_spectre = True
elif name == 'tree':
ds_cls = TreeGraphDataset
num_node_types, num_edge_types = 2, 2
convert_spectre = True
elif name == 'sbm':
ds_cls = SBMGraphDataset
num_node_types, num_edge_types = 2, 2
convert_spectre = True
elif name == 'comm20':
ds_cls = Comm20GraphDataset
num_node_types, num_edge_types = 2, 2
convert_spectre = True
elif name == 'qm9':
ds_cls = QM9Gen
qm9_remove_h = True if remove_h is None else bool(remove_h)
num_node_types, num_edge_types = (4, 5) if qm9_remove_h else (5, 5)
convert_spectre = False
elif name == 'guacamol':
ds_cls = GuacaMolDataset
num_node_types, num_edge_types = 12, 5
convert_spectre = False
elif name == 'zinc250k':
ds_cls = ZINC250kGen
num_node_types, num_edge_types = 9, 4
convert_spectre = False
elif name == 'moses':
ds_cls = MOSESDataset
num_node_types, num_edge_types = 8, 5
convert_spectre = False
elif name == 'tls':
ds_cls = TLSGraphDataset
num_node_types, num_edge_types = 9, 2
convert_spectre = False
else:
raise ValueError(f"Unknown dataset: {name}")

kwargs = {'root': root} if root else {}
if name == 'qm9':
kwargs['remove_h'] = True if remove_h is None else remove_h
kwargs['aromatic'] = True
if name in ('qm9', 'tls') and conditional:
kwargs['conditional'] = True
kwargs['target'] = target

# Inject transform
transform = SpectreNodeTransform(num_node_types) if convert_spectre else GenericNodeTransform()
kwargs['transform'] = transform

train_ds = ds_cls(split='train', **kwargs)
val_ds = ds_cls(split='val', **kwargs)
test_ds = ds_cls(split='test', **kwargs)

print(f"Dataset {name}: train={len(train_ds)}, val={len(val_ds)}, test={len(test_ds)}")

if name in ('qm9', 'guacamol', 'zinc250k', 'moses'):
dataset_infos = {'output_dims': {'X': num_node_types, 'E': num_edge_types, 'y': 0}}
stats = ds_cls.STATS_REMOVE_H if name == 'qm9' and kwargs.get('remove_h') else getattr(ds_cls, 'STATS_WITH_H', getattr(ds_cls, 'STATS', None))

if stats:
atom_decoder = stats.get('atom_names', getattr(ds_cls, 'ATOM_DECODER', []))
dataset_infos['node_types'] = stats['node_types'].astype(np.float32).copy()
dataset_infos['edge_types'] = stats['edge_types'].astype(np.float32).copy()
if 'n_nodes' in stats:
dataset_infos['node_dist'] = stats['n_nodes'].astype(np.float32).copy()
dataset_infos['max_n_nodes'] = int(stats.get('max_n_nodes', len(stats['n_nodes']) - 1))
dataset_infos['remove_h'] = kwargs.get('remove_h', True)
dataset_infos['valencies'] = list(stats['valencies'])
if 'valency_distribution' in stats:
dataset_infos['valency_distribution'] = stats['valency_distribution'].astype(np.float32).copy()
dataset_infos['atom_weights'] = dict(stats['atom_weights'])
dataset_infos['max_weight'] = float(stats['max_weight'])
dataset_infos['atom_decoder'] = list(atom_decoder)
else:
dataset_infos = compute_dataset_infos(train_ds, num_node_types, num_edge_types)

test_labels = None
if conditional:
test_labels_list = [tlx.convert_to_numpy(g.y).flatten() for g in test_ds]
if test_labels_list and len(test_labels_list[0]) > 0:
test_labels = tlx.convert_to_tensor(np.stack(test_labels_list, axis=0).astype(np.float32))

return train_ds, val_ds, test_ds, dataset_infos, num_node_types, num_edge_types, test_labels
35 changes: 35 additions & 0 deletions examples/defog/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""DeFoG-specific graph generation datasets.

These datasets are designed for discrete flow matching graph generation
and contain heavy preprocessing (dense adjacency matrices, atom/bond
distribution statistics, eigenvalue caching, etc.) that is specific to
the DeFoG training pipeline.

They are intentionally kept in ``examples/defog/datasets/`` rather than
``gammagl/datasets/`` to avoid polluting the core GammaGL package with
optional heavy dependencies (RDKit, graph-tool, etc.).
"""

from .spectre_dataset import (
PlanarGraphDataset,
TreeGraphDataset,
SBMGraphDataset,
Comm20GraphDataset,
)
from .qm9_dataset import QM9Gen
from .moses_dataset import MOSESDataset
from .guacamol_dataset import GuacaMolDataset
from .zinc250k_dataset import ZINC250kGen
from .tls_dataset import TLSGraphDataset

__all__ = [
'PlanarGraphDataset',
'TreeGraphDataset',
'SBMGraphDataset',
'Comm20GraphDataset',
'QM9Gen',
'MOSESDataset',
'GuacaMolDataset',
'ZINC250kGen',
'TLSGraphDataset',
]
Loading