diff --git a/examples/unifews/.gitignore b/examples/unifews/.gitignore new file mode 100644 index 000000000..234538ec2 --- /dev/null +++ b/examples/unifews/.gitignore @@ -0,0 +1,201 @@ +# This repo +backup/* +result*/ +*/result/ +*/results/ +save* +save/* +wheels/* +precompute/prop.cpp + +# my ignore +try* +.vscode/ +.DS_Store + +# CMake +CMakeLists.txt.user +CMakeCache.txt +CMakeFiles +CMakeScripts +Testing +Makefile +cmake_install.cmake +install_manifest.txt +compile_commands.json +CTestTestfile.cmake +_deps + +# Prerequisites +*.d + +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[co] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +build*/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ \ No newline at end of file diff --git a/examples/unifews/convert_data.py b/examples/unifews/convert_data.py new file mode 100644 index 000000000..085a9f554 --- /dev/null +++ b/examples/unifews/convert_data.py @@ -0,0 +1,133 @@ +"""Data conversion helpers for UNIFEWS. + +These are example-local utilities for converting OGB/PyG datasets +into the format expected by the data processor. Not intended as +shared GammaGL utilities. +""" + +import os +import argparse +import json +import numpy as np +import scipy.sparse as sp +import gammagl.datasets as ggl_ds +import gammagl.transforms as ggl_t +import gammagl.utils as ggl_utils + +from gammagl.utils.data_processor import DataProcess, edgeidx2adj + + +class DataProcess_OGB(DataProcess): + def __init__(self, name, dataset_root, path='../data/', rrz=0.5, seed=0): + super().__init__(name, path=path, rrz=rrz, seed=seed) + self.dataset_root = dataset_root + + @property + def n_train(self): + if self.idx_train is None: + self.input(['labels']) + return len(self.idx_train) + + def fetch(self): + dataset = ggl_ds.OGB(self.name, root=self.dataset_root) + split_idx = dataset.get_idx_split() + idx_train, idx_val, idx_test = split_idx["train"], split_idx["valid"], split_idx["test"] + graph, labels = dataset[0] + + row, col = graph.edge_index[0].numpy(), graph.edge_index[1].numpy() + row, col = np.concatenate([row, col], axis=0), np.concatenate([col, row], axis=0) + deg = np.bincount(row) + idx_zero = np.where(deg == 0)[0] + if len(idx_zero) > 0: + print(f"Warning: removing {len(idx_zero)} isolated nodes: {idx_zero}!") + + nodelst = deg.nonzero()[0] + idxnew = np.full(graph.num_nodes, -1) + idxnew[nodelst] = np.arange(len(nodelst)) + self._n = len(nodelst) + row, col = idxnew[row], idxnew[col] + self.adj_matrix = edgeidx2adj(row, col, self.n) + self._m = self.adj_matrix.nnz + + self.attr_matrix = graph.x.numpy()[nodelst] + assert (labels.ndim == 2 and labels.shape[1] == 1) or labels.ndim == 1, "label shape error" + self.labels = labels.numpy().flatten()[nodelst] + + self.idx_train = idxnew[idx_train] + self.idx_train = self.idx_train[self.idx_train > -1] + self.idx_val = idxnew[idx_val] + self.idx_val = self.idx_val[self.idx_val > -1] + self.idx_test = idxnew[idx_test] + self.idx_test = self.idx_test[self.idx_test > -1] + + +class DataProcess_PyGFlickr(DataProcess): + def __init__(self, name, raw_root, path='../data/', rrz=0.5, seed=0): + super().__init__(name, path=path, rrz=rrz, seed=seed) + self.raw_root = raw_root + + def fetch(self): + f = np.load(os.path.join(self.raw_root, 'adj_full.npz')) + self.adj_matrix = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape']) + self.to_undirected() + self.attr_matrix = np.load(os.path.join(self.raw_root, 'feats.npy'), allow_pickle=True) + + ys = [-1] * self.attr_matrix.shape[0] + with open(os.path.join(self.raw_root, 'class_map.json')) as f: + class_map = json.load(f) + for key, item in class_map.items(): + ys[int(key)] = item + self.labels = np.array(ys).flatten() + if self.labels.min() < 0: + print(f"Warning: negative label: {self.labels.min()}") + + with open(os.path.join(self.raw_root, 'role.json')) as f: + role = json.load(f) + self.idx_train = np.array(role['tr']) + self.idx_val = np.array(role['va']) + self.idx_test = np.array(role['te']) + + +class DataProcess_PyG(DataProcess): + def __init__(self, name, dataset_root, path='../data/', rrz=0.5, seed=0): + super().__init__(name, path=path, rrz=rrz, seed=seed) + self.dataset_root = dataset_root + + def fetch(self): + dataset = ggl_ds.Coauthor(root=self.dataset_root, name='CS', transform=ggl_t.ToUndirected()) + graph = dataset[0] + degree = ggl_utils.degree(graph.edge_index[0], graph.num_nodes).numpy() + idx_zero = np.where(degree == 0)[0] + if len(idx_zero) > 0: + print(f"Warning: removing {len(idx_zero)} isolated nodes: {idx_zero}!") + + self._n = graph.num_nodes + self.adj_matrix = edgeidx2adj(graph.edge_index[0].numpy(), graph.edge_index[1].numpy(), self.n) + self._m = self.adj_matrix.nnz + self.attr_matrix = graph.x.numpy() + self.labels = graph.y.numpy().flatten() + if self.labels.min() < 0: + print(f"Warning: negative label: {self.labels.min()}") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Convert OGB/PyG datasets for UNIFEWS') + parser.add_argument('--dataset', type=str, default='ogbn-papers100M', choices=['ogbn-papers100M', 'flickr', 'cs']) + parser.add_argument('--dataset-root', type=str, required=True, help='Root directory for the source dataset') + parser.add_argument('--output-path', type=str, default='./data/', help='Output directory') + args = parser.parse_args() + + if args.dataset.startswith('ogbn'): + ds = DataProcess_OGB(args.dataset, dataset_root=args.dataset_root, path=args.output_path) + elif args.dataset == 'flickr': + ds = DataProcess_PyGFlickr(args.dataset, raw_root=args.dataset_root, path=args.output_path) + elif args.dataset == 'cs': + ds = DataProcess_PyG(args.dataset, dataset_root=args.dataset_root, path=args.output_path) + else: + raise ValueError(f"Unknown dataset: {args.dataset}") + + ds.fetch() + ds.calculate(['deg', 'idx_train']) + print(ds) + os.makedirs(os.path.join(ds.path, ds.name), exist_ok=False) + ds.output(['adjtxt', 'adjnpz', 'adjl', 'attr_matrix', 'deg', 'labels', 'attribute']) diff --git a/examples/unifews/loader.py b/examples/unifews/loader.py new file mode 100644 index 000000000..1f9789928 --- /dev/null +++ b/examples/unifews/loader.py @@ -0,0 +1,92 @@ +"""Example-local data loading for UNIFEWS decouple mode. + +This module depends on the optional Cython precompute extension. +Only import this if you need decouple mode (run_mlp.py). +""" + +import os +import gc +import numpy as np +import tensorlayerx as tlx +from dotmap import DotMap + +from gammagl.utils.data_processor import DataProcess, DataProcess_inductive, matstd_clip + + +def dmap2dct(chnname, dmap, processor): + typedct = {'sgc': -2, 'gbp': -3, + 'sgc_agp': 0, 'gbp_agp': 1, + 'sgc_thr': 2, 'gbp_thr': 3} + ctype = chnname.split('_')[0] + dct = {} + dct['type'] = typedct[chnname] + dct['hop'] = dmap.hop + dct['dim'] = processor.nfeat + dct['delta'] = dmap.delta if type(dmap.delta) is float else 1e-5 + dct['alpha'] = dmap.alpha if (type(dmap.alpha) is float and not (ctype == 'sgc')) else 0 + dct['rra'] = (1 - dmap.rrz) if type(dmap.rrz) is float else 0 + dct['rrb'] = dmap.rrz if type(dmap.rrz) is float else 0 + return dct + + +def load_embedding(datastr, algo, algo_chn, + datapath="./data/", + inductive=False, multil=False, + seed=0, **kwargs): + # Lazy import of precompute Cython extension (only needed for decouple mode) + try: + from examples.unifews.precompute.prop import A2Prop + except ImportError as e: + raise RuntimeError( + "UNIFEWS decouple mode requires building examples/unifews/precompute. " + "Run: cd examples/unifews/precompute && python setup.py build_ext --inplace" + ) from e + + dp = DataProcess(datastr, path=datapath, seed=seed) + dp.input(['labels', 'attr_matrix', 'deg']) + if inductive: + dpi = DataProcess_inductive(datastr, path=datapath, seed=seed) + dpi.input(['attr_matrix', 'deg']) + else: + dpi = dp + + if (datastr.startswith('cora') or datastr.startswith('citeseer') or datastr.startswith('pubmed')): + dp.calculate(['idx_train']) + else: + dp.input(['idx_train', 'idx_val', 'idx_test']) + idx = {'train': tlx.convert_to_tensor(dp.idx_train, tlx.int64), + 'val': tlx.convert_to_tensor(dp.idx_val, tlx.int64), + 'test': tlx.convert_to_tensor(dp.idx_test, tlx.int64)} + + if multil: + dp.calculate(['labels_oh']) + dp.labels_oh[dp.labels_oh < 0] = 0 + labels = tlx.convert_to_tensor(dp.labels_oh, tlx.float32) + else: + dp.labels[dp.labels < 0] = 0 + labels = tlx.convert_to_tensor(dp.labels.flatten(), tlx.int64) + labels = {'train': labels[idx['train']], + 'val': labels[idx['val']], + 'test': labels[idx['test']]} + + n, m = dp.n, dp.m + nfeat, nclass = dp.nfeat, dp.nclass + if seed >= 15: + print(dp) + + py_a2prop = A2Prop() + py_a2prop.load(os.path.join(datapath, datastr), m, n, seed) + chn = dmap2dct(algo, DotMap(algo_chn), dp) + + feat = dp.attr_matrix.transpose().astype(np.float32, order='C') + macs_pre, time_pre = py_a2prop.compute(1, [chn], feat) + feat = feat.transpose() + feat = matstd_clip(feat, idx['train'], with_mean=True) + + feats = {'val': tlx.convert_to_tensor(feat[idx['val']], tlx.float32), + 'test': tlx.convert_to_tensor(feat[idx['test']], tlx.float32)} + feats['train'] = tlx.convert_to_tensor(feat[idx['train']], tlx.float32) + + del feat + gc.collect() + return feats, labels, idx, nfeat, nclass, macs_pre/1e9, time_pre diff --git a/examples/unifews/precompute/__init__.py b/examples/unifews/precompute/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/unifews/precompute/algprop.cpp b/examples/unifews/precompute/algprop.cpp new file mode 100644 index 000000000..49be97f5b --- /dev/null +++ b/examples/unifews/precompute/algprop.cpp @@ -0,0 +1,342 @@ +/* + * Author: nyLiao + * File Created: 2023-04-19 + * File: algprop.cpp + * Ref: [AGP](https://github.com/wanghzccls/AGP-Approximate_Graph_Propagation) + */ +#include "algprop.h" +using namespace std; +using namespace Eigen; + +// ==================== +double get_curr_time() { + long long time = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()).count(); + return static_cast(time) / 1000000.0; +} + +float get_proc_memory(){ + struct rusage r_usage; + getrusage(RUSAGE_SELF,&r_usage); + return r_usage.ru_maxrss/1000000.0; +} + +float get_stat_memory(){ + long rss; + std::string ignore; + std::ifstream ifs("/proc/self/stat", std::ios_base::in); + ifs >> ignore >> ignore >> ignore >> ignore >> ignore >> ignore >> ignore >> ignore >> ignore >> ignore + >> ignore >> ignore >> ignore >> ignore >> ignore >> ignore >> ignore >> ignore >> ignore >> ignore + >> ignore >> ignore >> ignore >> rss; + + long page_size_kb = sysconf(_SC_PAGE_SIZE) / 1024; + return rss * page_size_kb / 1000000.0; +} + +inline void update_maxr(const float r, float &maxrp, float &maxrn) { + if (r > maxrp) + maxrp = r; + else if (r < maxrn) + maxrn = r; +} + +// ==================== +namespace algprop { +// Load graph related data +void A2prop::load(string dataset, uint mm, uint nn, uint seedd) { + m = mm; + n = nn; + seed = seedd; + + // Load graph adjacency + el = vector(m); // edge list sorted by source node degree + pl = vector(n + 1); + string dataset_el = dataset + "/adj_el.bin"; + const char *p1 = dataset_el.c_str(); + if (FILE *f1 = fopen(p1, "rb")) { + size_t rtn = fread(el.data(), sizeof el[0], el.size(), f1); + if (rtn != m) + cout << "Error! " << dataset_el << " Incorrect read!" << endl; + fclose(f1); + } else { + cout << dataset_el << " Not Exists." << endl; + exit(1); + } + string dataset_pl = dataset + "/adj_pl.bin"; + const char *p2 = dataset_pl.c_str(); + if (FILE *f2 = fopen(p2, "rb")) { + size_t rtn = fread(pl.data(), sizeof pl[0], pl.size(), f2); + if (rtn != n + 1) + cout << "Error! " << dataset_pl << " Incorrect read!" << endl; + fclose(f2); + } else { + cout << dataset_pl << " Not Exists." << endl; + exit(1); + } + + deg = Eigen::ArrayXf::Zero(n); + for (uint i = 0; i < n; i++) { + deg(i) = pl[i + 1] - pl[i]; + if (deg(i) <= 0) { + deg(i) = 1; + // cout << i << " "; + } + } +} + +// Computation call entry +float A2prop::compute(uint nchnn, Channel* chnss, Eigen::Map &feat, float &ttime) { + // Node-specific array + chns = chnss; + assert(nchnn <= 4); + dega = Eigen::ArrayXf::Zero(n); + dinva = Eigen::ArrayXf::Zero(n); + dinvb = Eigen::ArrayXf::Zero(n); + for (uint c = 0; c < nchnn; c++) { + dega = deg.pow(chns[c].rra); + dinva = 1 / dega; + dinvb = 1 / deg.pow(chns[c].rrb); + } + + // Feat is ColMajor, shape: (n, c*F) + int fsum = feat.cols(); + int it = 0; + map_feat = Eigen::ArrayXf::LinSpaced(fsum, 0, fsum - 1); + // random_shuffle(map_feat.data(), map_feat.data() + map_feat.size()); + cout << "feat dim: " << feat.cols() << ", nodes: " << feat.rows() << ", edges: " << m << ". "; + + // Feature-specific array + dlt_p = Eigen::ArrayXf::Zero(fsum); + dlt_n = Eigen::ArrayXf::Zero(fsum); + maxf_p = Eigen::ArrayXf::Zero(fsum); + maxf_n = Eigen::ArrayXf::Zero(fsum); + map_chn = Eigen::ArrayXi::Zero(fsum); + macs = Eigen::ArrayXf::Zero(fsum); + // Loop each feature index `it`, inside channel index `i` + for (uint c = 0; c < nchnn; c++) { + for (int i = 0; i < chns[c].dim; i++) { + for (uint u = 0; u < n; u++) { + if (feat(u, i) > 0) + dlt_p(it) += feat(u, it) * pow(deg(u), chns[c].rrb); + else + dlt_n(it) += feat(u, it) * pow(deg(u), chns[c].rrb); + update_maxr(feat(u, it), maxf_p(it), maxf_n(it)); + } + if (dlt_p(it) == 0) + dlt_p(it) = 1e-12; + if (dlt_n(it) == 0) + dlt_n(it) = -1e-12; + dlt_p(it) *= chns[c].delta / (1 - chns[c].alpha); + dlt_n(it) *= chns[c].delta / (1 - chns[c].alpha); + map_chn(it) = c; + it++; + } + } + + // Begin propagation + cout << "Propagating..." << endl; + struct timeval ttod_start, ttod_end; + double ttod, tclk; + gettimeofday(&ttod_start, NULL); + tclk = get_curr_time(); + int dim_top = 0; + int start, ends = dim_top; + + vector threads; + for (it = 1; it <= fsum % NUMTHREAD; it++) { + start = ends; + ends += ceil((float)fsum / NUMTHREAD); + if (chns[0].type < 0) + threads.push_back(thread(&A2prop::feat_ori, this, feat, start, ends)); + else + threads.push_back(thread(&A2prop::feat_chn, this, feat, start, ends)); + } + for (; it <= NUMTHREAD; it++) { + start = ends; + ends += fsum / NUMTHREAD; + if (chns[0].type < 0) + threads.push_back(thread(&A2prop::feat_ori, this, feat, start, ends)); + else + threads.push_back(thread(&A2prop::feat_chn, this, feat, start, ends)); + } + for (int t = 0; t < NUMTHREAD; t++) + threads[t].join(); + vector().swap(threads); + + tclk = get_curr_time() - tclk; + gettimeofday(&ttod_end, NULL); + ttod = ttod_end.tv_sec - ttod_start.tv_sec + (ttod_end.tv_usec - ttod_start.tv_usec) / 1000000.0; + cout << "[Pre] Prop time: " << ttod << " s, "; + cout << "Clock time: " << tclk << " s, "; + cout << "Max PRAM: " << get_proc_memory() << " GB, "; + cout << "End RAM: " << get_stat_memory() << " GB, "; + cout << "MACs: " << macs.sum()/1e9 << " G" << endl; + ttime = ttod; + return macs.sum(); +} + +// ==================== +// Feature embs +void A2prop::feat_chn(Eigen::Ref feats, int st, int ed) { + uint seedt = seed; + Eigen::VectorXf res0(n), res1(n); + Eigen::Map rprev(res1.data(), n), rcurr(res0.data(), n); + + // Loop each feature `ift`, index `it` + for (int it = st; it < ed; it++) { + const uint ift = map_feat(it); + const Channel chn = chns[0]; + const float alpha = chn.alpha; + vector plshort(pl), plshort2(pl); + Eigen::Map feati(feats.col(ift).data(), n); + + const float dlti_p = dlt_p(ift); + const float dlti_n = dlt_n(ift); + const float dltinv_p = 1 / dlti_p; + const float dltinv_n = 1 / dlti_n; + float maxr_p = maxf_p(ift); // max positive residue + float maxr_n = maxf_n(ift); // max negative residue + uint maccnt = 0; + + // Init residue + res1.setZero(); + res0 = feats.col(ift); + feati.setZero(); + rprev = res1; + rcurr = res0; + + // Loop each hop `il` + int il; + for (il = 0; il < chn.hop; il++) { + // Early termination + if ((maxr_p <= dlti_p) && (maxr_n >= dlti_n)) + break; + rcurr.swap(rprev); + rcurr.setZero(); + + // Loop each node `u` + for (uint u = 0; u < n; u++) { + const float old = rprev(u); + float thr_p = old * dltinv_p; + float thr_n = old * dltinv_n; + // if ((!chn.is_acc) && (m < 1e9)) { + if (!chn.is_acc) { + rcurr(u) += old; + } + if (thr_p > 1 || thr_n > 1) { + float oldb = 0; + if (chn.is_acc) { + feati(u) += old * alpha; + oldb = old * (1-alpha) * dinvb(u); + } + + // Loop each neighbor index `im`, node `v` + uint iv, iv2; + const uint ivmax = (chn.is_thr) ? plshort[u+1] : pl[u+1]; + for (iv = pl[u]; iv < ivmax; iv++) { + const uint v = el[iv]; + const float da_v = dega(v); + if (thr_p > da_v || thr_n > da_v) { + maccnt++; + if (chn.is_acc) + rcurr(v) += oldb * dinva(v); + else + rcurr(v) += old / deg(v); + update_maxr(rcurr(v), maxr_p, maxr_n); + } else { + // plshort[u+1] = iv; + break; + } + } + + iv2 = iv; + const float ran = (float)RAND_MAX / (rand_r(&seedt) % RAND_MAX); + thr_p *= ran; + thr_n *= ran; + const uint ivmax2 = (chn.is_thr) ? plshort2[u+1] : pl[u+1]; + for (; iv < ivmax2; iv++) { + const uint v = el[iv]; + const float da_v = dega(v); + if (thr_p > da_v) { + maccnt++; + rcurr(v) += dlti_p * dinva(v); + update_maxr(rcurr(v), maxr_p, maxr_n); + } else if (thr_n > da_v) { + maccnt++; + rcurr(v) += dlti_n * dinva(v); + update_maxr(rcurr(v), maxr_p, maxr_n); + } else { + break; + } + } + plshort[u+1] = (iv + iv2) / 2; + if (m < 1e8){ + plshort2[u+1] = (iv + pl[u+1]) / 2; + } + + } else { + if (chn.is_acc) + feati(u) += old; + } + } + } + + feati += rcurr; + macs(ift) += (float)maccnt; + } +} + + +void A2prop::feat_ori(Eigen::Ref feats, int st, int ed) { + Eigen::VectorXf res0(n), res1(n); + Eigen::Map rprev(res1.data(), n), rcurr(res0.data(), n); + + // Loop each feature `ift`, index `it` + for (int it = st; it < ed; it++) { + const uint ift = map_feat(it); + const Channel chn = chns[0]; + const float alpha = chn.alpha; + Eigen::Map feati(feats.col(ift).data(), n); + uint maccnt = 0; + + // Init residue + res1.setZero(); + res0 = feats.col(ift); + feati.setZero(); + rprev = res1; + rcurr = res0; + + // Loop each hop `il` + int il; + for (il = 0; il < chn.hop; il++) { + rcurr.swap(rprev); + rcurr.setZero(); + + // Loop each node `u` + for (uint u = 0; u < n; u++) { + const float old = rprev(u); + float oldb = 0; + if (chn.is_acc) { + feati(u) += old * alpha; + oldb = old * (1-alpha) * dinvb(u); + } + + // Loop each neighbor index `im`, node `v` + uint iv; + for (iv = pl[u]; iv < pl[u+1]; iv++) { + const uint v = el[iv]; + maccnt++; + if (chn.is_acc) + rcurr(v) += oldb * dinva(v); + else + rcurr(v) += old / deg(v); + } + } + } + + feati += rcurr; + macs(ift) += (float)maccnt; + } +} + +} // namespace propagation diff --git a/examples/unifews/precompute/algprop.h b/examples/unifews/precompute/algprop.h new file mode 100644 index 000000000..e9081af2a --- /dev/null +++ b/examples/unifews/precompute/algprop.h @@ -0,0 +1,80 @@ +/* + * Author: nyLiao + * File Created: 2023-04-19 + * File: algprop.h + */ +#ifndef ALGPROP_H +#define ALGPROP_H +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#pragma warning(push, 0) +#include +#pragma warning(pop) + +using namespace std; +using namespace Eigen; +typedef unsigned int uint; + +namespace algprop { + const int NUMTHREAD = 32; // Number of threads + + struct Channel { // channel scheme + int type; + // -2: SGC, -3: APPNP + // 0: SGC_AGP, 1: APPNP_AGP + // 2: SGC_thr, 3: APPNP_thr + bool is_thr; // is threshold + bool is_acc; // is accumulate (APPNP) + + int hop; // propagation hop + int dim; // feature dimension + float delta; // absolute error + float alpha; // summation decay, alpha=0 for SGC + float rra, rrb; // left & right normalization + }; + + class A2prop{ + public: + uint m,n,seed; // edges, nodes, seed + vector el; + vector pl; + Eigen::ArrayXf map_feat; // permuted index -> index in feats + Eigen::ArrayXi map_chn; // index in chns -> channel type + Eigen::ArrayXf macs; // MACs per feature + + Channel* chns; // channel schemes + Eigen::ArrayXf deg; // node degree vector + Eigen::ArrayXf dega, dinva; // left-norm degree, inversed deg_a + Eigen::ArrayXf dinvb; // right-norm degree, inversed deg_b + Eigen::ArrayXf dlt_p, dlt_n; // absolute error (positive, negative) + Eigen::ArrayXf maxf_p, maxf_n; // max feature coefficient + + void load(string dataset, uint mm, uint nn, uint seedd); + float compute(uint nchnn, Channel* chnss, Eigen::Map &feat, float &time); + + void feat_chn(Eigen::Reffeats,int st,int ed); + void feat_ori(Eigen::Reffeats,int st,int ed); + }; +} + +#endif // ALGPROP_H diff --git a/examples/unifews/precompute/prop.pxd b/examples/unifews/precompute/prop.pxd new file mode 100644 index 000000000..b13245fe5 --- /dev/null +++ b/examples/unifews/precompute/prop.pxd @@ -0,0 +1,28 @@ +# cython: language_level=3 +from libcpp.string cimport string +from libcpp cimport bool +from eigency.core cimport * + +ctypedef unsigned int uint + +cdef extern from "algprop.cpp": + pass + +cdef extern from "algprop.h" namespace "algprop": + cdef struct Channel: + int type + bool is_thr + bool is_acc + + int hop + int dim + float delta + float alpha + float rra + float rrb + + cdef cppclass A2prop: + A2prop() except+ + # dataset, m, n, seed + void load(string, uint, uint, uint) + float compute(uint, Channel*, Map[MatrixXf] &, float &) diff --git a/examples/unifews/precompute/prop.pyx b/examples/unifews/precompute/prop.pyx new file mode 100644 index 000000000..d244c5db1 --- /dev/null +++ b/examples/unifews/precompute/prop.pyx @@ -0,0 +1,32 @@ +from libc.stdlib cimport malloc, free +from prop cimport A2prop, Channel + +cdef class A2Prop: + cdef A2prop c_a2prop + + def __cinit__(self): + self.c_a2prop = A2prop() + + def load(self, str dataset, unsigned int m, unsigned int n, unsigned int seed): + self.c_a2prop.load(dataset.encode(), m, n, seed) + + def compute(self, unsigned int nchn, chns, np.ndarray feat): + cdef: + Channel* c_chns = malloc(nchn * sizeof(Channel)) + float res1, res2 + res2 = 0.0 + for i in range(nchn): + c_chns[i].type = chns[i]['type'] + c_chns[i].is_thr = (chns[i]['type'] > 1) + c_chns[i].is_acc = (chns[i]['type'] % 2 == 1) + + c_chns[i].hop = chns[i]['hop'] + c_chns[i].dim = chns[i]['dim'] + c_chns[i].delta = chns[i]['delta'] + c_chns[i].alpha = chns[i]['alpha'] + c_chns[i].rra = chns[i]['rra'] + c_chns[i].rrb = chns[i]['rrb'] + + res1 = self.c_a2prop.compute(nchn, c_chns, Map[MatrixXf](feat), res2) + free(c_chns) + return res1, res2 diff --git a/examples/unifews/precompute/setup.py b/examples/unifews/precompute/setup.py new file mode 100644 index 000000000..d6c29d4d4 --- /dev/null +++ b/examples/unifews/precompute/setup.py @@ -0,0 +1,18 @@ +from distutils.core import setup,Extension +from Cython.Build import cythonize +import eigency + +setup( + author='nyLiao', + version='0.0.1', + install_requires=['Cython>=0.2.15','eigency>=1.77'], + packages=['precompute'], + python_requires='>=3', + ext_modules=cythonize(Extension( + name='prop', + sources=['prop.pyx'], + language='c++', + extra_compile_args=["-std=c++11", "-O3", "-fopenmp"], + include_dirs=["."] + eigency.get_includes(), + )) +) diff --git a/examples/unifews/readme.md b/examples/unifews/readme.md new file mode 100644 index 000000000..65103e507 --- /dev/null +++ b/examples/unifews/readme.md @@ -0,0 +1,131 @@ +# UNIFEWS: You Need Fewer Operations for Efficient Graph Neural Networks + +- Paper link: [UNIFEWS: You Need Fewer Operations for Efficient Graph Neural Networks](https://arxiv.org/abs/2403.13268) +- Author's code repo: [https://github.com/gdmnl/Unifews](https://github.com/gdmnl/Unifews) + +## Verified Backend + +Currently verified: **Torch** only. Other backends (Paddle, TensorFlow, MindSpore) are not yet tested. + +## Dependencies + +```bash +pip install dotmap +``` + +Optional for FLOPs counting: +```bash +pip install ptflops +``` + +## Dataset Statistics + +| Dataset | # Nodes | # Edges | # Classes | +|----------|---------|---------|-----------| +| Cora | 2,708 | 10,556 | 7 | +| Citeseer | 3,327 | 9,228 | 6 | +| Pubmed | 19,717 | 88,651 | 3 | + +## Iterate Mode (run_single.py) + +Iterate mode trains GNNs with integrated pruning. **No Cython compilation required.** + +```bash +ALGO=gat_unifews # gat, gcn, gcn_unifews, gcn2, gcn2_unifews, gsage, gsage_unifews +for DATASTR in cora citeseer pubmed +do + for THRA in 0.0e+00 5.0e-02 + do + for THRW in 0.0e+00 5.0e-02 1.0e-01 + do + for SEED in 42 + do + OUTDIR=./save/${DATASTR}/${ALGO}/${SEED}-${THRA}-${THRW} + mkdir -p ${OUTDIR} + OUTFILE=${OUTDIR}/out.txt + TL_BACKEND="torch" python -u run_single.py --seed ${SEED} --dev ${1:--1} --algo ${ALGO} --thr_a ${THRA} --thr_w ${THRW} --data ${DATASTR} --path ./data/ --epochs 200 --patience 20 --lr 0.001 --weight_decay 1e-5 --layer 2 --hidden 512 --dropout 0.5 >> ${OUTFILE} & + echo $! && wait + done + done + done +done +``` + + +## Decouple Mode (run_mlp.py) + +Decouple mode uses precomputed propagation matrices with Cython acceleration. +**Requires building the Cython extension first:** + +```bash +pip install Cython eigency +cd examples/unifews/precompute +python setup.py build_ext --inplace +cd .. +``` + +Then run the experiment: + +```bash +for DATASTR in cora +do + for ALGO in gbp_thr # sgc + do + for THRA in 0.0e+00 1.0e-06 2.0e-06 1.0e-05 5.0e-05 + do + for THRW in 0.0e+00 5.0e-02 1.5e-01 3.0e-01 5.0e-01 7.0e-01 + do + for SEED in 42 + do + OUTDIR=./save/${DATASTR}/${ALGO}/${SEED}-${THRA}-${THRW} + mkdir -p ${OUTDIR} + OUTFILE=${OUTDIR}/out.txt + TL_BACKEND="torch" python -u run_mlp.py --seed ${SEED} --dev ${1:--1} --algo ${ALGO} --thr_a ${THRA} --thr_w ${THRW} --data ${DATASTR} --path ./data/ --epochs 200 --patience 20 --batch 512 --lr 0.01 --weight_decay 1e-5 --layer 2 --hidden 512 --dropout 0.5 --inductive false --multil false --hop 20 --alpha 0.1 --rrz 0.5 >> ${OUTFILE} & + echo $! && wait + done + done + done + done +done +``` + +## Data Preparation + +Standard Cora/Citeseer/Pubmed datasets are downloaded automatically. + +For OGB or PyG datasets, use the conversion script: +```bash +python convert_data.py --dataset ogbn-papers100M --dataset-root /path/to/dataset --output-path ./data/ +``` + +## Results + +Results below are from single-seed runs. Backend: Torch, seed: 42. + +### GAT +| dataset | paper original | our original | paper unifews | our unifews | +|----------|---------------|-------------|--------------|------------| +| cora | 86.44 (+/-0.55) | 87.73 (+/-0.79) | 86.20 (+/-1.14) | 88.59 | +| citeseer | 71.55 (+/-1.52) | 73.19 (+/-0.56) | 69.97 (+/-2.63) | 73.71 | +| pubmed | 84.56 (+/-0.37) | 84.34 (+/-0.41) | 80.08 (+/-4.65) | 84.44 | + +### GCN +| dataset | paper original | our original | paper unifews | our unifews | +|----------|---------------|-------------|--------------|------------| +| cora | 88.37 (+/-0.95) | 82.74 (+/-1.42) | 87.06 (+/-1.46) | 87.94 | +| citeseer | 74.07 (+/-0.91) | 72.23 (+/-0.62) | 71.37 (+/-3.50) | 74.91 | +| pubmed | 84.75 (+/-0.51) | 83.41 (+/-0.51) | 81.68 (+/-6.07) | 84.58 | + +### GCNII +| dataset | paper original | our original | paper unifews | our unifews | +|----------|---------------|-------------|--------------|------------| +| cora | 88.59 (+/-0.52) | 82.74 (+/-1.42) | 87.06 (+/-0.51) | 85.13 | +| citeseer | 75.35 (+/-0.80) | 73.23 | 71.03 (+/-2.06) | 72.15 | +| pubmed | 85.77 (+/-0.65) | 88.28 | 88.38 (+/-0.29) | 89.62 | + +### GraphSAGE +| dataset | paper original | our original | paper unifews | our unifews | +|----------|---------------|-------------|--------------|------------| +| cora | 88.21 (+/-0.15) | 88.59 | 86.03 (+/-1.89) | 88.26 | +| citeseer | 73.11 (+/-1.41) | 74.31 | 72.41 (+/-2.00) | 76.35 | +| pubmed | 88.22 (+/-0.56) | 88.01 | 84.54 (+/-2.92) | 88.49 | diff --git a/examples/unifews/run_mlp.py b/examples/unifews/run_mlp.py new file mode 100644 index 000000000..c34345de2 --- /dev/null +++ b/examples/unifews/run_mlp.py @@ -0,0 +1,251 @@ +import gc +import random +import argparse +import numpy as np + +try: + import ptflops + HAS_PTFLOPS = True +except ImportError: + HAS_PTFLOPS = False + +import tensorlayerx as tlx +from tensorlayerx import nn +from tensorlayerx.dataflow import Dataset, DataLoader +from tensorlayerx.model import WithLoss, TrainOneStep + +import sys +sys.path.append("/root/GammaGL") +from gammagl.utils.logger_unifews import Logger, ModelLogger +from examples.unifews.loader import load_embedding +import gammagl.utils.metric_unifews as metric +from gammagl.models.gnn_unifews import flops_modules_dict +import gammagl.models.mlp_unifews as models + +np.set_printoptions(linewidth=160, edgeitems=5, threshold=20, + formatter=dict(float=lambda x: "%.6e" % x)) + +# ========== Training settings ========== +parser = argparse.ArgumentParser() +parser.add_argument('-f', '--seed', type=int, default=11) +parser.add_argument('-v', '--dev', type=int, default=0) +parser.add_argument('-m', '--algo', type=str, default='sgc') +parser.add_argument('-n', '--suffix', type=str, default='') +parser.add_argument('-a', '--thr_a', type=float, default=0.0001) +parser.add_argument('-w', '--thr_w', type=float, default=0.1) +parser.add_argument('-l', '--layer', type=int, default=2) +parser.add_argument('-p', '--hop', type=int, default=2) +parser.add_argument('--data', type=str, default='cora') +parser.add_argument('--path', type=str, default='./data/') +parser.add_argument('--epochs', type=int, default=200) +parser.add_argument('--patience', type=int, default=20) +parser.add_argument('--lr', type=float, default=0.001) +parser.add_argument('--weight_decay', type=float, default=1e-5) +parser.add_argument('--hidden', type=int, default=512) +parser.add_argument('--dropout', type=float, default=0.5) +parser.add_argument('--batch', type=int, default=64) +parser.add_argument('--inductive', action='store_true', default=False) +parser.add_argument('--multil', action='store_true', default=False) +args = parser.parse_args() + +# Construct chn config (previously loaded from JSON) +args.chn = {'hop': args.hop, 'delta': args.thr_a, 'alpha': 0.0, 'rrz': 0.5} + +num_thread = 0 if args.data in ['cora', 'citeseer', 'pubmed'] else 8 +random.seed(args.seed) +np.random.seed(args.seed) +tlx.set_seed(args.seed) + +if '_' not in args.algo: args.thr_a, args.thr_w = 0.0, 0.0 +args.chn['delta'] = args.thr_a +args.chn['hop'] = args.hop if isinstance(args.hop, int) else args.chn['hop'] +flag_run = f"{args.seed}-{args.thr_a:.1e}-{args.thr_w:.1e}" + +logger = Logger(args.data, args.algo, flag_run=flag_run) +logger.save_opt(args) +model_logger = ModelLogger(logger, patience=args.patience, cmp='max', + prefix=f'model{args.suffix}', storage='state') +stopwatch = metric.Stopwatch() + +# ========== Data Load ========== +feat, labels, idx, nfeat, nclass, macs_pre, time_pre = load_embedding( + datastr=args.data, datapath=args.path, algo=args.algo, algo_chn=args.chn, + inductive=args.inductive, multil=args.multil, seed=args.seed +) + +# ========== Model Init ========== +model = models.MLP_unifews( + nlayer=args.layer, nfeat=nfeat, nhidden=args.hidden, nclass=nclass, + thr_w=args.thr_w, dropout=args.dropout, layer=args.algo, +) + +# Trigger model build with a dummy forward pass +dummy_x = tlx.convert_to_tensor(np.random.randn(1, nfeat), dtype=tlx.float32) +_ = model(dummy_x, edge_idx=None) +model.reset_parameters() + +if logger.lvl_config > 1: + print(type(model).__name__, args.algo, args.thr_a, args.thr_w) +model_logger.register(model, save_init=False) + +# ========== Training Components ========== +loss_fn = tlx.losses.sigmoid_cross_entropy_with_logits if args.multil else tlx.losses.softmax_cross_entropy_with_logits +net_with_loss = WithLoss(model, loss_fn) + +train_weights = model.trainable_weights + +optimizer = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.weight_decay) +train_one_step = TrainOneStep(net_with_loss, optimizer, train_weights) + +best_val_acc = 0.0 +patience_counter = 0 +lr_factor = 0.5 +lr_patience = 15 + +# ========== Dataset ========== +class TensorDataset(Dataset): + def __init__(self, data, label): + self.data = tlx.convert_to_tensor(data, dtype=tlx.float32) + self.label = tlx.convert_to_tensor(label, dtype=tlx.float32 if args.multil else tlx.int64) + def __getitem__(self, idx): return self.data[idx], self.label[idx] + def __len__(self): return len(self.data) + +loader_train = DataLoader(TensorDataset(feat['train'], labels['train']), batch_size=args.batch, shuffle=True, num_workers=num_thread) +loader_val = DataLoader(TensorDataset(feat['val'], labels['val']), batch_size=args.batch, shuffle=False, num_workers=num_thread) +loader_test = DataLoader(TensorDataset(feat['test'], labels['test']), batch_size=args.batch, shuffle=False, num_workers=num_thread) + +# ========== Train Func ========== +def train(epoch, ld=loader_train, verbose=False): + model.train() + loss_list = [] + stopwatch.reset() + + for it, (x, y) in enumerate(ld): + if it == 0: + model.set_scheme('pruneall' if epoch < args.epochs // 2 else 'pruneinc', 'pruneall') + else: + model.set_scheme('keep', 'keep') + + stopwatch.start() + loss_batch = train_one_step(x, y) + stopwatch.pause() + loss_list.append(float(loss_batch)) + + gc.collect() + return np.mean(loss_list), stopwatch.time + +# ========== Eval Func ========== +def eval(ld, verbose=False): + model.eval() + model.set_scheme('keep', 'keep') + calc = metric.F1Calculator(nclass) + stopwatch.reset() + + for x, y in ld: + stopwatch.start() + output = model(x, edge_idx=None) + stopwatch.pause() + + if args.multil: + output = tlx.where(output > 0, 1, 0) + else: + output = tlx.argmax(output, axis=1) + + calc.update(y, output) + + return calc.compute('micro'), stopwatch.time, None, None + +# ========== FLOPs Func ========== +def cal_flops(ld, verbose=False): + if not HAS_PTFLOPS: + return 0.0, 0.0 + model.eval() + model.set_scheme('keep', 'keep') + macs, nparam = ptflops.get_model_complexity_info( + model, (nfeat,), custom_modules_hooks=flops_modules_dict, + as_strings=False, print_per_layer_stat=verbose, verbose=verbose + ) + return macs / 1e9, nparam / 1e3 + +# ========== Train Loop ========== +time_tol, macs_tol = metric.Accumulator(), metric.Accumulator() +epoch_conv, acc_best = 0, 0 + +for epoch in range(1, args.epochs + 1): + verbose = (epoch % 1 == 0) and (logger.lvl_log > 0) + loss_train, time_epoch = train(epoch, ld=loader_train, verbose=verbose) + time_tol.update(time_epoch) + + acc_val, _, _, _ = eval(loader_val) + + if acc_val > best_val_acc: + best_val_acc = acc_val + patience_counter = 0 + else: + patience_counter += 1 + if patience_counter >= lr_patience: + optimizer.lr *= lr_factor + patience_counter = 0 + if logger.lvl_log > 1: + logger.print(f"Reducing learning rate to {optimizer.lr:.6f}") + + macs_epoch, _ = cal_flops(loader_val) + macs_tol.update(macs_epoch * len(idx['train'])) + + if verbose and logger.lvl_config > 1: + res = f"Epoch:{epoch:04d} | train loss:{loss_train:.4f}, val acc:{acc_val:.4f}, time:{time_tol.val:.4f}, macs:{macs_tol.val:.4f}" + logger.print(res) + + acc_best = model_logger.save_best(acc_val, epoch) + if not model_logger.is_early_stop(epoch): + epoch_conv = epoch - model_logger.patience + +acc_test, time_test, _, _ = eval(loader_test) +macs_test, _ = cal_flops(loader_val) +macs_test *= len(idx['test']) + +# Final metrics +n = len(idx['train']) + len(idx['val']) + len(idx['test']) +r_train, r_test = len(idx['train']) / n, len(idx['test']) / n +macs_wtr, macs_wte = macs_tol.val, macs_test +macs_tol.update(macs_pre * r_train, 0) +macs_test += macs_pre * r_test +numel_a = macs_pre * 1e6 / nfeat + +def get_numel_safe(m): + try: + res = m.get_numel() + return res[1] if isinstance(res, tuple) else res + except: + if hasattr(m, 'trainable_weights'): + return sum(np.prod(p.shape) for p in m.trainable_weights) / 1000.0 + return 0.0 + +numel_w = get_numel_safe(model) + +# ========== Logging ========== +if logger.lvl_config > 0: + print(f"[Val] best acc: {acc_best:.5f} (epoch: {epoch_conv}/{epoch}), [Test] acc: {acc_test:.5f}", flush=True) +if logger.lvl_config > 1: + print(f"[Pre] time: {time_pre:.4f} s, MACs: {macs_pre:.4f} G") + print(f"[Train] time: {time_tol.val:.4f} s, avg: {time_tol.avg*100:.1f} ms, MACs: {macs_tol.val:.3f} G, avg: {macs_tol.avg:.1f} G") + print(f"[Test] time: {time_test:.4f} s, MACs: {macs_test:.4f} G, Num adj: {numel_a:.3f} k, Num weight: {numel_w:.3f} k") + print(f"Train MACs: {macs_wtr:.4f} G, Pre MACs: {macs_pre:.4f} G, Test MACs: {macs_wte:.4f} G") +if logger.lvl_config > 2: + + import os + save_dir = './save' + os.makedirs(save_dir, exist_ok=True) + + logger_tab = Logger(args.data, args.algo, flag_run=flag_run, dir=save_dir) + logger_tab.file_log = os.path.join(save_dir, f'log_mb_{flag_run}.csv') + + hstr, cstr = logger_tab.str_csvg( + data=args.data, algo=args.algo, seed=args.seed, thr_a=args.thr_a, thr_w=args.thr_w, + acc_test=acc_test, conv_epoch=epoch_conv, epoch=epoch, time_train=time_tol.val, + macs_train=macs_tol.val, macs_a=macs_pre, macs_wtr=macs_wtr, macs_wte=macs_wte, + time_test=time_test, macs_test=macs_test, numel_a=numel_a, numel_w=numel_w, + hop=args.chn['hop'], layer=args.layer, time_pre=time_pre + ) + logger_tab.print_header(hstr, cstr) + \ No newline at end of file diff --git a/examples/unifews/run_single.py b/examples/unifews/run_single.py new file mode 100644 index 000000000..d74298547 --- /dev/null +++ b/examples/unifews/run_single.py @@ -0,0 +1,207 @@ +import os +import random +import argparse +import numpy as np + +import tensorlayerx as tlx +from tensorlayerx import nn +from tensorlayerx.model import WithLoss, TrainOneStep + +from gammagl.utils.logger_unifews import Logger, ModelLogger, prepare_opt +from gammagl.utils.loader_unifews import load_edgelist +import gammagl.utils.metric_unifews as metric +from gammagl.layers.conv.gcn_unifews import identity_n_norm +from gammagl.models.gnn_unifews import flops_modules_dict +import gammagl.models.mlp_unifews as mlp_model +import gammagl.models.gnn_unifews as gnn_model +import gammagl.models.gcn2_unifews as gcn2_model + + +class SemiSpvzLoss(WithLoss): + """Semi-supervised loss wrapper that only computes loss on training nodes.""" + def __init__(self, net, loss_fn): + super().__init__(backbone=net, loss_fn=loss_fn) + + def forward(self, data, label): + x, edge_idx, train_idx = data + logits = self._backbone(x, edge_idx) + train_logits = tlx.gather(logits, train_idx) + return self._loss_fn(train_logits, label) + + + +np.set_printoptions(linewidth=160, edgeitems=5, threshold=20, + formatter=dict(float=lambda x: "%9.3e" % x)) + +# ========== Training settings +parser = argparse.ArgumentParser() +parser.add_argument('-f', '--seed', type=int, default=11, help='Random seed.') +parser.add_argument('-v', '--dev', type=int, default=0, help='Device id.') +parser.add_argument('-c', '--config', type=str, default='cora', help='Config file name.') +parser.add_argument('-m', '--algo', type=str, default='gcn_unifews', help='Model name') +parser.add_argument('-n', '--suffix', type=str, default='', help='Save name suffix.') +parser.add_argument('-a', '--thr_a', type=float, default=0.5, help='Threshold of adj.') +parser.add_argument('-w', '--thr_w', type=float, default=0.5, help='Threshold of weight.') +parser.add_argument('-l', '--layer', type=int, default=2, help='Layer.') +parser.add_argument('--data', type=str, default='cora', help='dataset name') +parser.add_argument('--path', type=str, default='./data/', help='data path') +parser.add_argument('--epochs', type=int, default=200, help='number of epochs') +parser.add_argument('--patience', type=int, default=20, help='early stop patience') +parser.add_argument('--lr', type=float, default=0.001, help='learning rate') +parser.add_argument('--weight_decay', type=float, default=1e-5, help='weight decay') +parser.add_argument('--hidden', type=int, default=512, help='hidden dimension') +parser.add_argument('--dropout', type=float, default=0.5, help='dropout rate') +parser.add_argument('--inductive', action='store_true', default=False, help='inductive setting') +parser.add_argument('--multil', action='store_true', default=False, help='multi-label classification') + +#args = prepare_opt(parser) +args = parser.parse_args() + +if args.dev >= 0: + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.dev) + +# ========== random seed +random.seed(args.seed) +np.random.seed(args.seed) +tlx.set_seed(args.seed) + +if not ('_' in args.algo): + args.thr_a, args.thr_w = 0.0, 0.0 + + +flag_run = f"{args.seed}-{args.thr_a:.1e}-{args.thr_w:.1e}" +logger = Logger(args.data, args.algo, flag_run=flag_run) +logger.save_opt(args) +model_logger = ModelLogger(logger, + patience=args.patience, + cmp='max', + prefix='model'+args.suffix, + storage='state') +stopwatch = metric.Stopwatch() + +# ========== download data +adj, feat, labels, idx, nfeat, nclass = load_edgelist( + datastr=args.data, datapath=args.path, + inductive=args.inductive, multil=args.multil, seed=args.seed +) + +if args.algo.split('_')[0] in ['gcn2']: + model = gcn2_model.SandwitchThr( + nlayer=args.layer, nfeat=nfeat, nhidden=args.hidden, nclass=nclass, + thr_a=args.thr_a, thr_w=args.thr_w, dropout=args.dropout, layer=args.algo + ) +elif args.algo.split('_')[0] == 'mlp': + model = mlp_model.MLP_unifews( + nlayer=args.layer, nfeat=nfeat, nhidden=args.hidden, nclass=nclass, + thr_w=args.thr_w, dropout=args.dropout, layer='mlp' + ) +else: + model = gnn_model.GNNThr( + nlayer=args.layer, nfeat=nfeat, nhidden=args.hidden, nclass=nclass, + thr_a=args.thr_a, thr_w=args.thr_w, dropout=args.dropout, layer=args.algo + ) + +model.reset_parameters() + + +adj['train'] = identity_n_norm( + adj['train'], edge_weight=None, num_nodes=feat['train'].shape[0], + rnorm=1, diag=None +) + +if logger.lvl_config > 1: + print(type(model).__name__, args.algo, args.thr_a, args.thr_w) +if logger.lvl_config > 2: + print(model) + +model_logger.register(model, save_init=False) + +loss_fn = tlx.losses.sigmoid_cross_entropy if args.multil else tlx.losses.softmax_cross_entropy_with_logits +net_with_loss = SemiSpvzLoss(model, loss_fn) +optimizer = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.weight_decay) +train_one_step = TrainOneStep(net_with_loss, optimizer, model.trainable_weights) + +def train(x, edge_idx, y, idx_split, epoch, verbose=False): + model.train() + if epoch < args.epochs // 2: + model.set_scheme('pruneall', 'pruneall') + else: + model.set_scheme('pruneall', 'pruneinc') + + stopwatch.reset() + stopwatch.start() + loss = train_one_step([x, edge_idx, idx_split], y) + stopwatch.pause() + return float(loss), stopwatch.time + +def eval(x, edge_idx, y, idx_split, verbose=False): + model.eval() + model.set_scheme('keep', 'keep') + calc = metric.F1Calculator(nclass) + stopwatch.reset() + + stopwatch.start() + output = model(x, edge_idx, node_lock=tlx.convert_to_tensor([]))[idx_split] + stopwatch.pause() + + if args.multil: + output = tlx.where(output > 0, 1.0, 0.0) + else: + output = tlx.argmax(output, axis=1) + + calc.update(y, output) + res = calc.compute('micro') + return res, stopwatch.time, output, y + +def cal_flops(x, edge_idx, idx_split, verbose=False): + return 0.0 + + +time_tol, macs_tol = metric.Accumulator(), metric.Accumulator() +epoch_conv, acc_best = 0, 0 + +for epoch in range(1, args.epochs + 1): + verbose = epoch % 1 == 0 and (logger.lvl_log > 0) + loss_train, time_epoch = train( + x=feat['train'], edge_idx=adj['train'], + y=labels['train'], idx_split=idx['train'], epoch=epoch + ) + time_tol.update(time_epoch) + + acc_val, _, _, _ = eval( + x=feat['train'], edge_idx=adj['train'], + y=labels['val'], idx_split=idx['val'] + ) + + macs_epoch = cal_flops(feat['train'], adj['train'], idx['train']) + macs_tol.update(macs_epoch) + + if verbose: + res = f"Epoch:{epoch:04d} | loss:{loss_train:.4f}, val acc:{acc_val:.4f}, time:{time_tol.val:.4f}, macs:{macs_tol.val:.4f}" + logger.print(res) + + acc_best = model_logger.save_best(acc_val, epoch=epoch) + if not model_logger.is_early_stop(epoch=epoch): + epoch_conv = max(0, epoch - model_logger.patience) + + +model = model_logger.load('best') +adj['test'] = identity_n_norm( + adj['test'], edge_weight=None, num_nodes=feat['test'].shape[0], + rnorm=1, diag=None +) + +acc_test, time_test, outl, labl = eval( + x=feat['test'], edge_idx=adj['test'], + y=labels['test'], idx_split=idx['test'] +) + +macs_test = cal_flops(feat['test'], adj['test'], idx['test']) +numel_a, numel_w = model.get_numel() + + +print("="*60) +print(f"[Val] best acc: {acc_best:.5f}") +print(f"[Test] acc: {acc_test:.5f}") +print(f"[Test] MACs: {macs_test:.4f}G | Num adj: {numel_a:.3f}k | Num weight: {numel_w:.3f}k") +print("="*60) \ No newline at end of file diff --git a/gammagl/gglspeedup/prunes_gamma.py b/gammagl/gglspeedup/prunes_gamma.py new file mode 100644 index 000000000..580bb9e98 --- /dev/null +++ b/gammagl/gglspeedup/prunes_gamma.py @@ -0,0 +1,157 @@ +from abc import ABC, abstractmethod +import numpy as np +import tensorlayerx as tlx +from tensorlayerx.nn import Module + + +def safe_norm(x, p=2, axis=None, keepdims=False): + x_abs = tlx.abs(x) + if p == 1: + return tlx.reduce_sum(x_abs, axis=axis, keepdims=keepdims) + return tlx.sqrt(tlx.reduce_sum(x_abs * x_abs, axis=axis, keepdims=keepdims)) + + +def _tensor_numel(t): + return int(np.prod(tlx.get_tensor_shape(t))) + + +class BasePruningMethod(ABC): + PRUNING_TYPE = 'unstructured' + _tensor_name: str = None + + @abstractmethod + def compute_mask(self, t, default_mask): + pass + + @classmethod + def apply(cls, module, name, *args, importance_scores=None, **kwargs): + method = cls(*args, **kwargs) + method._tensor_name = name + + orig = getattr(module, name) + if importance_scores is None: + importance_scores = orig + + if not hasattr(module, name + "_orig"): + setattr(module, name + "_orig", tlx.convert_to_tensor(orig)) + default_mask = tlx.ones_like(orig) + else: + default_mask = getattr(module, name + "_mask") + + mask = method.compute_mask(importance_scores, default_mask) + setattr(module, name + "_mask", mask) + return method + + def remove(self, module): + if self._tensor_name: + for suffix in ["_orig", "_mask"]: + attr = self._tensor_name + suffix + if hasattr(module, attr): + delattr(module, attr) + + +class RandomUnstructured(BasePruningMethod): + PRUNING_TYPE = "unstructured" + + def __init__(self, amount) -> None: + self._validate_amount(amount) + self.amount = amount + + def compute_mask(self, t, default_mask): + tensor_size = _tensor_numel(t) + nparams_toprune = self._get_prune_count(self.amount, tensor_size) + nparams_toprune = max(0, min(nparams_toprune, tensor_size)) + + mask = tlx.ones_like(default_mask) + + if nparams_toprune != 0: + prob = tlx.random_uniform(shape=tlx.get_tensor_shape(t), dtype=t.dtype) + flat_prob = tlx.reshape(prob, (-1,)) + _, topk_indices = tlx.topk(flat_prob, k=nparams_toprune, largest=False) + flat_mask = tlx.reshape(mask, (-1,)) + flat_mask = tlx.scatter_update(flat_mask, topk_indices, + tlx.zeros_like(topk_indices, dtype=flat_mask.dtype)) + mask = tlx.reshape(flat_mask, tlx.get_tensor_shape(default_mask)) + return mask + + def _validate_amount(self, amount): + if isinstance(amount, float): + if not (0.0 <= amount <= 1.0): + raise ValueError("The ratio of pruning must range from 0 to 1.") + elif isinstance(amount, int): + if amount < 0: + raise ValueError("Pruning quantity cannot be negative.") + else: + raise TypeError("amount must be either an integer for absolute quantity or a float for pruning ratio.") + + def _get_prune_count(self, amount, tensor_size): + if isinstance(amount, int): + return amount + return int(amount * tensor_size) + + +class prune: + @staticmethod + def is_pruned(module: Module) -> bool: + return any(key.endswith("_mask") for key in dir(module)) + + @staticmethod + def remove(module: Module, name: str): + for suffix in ["_orig", "_mask"]: + attr = name + suffix + if hasattr(module, attr): + delattr(module, attr) + + RandomUnstructured = RandomUnstructured + + +def prune_threshold(x, threshold=1e-3): + norm_vals = safe_norm(x, axis=1) / x.shape[1] + idx_0 = norm_vals < threshold + x = tlx.where(idx_0, tlx.zeros_like(x), x) + return x, idx_0 + + +def prune_topk(x, k=0.2): + num_0 = int(x.shape[0] * k) + x_norm = safe_norm(x, axis=1) + _, idx_0 = tlx.topk(x_norm, num_0) + mask = tlx.ones((x.shape[0],), dtype=tlx.bool) + mask[idx_0] = False + x = tlx.where(mask[:, None], x, tlx.zeros_like(x)) + return x, idx_0 + + +def rewind(module: Module, name: str): + orig_name = name + "_orig" + mask_name = name + "_mask" + if hasattr(module, orig_name): + delattr(module, orig_name) + delattr(module, mask_name) + + +class ThrInPrune(BasePruningMethod): + PRUNING_TYPE = 'structured' + + def __init__(self, threshold, dim=0): + self.threshold = threshold + self.dim = dim + + def compute_mask(self, t, default_mask): + tmax = tlx.reduce_max(tlx.abs(t)) * (1 - 1e-3) + threshold = tlx.where(self.threshold > tmax, tmax, self.threshold) + mask = tlx.ones_like(default_mask) + mask = tlx.where(tlx.abs(t) < threshold, 0.0, mask) + return mask + + +class ThrProdPrune(BasePruningMethod): + PRUNING_TYPE = 'unstructured' + + def __init__(self, threshold): + self.threshold = threshold + + def compute_mask(self, t, default_mask): + mask = tlx.ones_like(default_mask) + mask = tlx.where(tlx.abs(t) < self.threshold, 0.0, mask) + return mask diff --git a/gammagl/layers/conv/__init__.py b/gammagl/layers/conv/__init__.py index 442db13a0..583f0aea6 100644 --- a/gammagl/layers/conv/__init__.py +++ b/gammagl/layers/conv/__init__.py @@ -36,6 +36,15 @@ from .dhn_conv import DHNConv from .dna_conv import DNAConv from .rohehan_conv import RoheHANConv +from .gcn_unifews import ConvThr +from .gat_unifews import GATv2ConvThr +from .gat_unifews import GATv2ConvRaw +from .gcn2_unifews import GCNIIConvRaw +from .gcn2_unifews import GCNIIConvThr +from .gcn_unifews import GCNConvRaw +from .gcn_unifews import GCNConvThr +from .gsage_unifews import SAGEConvRaw +from .gsage_unifews import SAGEConvThr __all__ = [ 'MessagePassing', @@ -75,7 +84,16 @@ 'HEATlayer', 'DHNConv', 'DNAConv', - 'RoheHANConv' + 'RoheHANConv', + 'ConvThr', + 'GCNConvRaw', + 'GCNConvThr', + 'GATv2ConvRaw', + 'GATv2ConvThr', + 'GCNIIConvRaw', + 'GCNIIConvThr', + 'SAGEConvRaw', + 'SAGEConvThr' ] classes = __all__ diff --git a/gammagl/layers/conv/gat_unifews.py b/gammagl/layers/conv/gat_unifews.py new file mode 100644 index 000000000..5e52c859e --- /dev/null +++ b/gammagl/layers/conv/gat_unifews.py @@ -0,0 +1,206 @@ +import os +from math import log +import numpy as np +from typing import Optional, Tuple, Union, Any + + +import tensorlayerx as tlx +import tensorlayerx.nn as nn +from tensorlayerx.nn import Linear +import gammagl as ggl + +Tensor = Any +Adj = Tensor +OptTensor = Optional[Tensor] +PairTensor = Tuple[Tensor, OptTensor] + +from gammagl.layers.conv import (GATV2Conv,MessagePassing) +GATv2Conv = GATV2Conv + +from gammagl.gglspeedup.prunes_gamma import ThrInPrune, rewind,prune, _tensor_numel +from gammagl.utils.logger_unifews import LayerNumLogger +from .gcn_unifews import identity_n_norm, gcn_norm, softmax, leaky_relu, scatter, maybe_num_nodes, reset_weight_, reset_bias_,ConvThr,add_remaining_self_loops,pow_with_pinv,normalize,norm + + +class GATv2ConvRaw(GATV2Conv): + def __init__(self, in_channels: int, out_channels: int, depth: int, + rnorm=None, diag=1., depth_inv=False, heads: int = 1, concat: bool = True, **kwargs): + + self.rnorm = rnorm + self.diag = diag + self.depth_inv = depth_inv + kwargs.pop('thr_a', None) + kwargs.pop('thr_w', None) + if depth == 0: + final_out = out_channels + heads = 1 + concat = False + else: + + final_out = out_channels + heads = heads + concat = concat + + super().__init__(in_channels, final_out, heads, concat,** kwargs) + + self.logger_a = LayerNumLogger() + self.logger_w = LayerNumLogger() + self.logger_in = LayerNumLogger() + self.logger_msg = LayerNumLogger() + self.reset_parameters() + + + def reset_parameters(self): + + reset_weight_(self.linear.weights, self.in_channels, initializer='glorot') + + reset_weight_(self.att_src, self.out_channels, initializer='glorot') + reset_weight_(self.att_dst, self.out_channels, initializer='glorot') + + if self.bias is not None: + reset_bias_(self.bias, self.heads * self.out_channels, initializer='zeros') + + def propagate(self, x, edge_index, edge_weight=None, num_nodes=None): + return super().propagate(x, edge_index) + + + def forward(self, x, edge_index, edge_weight=None): + + if isinstance(edge_index, (tuple,list)): + edge_index, edge_weight = edge_index + + self.logger_a.numel_after = edge_index.shape[1] + self.logger_w.numel_after = _tensor_numel(self.linear.weights) + return super().forward(x, edge_index, edge_weight) + + @classmethod + def cnt_flops(cls, module, input, output): + x_in, edge_index = input + f_in, f_h, f_c = x_in.shape[-1], module.heads, module.out_channels + n, m = x_in.shape[0], edge_index[0].shape[1] if isinstance(edge_index, tuple) else edge_index.shape[1] + flops_lin = f_in * f_h * f_c * n + module.__flops__ += flops_lin + flops_attn = (2 * f_c + 2) * m * f_h + module.__flops__ += flops_attn + if module.bias is not None: + module.__flops__ += (f_h * f_c if module.concat else f_c + 1) * n + +class GATv2ConvThr(ConvThr, GATv2ConvRaw): + def __init__(self, *args, thr_a, thr_w, **kwargs): + super().__init__(*args, thr_a=thr_a, thr_w=thr_w, **kwargs) + self.prune_lst = [self.linear] + self.idx_keep = None + + self.register_forward_hook(self.prune_on_msg) + + def prune_on_msg(self, module, inputs, output): + msg_tensor = output[0] if isinstance(output, (list, tuple)) else output + + if len(tlx.get_tensor_shape(msg_tensor)) == 3: + msg_tensor = tlx.reduce_mean(msg_tensor, dim=1) + + num_edges = msg_tensor.shape[0] + + if self.scheme_a in ['pruneall', 'pruneinc']: + + norm_feat_msg = tlx.sqrt(tlx.reduce_sum(tlx.square(msg_tensor), axis=1)) + norm_all_msg = tlx.reduce_sum(tlx.abs(norm_feat_msg)) / num_edges + mask_prune = norm_feat_msg < (self.threshold_a * 0.1 * norm_all_msg) + + final_mask_to_keep = tlx.logical_not(mask_prune) + + msg_tensor = tlx.where(tlx.expand_dims(final_mask_to_keep, 1), msg_tensor, tlx.zeros_like(msg_tensor)) + edge_indices = tlx.arange(0, num_edges, dtype=tlx.int64) + self.idx_keep = tlx.cast(edge_indices[final_mask_to_keep], tlx.int64).squeeze() + + elif self.scheme_a == 'keep': + keep_mask = tlx.zeros((num_edges,), dtype=tlx.bool) + if self.idx_keep is not None: + indices_to_save = self.idx_keep[self.idx_keep < num_edges] + keep_mask = tlx.scatter_update(keep_mask, indices_to_save, tlx.ones_like(indices_to_save, dtype=tlx.bool)) + + msg_tensor = tlx.where(tlx.expand_dims(keep_mask, 1), msg_tensor, tlx.zeros_like(msg_tensor)) + + return (msg_tensor,) + output[1:] if isinstance(output, (list, tuple)) else msg_tensor + + def forward(self, x, edge_tuple: PairTensor, node_lock: OptTensor = None, verbose: bool = False): + (edge_index, edge_weight) = edge_tuple + H, C = self.heads, self.out_channels + + if self.scheme_w in ['pruneall', 'pruneinc']: + if self.scheme_w == 'pruneall': + if prune.is_pruned(self.linear): + rewind(self.linear, 'weights') + else: + if prune.is_pruned(self.linear): + prune.remove(self.linear, 'weights') + + norm_node_in = tlx.sqrt(tlx.reduce_sum(tlx.square(x), axis=0)) + norm_all_in = tlx.reduce_sum(norm_node_in) / x.shape[1] + if norm_all_in > 1e-8: + threshold_wi = self.threshold_w * norm_all_in + ThrInPrune.apply(self.linear, 'weights', threshold_wi) + + x = self.linear(x) + + x = tlx.reshape(x, (-1, H, C)) + + elif self.scheme_w == 'keep': + x = self.linear(x) + + x = tlx.reshape(x, (-1, H, C)) + else: + raise NotImplementedError() + + + self.logger_w.numel_before = _tensor_numel(self.linear.weights) + self.logger_w.numel_after = int(tlx.convert_to_numpy(tlx.reduce_sum(tlx.cast(self.linear.weights != 0, tlx.float32)))) + + + #self.idx_lock = None + + out = self.propagate(x, edge_index, edge_weight=edge_weight, num_nodes=tlx.get_tensor_shape(x)[0]) + + + if self.concat: + out = tlx.reshape(out, (-1, self.heads * self.out_channels)) + else: + out = tlx.reduce_mean(out, axis=1) + + if self.bias is not None: + out = out + self.bias + + if self.scheme_a in ['pruneall', 'pruneinc', 'keep']: + num_edges = edge_index.shape[1] + self.logger_a.numel_before = edge_index.shape[1] + + if self.idx_keep is None: + self.idx_keep = tlx.arange(start=0, limit=num_edges, dtype=tlx.int64) + + self.idx_keep = self.idx_keep[self.idx_keep < num_edges] + self.idx_keep = tlx.cast(self.idx_keep, tlx.int64).squeeze() + self.logger_a.numel_after = self.idx_keep.shape[0] + + edge_index = edge_index[:, self.idx_keep] + if edge_weight is not None: + edge_weight = edge_weight[self.idx_keep] + + return out, (edge_index, edge_weight) + +def Linear_cnt_flops(module, input, output): + input = input[0] + pre_last = np.prod(input.shape[0:-1], dtype=np.int64) + bias_flops = output.shape[-1] if module.bias is not None else 0 + module.__flops__ += int((input.shape[-1]*output.shape[-1] + bias_flops) * pre_last * (module.logger_w.ratio if hasattr(module,'logger_w') else 1)) + +layer_dict_gat = { + 'gat': GATv2ConvRaw, + 'gat_unifews': GATv2ConvThr, +} + + +flops_modules_dict_gat = { + nn.Linear: Linear_cnt_flops, + GATv2ConvRaw: GATv2ConvRaw.cnt_flops, + GATv2ConvThr: GATv2ConvThr.cnt_flops, +} \ No newline at end of file diff --git a/gammagl/layers/conv/gcn2_unifews.py b/gammagl/layers/conv/gcn2_unifews.py new file mode 100644 index 000000000..b557d258e --- /dev/null +++ b/gammagl/layers/conv/gcn2_unifews.py @@ -0,0 +1,160 @@ +import os +from math import log +import numpy as np +from typing import Optional, Tuple, Union, Any + + +import tensorlayerx as tlx +import tensorlayerx.nn as nn +from tensorlayerx.nn import Linear +import gammagl as ggl + +Tensor = Any +Adj = Tensor +OptTensor = Optional[Tensor] +PairTensor = Tuple[Tensor, OptTensor] + +from gammagl.layers.conv import ( + MessagePassing,GCNIIConv +) + + +from gammagl.gglspeedup.prunes_gamma import ThrInPrune, rewind,prune, _tensor_numel +from gammagl.utils.logger_unifews import LayerNumLogger +from .gcn_unifews import identity_n_norm, gcn_norm, softmax, leaky_relu, scatter, maybe_num_nodes, reset_weight_, reset_bias_,ConvThr,add_remaining_self_loops,pow_with_pinv,normalize,norm + + + +class GCNIIConvRaw(GCNIIConv): + def __init__(self, in_channels, out_channels, alpha, beta, variant=False, + rnorm=None, diag=1., depth_inv=False, **kwargs): + self.rnorm = rnorm + self.diag = diag + self.depth_inv = depth_inv + + kwargs.pop('thr_a', None) + kwargs.pop('thr_w', None) + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + alpha=alpha, + beta=beta, + variant=variant + ) + + self.logger_a = LayerNumLogger() + self.logger_w = LayerNumLogger() + self.logger_in = LayerNumLogger() + self.logger_msg = LayerNumLogger() + + def forward(self, x, x_0, edge_tuple: PairTensor): + edge_index, edge_weight = edge_tuple + + self.logger_a.numel_after = edge_index.shape[1] + total_w = _tensor_numel(self.linear.weights) + if self.variant: + total_w += _tensor_numel(self.linear0.weights) + self.logger_w.numel_after = total_w + + num_nodes = tlx.get_tensor_shape(x)[0] + return super().forward(x0=x_0, x=x, edge_index=edge_index, + edge_weight=edge_weight, num_nodes=num_nodes) + + def reset_parameters(self): + reset_weight_(self.linear.weights, self.in_channels, initializer='glorot') + if self.variant: + reset_weight_(self.linear0.weights, self.in_channels, initializer='glorot') + if hasattr(self, 'bias') and self.bias is not None: + reset_bias_(self.bias, self.out_channels, initializer='zeros') + + @classmethod + def cnt_flops(cls, module, input, output): + x_in, x_0, (edge_index, edge_weight) = input + x_out = output + f_in, f_out = x_in.shape[-1], x_out.shape[-1] + n, m = x_in.shape[0], edge_index.shape[1] + module.__flops__ += int(f_in * f_out * n) * (2 if module.variant else 1) + module.__flops__ += f_in * m + + + +class GCNIIConvThr(ConvThr, GCNIIConvRaw): + def __init__(self, *args, thr_a, thr_w, **kwargs): + super().__init__(*args, thr_a=thr_a, thr_w=thr_w, **kwargs) + self.prune_lst = [self.linear] + if self.variant: + self.prune_lst.append(self.linear0) + + self.idx_keep = None + + def forward(self, x, x_0, edge_tuple, node_lock=None, verbose=False): + (edge_index, edge_weight) = edge_tuple + num_nodes = tlx.get_tensor_shape(x)[0] + num_edges = edge_index.shape[1] + self.current_edge_count = num_edges + + if self.scheme_w in ['pruneall', 'pruneinc']: + if self.scheme_w == 'pruneall': + if prune.is_pruned(self.linear): rewind(self.linear, 'weights') + if self.variant and prune.is_pruned(self.linear0): rewind(self.linear0, 'weights') + else: + if prune.is_pruned(self.linear): prune.remove(self.linear, 'weights') + if self.variant and prune.is_pruned(self.linear0): prune.remove(self.linear0, 'weights') + + norm_node_in = tlx.sqrt(tlx.reduce_sum(tlx.square(x), axis=0)) + norm_all_in = tlx.reduce_sum(norm_node_in) / (x.shape[1] + 1e-10) + + if norm_all_in > 1e-8: + threshold_wi = self.threshold_w * norm_all_in + ThrInPrune.apply(self.linear, 'weights', threshold_wi) + if self.variant: + ThrInPrune.apply(self.linear0, 'weights', threshold_wi) + + self.logger_w.numel_before = _tensor_numel(self.linear.weights) + self.logger_w.numel_after = int(tlx.convert_to_numpy(tlx.reduce_sum(tlx.cast(self.linear.weights != 0, tlx.float32)))) + + self.logger_a.numel_before = num_edges + if self.idx_keep is None: + self.idx_keep = tlx.arange(start=0, limit=num_edges, dtype=tlx.int64) + + self.idx_keep = self.idx_keep[self.idx_keep < num_edges] + self.idx_keep = tlx.cast(self.idx_keep, tlx.int64).squeeze() + self.logger_a.numel_after = self.idx_keep.shape[0] if self.idx_keep.ndim > 0 else 1 + + edge_index = tlx.gather(edge_index, self.idx_keep, axis=1) + if edge_weight is not None: + edge_weight = tlx.gather(edge_weight, self.idx_keep) + + m = self.propagate(x, edge_index, edge_weight=edge_weight, num_nodes=num_nodes) + + m = m * (1 - self.alpha) + x_0_part = x_0[:num_nodes] * self.alpha + + if not self.variant: + out = m + x_0_part + out = out * (1 - self.beta) + tlx.matmul(out, self.linear.weights) * self.beta + else: + out = (m * (1 - self.beta) + tlx.matmul(m, self.linear.weights) * self.beta + + x_0_part * (1 - self.beta) + tlx.matmul(x_0_part, self.linear0.weights) * self.beta) + + return out, (edge_index, edge_weight) + +def Linear_cnt_flops(module, input, output): + input = input[0] + pre_last = np.prod(input.shape[0:-1], dtype=np.int64) + bias_flops = output.shape[-1] if module.bias is not None else 0 + module.__flops__ += int((input.shape[-1]*output.shape[-1] + bias_flops) * pre_last * (module.logger_w.ratio if hasattr(module,'logger_w') else 1)) + +layer_dict_gcn2 = { + 'gcn2': GCNIIConvRaw, + 'gcn2_unifews': GCNIIConvThr, +} + + + +flops_modules_dict_gcn2 = { + nn.Linear: Linear_cnt_flops, + GCNIIConvRaw: GCNIIConvRaw.cnt_flops, + GCNIIConvThr: GCNIIConvThr.cnt_flops, +} \ No newline at end of file diff --git a/gammagl/layers/conv/gcn_unifews.py b/gammagl/layers/conv/gcn_unifews.py new file mode 100644 index 000000000..85dbea4b5 --- /dev/null +++ b/gammagl/layers/conv/gcn_unifews.py @@ -0,0 +1,371 @@ +import os +from math import log +import numpy as np +from typing import Optional, Tuple, Union, Any + + +import tensorlayerx as tlx +import tensorlayerx.nn as nn +from tensorlayerx.nn import Linear + +Tensor = Any +Adj = Tensor +OptTensor = Optional[Tensor] +PairTensor = Tuple[Tensor, OptTensor] + +from gammagl.layers.conv import ( + GCNConv, MessagePassing +) + + +from gammagl.gglspeedup.prunes_gamma import ThrInPrune, rewind,prune, _tensor_numel +from gammagl.utils.logger_unifews import LayerNumLogger + + +def norm(x, p=2, axis=None, keepdims=False): + """ + Manual implementation of the norm function for TensorLayerX. + """ + # Ensure we are working with absolute values for the norm calculation + x_abs = tlx.abs(x) + + if p == 1: + return tlx.reduce_sum(x_abs, axis=axis, keepdims=keepdims) + elif p == 2: + # Standard Euclidean norm: sqrt(sum(x^2)) + return tlx.sqrt(tlx.reduce_sum(x_abs * x_abs, axis=axis, keepdims=keepdims)) + else: + # General Lp norm + sum_p = tlx.reduce_sum(tlx.pow(x_abs, p), axis=axis, keepdims=keepdims) + return tlx.pow(sum_p, 1.0/p) + +def normalize(x, p=2., axis=-1): + # Use our new manual norm function here + norm_val = norm(x, p=p, axis=axis, keepdims=True) + return x / (norm_val + 1e-12) + + +def reset_weight_(weight: Tensor, in_channels: int, initializer: Optional[str] = None) -> Tensor: + if weight is None: + return weight + shape = weight.shape + device = weight.device if hasattr(weight, 'device') else None + if in_channels <= 0: + return + elif initializer == 'glorot': + new_weight = tlx.initializers.XavierUniform()(shape) + elif initializer == 'uniform': + bound = 1.0 / np.sqrt(in_channels) + new_weight = tlx.initializers.RandomUniform(-bound, bound)(shape) + elif initializer == 'kaiming_uniform' or initializer is None: + new_weight = tlx.initializers.HeUniform()(shape) + else: + raise RuntimeError(f"Weight initializer '{initializer}' not supported") + if device is not None: + new_weight = tlx.convert_to_tensor(new_weight, device=device) + if hasattr(weight, 'data'): + weight.data = new_weight + else: + weight.copy_(new_weight) + return weight + +def reset_bias_(bias: Optional[Tensor], in_channels: int, initializer: Optional[str] = None) -> Optional[Tensor]: + if bias is None or in_channels <= 0: + return bias + shape = bias.shape + device = bias.device if hasattr(bias, 'device') else None + if initializer == 'zeros': + new_bias = tlx.initializers.Zeros()(shape) + elif initializer == 'uniform' or initializer is None: + bound = 1.0 / np.sqrt(in_channels) + new_bias = tlx.initializers.RandomUniform(-bound, bound)(shape) + else: + raise RuntimeError(f"Bias initializer '{initializer}' not supported") + if device is not None: + new_bias = tlx.convert_to_tensor(new_bias, device=device) + if hasattr(bias, 'data'): + bias.data = new_bias + else: + bias.copy_(new_bias) + return bias + + + +def maybe_num_nodes(edge_index, num_nodes=None): + if num_nodes is not None: + return num_nodes + return int(tlx.reduce_max(edge_index)) + 1 + +def scatter(src, index, axis=0, dim_size=None, reduce='sum'): + if dim_size is None: + dim_size = int(tlx.reduce_max(index)) + 1 + if reduce == 'sum': + return tlx.unsorted_segment_sum(src, index, num_segments=dim_size) + elif reduce == 'max': + return tlx.unsorted_segment_max(src, index, num_segments=dim_size) + elif reduce == 'mean': + return tlx.unsorted_segment_mean(src, index, num_segments=dim_size) + elif reduce == 'min': + return tlx.unsorted_segment_min(src, index, num_segments=dim_size) + else: + raise ValueError(f"Unsupported reduce type: {reduce}") + + + +def add_remaining_self_loops(edge_index, edge_weight=None, fill_value=1.0, num_nodes=None): + num_nodes = maybe_num_nodes(edge_index, num_nodes) + row, col = edge_index[0], edge_index[1] + mask = row != col + + if edge_weight is not None: + edge_weight = edge_weight[mask] + edge_index = edge_index[:, mask] + + loop_index = tlx.convert_to_tensor(np.arange(num_nodes), dtype=edge_index.dtype) + loop_index = tlx.stack([loop_index, loop_index], axis=0) + + edge_index = tlx.concat([edge_index, loop_index], axis=1) + + if edge_weight is not None: + loop_weight = tlx.ones((num_nodes,), dtype=edge_weight.dtype) * fill_value + edge_weight = tlx.concat([edge_weight, loop_weight], axis=0) + + return edge_index, edge_weight + +def pow_with_pinv(x, p: float): + x = tlx.convert_to_tensor(x) + x_pow = tlx.pow(x, p) + x_safe = tlx.where(tlx.is_inf(x_pow), tlx.zeros_like(x_pow), x_pow) + return x_safe + +def leaky_relu(x, negative_slope=0.2): + return tlx.where(x > 0.0, x, x * negative_slope) + +def softmax(src, index, ptr=None, num_nodes=None): + if num_nodes is None: + num_nodes = int(tlx.reduce_max(index)) + 1 + src_max = scatter(src, index, axis=0, dim_size=num_nodes, reduce='max') + src_max_gathered = tlx.gather(src_max, index) + out = tlx.exp(src - src_max_gathered) + out_sum = scatter(out, index, axis=0, dim_size=num_nodes, reduce='sum') + out_sum_gathered = tlx.gather(out_sum, index) + return out / (out_sum_gathered + 1e-16) + +def gcn_norm(edge_index, edge_weight, num_nodes, improved=False, add_self_loops=True, flow="source_to_target", dtype=tlx.float32): + if add_self_loops: + edge_index, edge_weight = add_remaining_self_loops(edge_index, edge_weight, 1.0, num_nodes) + if edge_weight is None: + edge_weight = tlx.ones((edge_index.shape[1],), dtype=dtype) + row, col = edge_index[0], edge_index[1] + idx = col if flow == "source_to_target" else row + deg = scatter(edge_weight, idx, axis=0, dim_size=num_nodes, reduce='sum') + deg_inv_sqrt = pow_with_pinv(deg, -0.5) + edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] + return edge_index, edge_weight + +def identity_n_norm(edge_index, edge_weight=None, num_nodes=None, + rnorm=None, diag=1., dtype=tlx.float32): + if tlx.is_tensor(edge_index): + num_nodes = maybe_num_nodes(edge_index, num_nodes) + if diag is not None: + edge_index, edge_weight = add_remaining_self_loops( + edge_index, edge_weight, diag, num_nodes) + if rnorm is None: + return edge_index + else: + edge_weight = tlx.ones((edge_index.shape[1], ), dtype=dtype, + device=edge_index.device) + row, col = edge_index[0], edge_index[1] + idx = col + deg = scatter(edge_weight, idx, axis=0, dim_size=num_nodes, reduce='sum') + deg_inv_sqrt = pow_with_pinv(deg, -0.5) + edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] + return edge_index, edge_weight + raise NotImplementedError() + + + +class ConvThr(nn.Module): + def __init__(self, *args, thr_a, thr_w, **kwargs): + super().__init__(*args, **kwargs) + self.threshold_a = thr_a + self.threshold_w = thr_w + self.idx_keep = tlx.convert_to_tensor([]) + self.prune_lst = [] + self.scheme_a = 'full' + self.scheme_w = 'full' + + def propagate_forward_print(self, module, inputs, output): + print(inputs[0], inputs[0].shape, inputs[1]) + print(output, output.shape) + + def get_idx_lock(self, edge_index, node_lock): + # Ensure starting with a valid empty tensor for concatenation + idx_lock = tlx.convert_to_tensor([], dtype=tlx.int64) + if edge_index.shape[1] == 0: + return idx_lock + + bs = int(2**28 / edge_index.shape[1]) if edge_index.shape[1] > 0 else 1 + + for i in range(0, node_lock.shape[0], bs): + batch = node_lock[i:min(i+bs, node_lock.shape[0])] + mask = tlx.reduce_any(tlx.expand_dims(edge_index[1], 0) == tlx.expand_dims(batch, 1), axis=0) + batch_idx = tlx.arange(0, edge_index.shape[1])[mask] + idx_lock = tlx.concat((idx_lock, tlx.cast(batch_idx, tlx.int64)), axis=0) + diag_mask = edge_index[0] == edge_index[1] + idx_diag = tlx.arange(0, edge_index.shape[1])[diag_mask] + + idx_lock = tlx.concat((idx_lock, tlx.cast(idx_diag, tlx.int64)), axis=0) + return tlx.unique(idx_lock) + +class GCNConvRaw(GCNConv): + + def __init__(self, in_channels, out_channels, + rnorm=None, diag=1., depth_inv=False, *args, **kwargs): + self.rnorm = rnorm + self.diag = diag + self.depth_inv = depth_inv + kwargs.pop('thr_a', None) + kwargs.pop('thr_w', None) + + super().__init__(in_channels, out_channels, *args, **kwargs) + self.logger_a = LayerNumLogger() + self.logger_w = LayerNumLogger() + self.logger_in = LayerNumLogger() + self.logger_msg = LayerNumLogger() + self.reset_parameters() + + def reset_parameters(self): + reset_weight_(self.linear.weights, self.in_channels, initializer='kaiming_uniform') + if self.bias is not None: + reset_bias_(self.bias, self.out_channels, initializer='zeros') + + def forward(self, x, edge_tuple: PairTensor, **kwargs): + (edge_index, edge_weight) = edge_tuple + self.logger_a.numel_after = edge_index.shape[1] + self.logger_w.numel_after = _tensor_numel(self.linear.weights) + return super().forward(x, edge_index, edge_weight) + + @classmethod + def cnt_flops(cls, module, input, output): + x_in, (edge_index, edge_weight) = input + x_out = output + f_in, f_out = x_in.shape[-1], x_out.shape[-1] + n, m = x_in.shape[0], edge_index.shape[1] + flops_bias = f_out if module.bias is not None else 0 + module.__flops__ += int(f_in * f_out * n) + module.__flops__ += flops_bias * n + module.__flops__ += f_in * m + +class GCNConvThr(ConvThr, GCNConvRaw): + def __init__(self, *args, thr_a, thr_w, **kwargs): + super().__init__(*args, thr_a=thr_a, thr_w=thr_w, **kwargs) + self.prune_lst = [self.linear] + self.register_forward_hook(self.prune_on_msg) + self.idx_keep = None + + def prune_on_msg(self, module, inputs, output): + msg_tensor = output[0] if isinstance(output, (list, tuple)) else output + num_edges = msg_tensor.shape[0] + + if self.scheme_a in ['pruneall', 'pruneinc']: + + norm_feat_msg = tlx.sqrt(tlx.reduce_sum(tlx.square(msg_tensor), axis=1)) + norm_all_msg = tlx.reduce_sum(tlx.abs(norm_feat_msg)) / num_edges + mask_prune = norm_feat_msg < (self.threshold_a * 0.1 * norm_all_msg) + + final_mask_to_keep = tlx.logical_not(mask_prune) + + msg_tensor = tlx.where(tlx.expand_dims(final_mask_to_keep, 1), msg_tensor, tlx.zeros_like(msg_tensor)) + edge_indices = tlx.arange(0, num_edges, dtype=tlx.int64) + self.idx_keep = tlx.cast(edge_indices[final_mask_to_keep], tlx.int64).squeeze() + + elif self.scheme_a == 'keep': + keep_mask = tlx.zeros((num_edges,), dtype=tlx.bool) + if self.idx_keep is not None: + indices_to_save = self.idx_keep[self.idx_keep < num_edges] + keep_mask = tlx.scatter_update(keep_mask, indices_to_save, tlx.ones_like(indices_to_save, dtype=tlx.bool)) + + msg_tensor = tlx.where(tlx.expand_dims(keep_mask, 1), msg_tensor, tlx.zeros_like(msg_tensor)) + + return (msg_tensor,) + output[1:] if isinstance(output, (list, tuple)) else msg_tensor + + def forward(self, x, edge_tuple: PairTensor, node_lock: OptTensor = None, verbose: bool = False): + (edge_index, edge_weight) = edge_tuple + + if self.scheme_w in ['pruneall', 'pruneinc']: + if self.scheme_w == 'pruneall': + if prune.is_pruned(self.linear): + rewind(self.linear, 'weights') + else: + if prune.is_pruned(self.linear): + prune.remove(self.linear, 'weights') + + norm_node_in = tlx.sqrt(tlx.reduce_sum(tlx.square(x), axis=0)) + norm_all_in = tlx.reduce_sum(norm_node_in) / x.shape[1] + if norm_all_in > 1e-8: + threshold_wi = self.threshold_w * norm_all_in + ThrInPrune.apply(self.linear, 'weights', threshold_wi, dim=0) + + x = self.linear(x) + + elif self.scheme_w == 'keep': + x = self.linear(x) + elif self.scheme_w == 'full': + raise NotImplementedError() + + self.logger_w.numel_before = _tensor_numel(self.linear.weights) + self.logger_w.numel_after = int(tlx.convert_to_numpy(tlx.reduce_sum(tlx.cast(self.linear.weights != 0, tlx.float32)))) + + self.idx_lock = None + + + out = self.propagate(x, edge_index, edge_weight=edge_weight, num_nodes=tlx.get_tensor_shape(x)[0]) + if self.bias is not None: + out = out + self.bias + + + if self.scheme_a in ['pruneall', 'pruneinc', 'keep']: + num_edges = edge_index.shape[1] + self.logger_a.numel_before = edge_index.shape[1] + + if self.idx_keep is None: + self.idx_keep = tlx.arange(start=0, limit=num_edges, dtype=tlx.int64) + + self.idx_keep = self.idx_keep[self.idx_keep < num_edges] + self.idx_keep = tlx.cast(self.idx_keep, tlx.int64).squeeze() + self.logger_a.numel_after = self.idx_keep.shape[0] + + edge_index = edge_index[:, self.idx_keep] + if edge_weight is not None: + edge_weight = edge_weight[self.idx_keep] + + return out, (edge_index, edge_weight) + + @classmethod + def cnt_flops(cls, module, input, output): + x_in, _ = input + x_out, (edge_index, edge_weight) = output + f_in, f_out = x_in.shape[-1], x_out.shape[-1] + n, m = x_in.shape[0], edge_index.shape[1] + flops_bias = f_out if module.bias is not None else 0 + module.__flops__ += int((f_in * f_out * module.logger_w.ratio + flops_bias) * n) + module.__flops__ += f_in * (m - n) + +def Linear_cnt_flops(module, input, output): + input = input[0] + pre_last = np.prod(input.shape[0:-1], dtype=np.int64) + bias_flops = output.shape[-1] if module.bias is not None else 0 + module.__flops__ += int((input.shape[-1]*output.shape[-1] + bias_flops) * pre_last * (module.logger_w.ratio if hasattr(module,'logger_w') else 1)) + +layer_dict_gcn = { + 'gcn': GCNConvRaw, + 'gcn_unifews': GCNConvThr, +} + + +flops_modules_dict_gcn = { + nn.Linear: Linear_cnt_flops, + GCNConvRaw: GCNConvRaw.cnt_flops, + GCNConvThr: GCNConvThr.cnt_flops, +} \ No newline at end of file diff --git a/gammagl/layers/conv/gsage_unifews.py b/gammagl/layers/conv/gsage_unifews.py new file mode 100644 index 000000000..2ca9bfe64 --- /dev/null +++ b/gammagl/layers/conv/gsage_unifews.py @@ -0,0 +1,176 @@ +import os +from math import log +import numpy as np +from typing import Optional, Tuple, Union, Any + + +import tensorlayerx as tlx +import tensorlayerx.nn as nn +from tensorlayerx.nn import Linear +import gammagl as ggl + +Tensor = Any +Adj = Tensor +OptTensor = Optional[Tensor] +PairTensor = Tuple[Tensor, OptTensor] + +from gammagl.layers.conv import ( + SAGEConv, MessagePassing +) + +from gammagl.gglspeedup.prunes_gamma import ThrInPrune, rewind,prune, _tensor_numel +from gammagl.utils.logger_unifews import LayerNumLogger +from .gcn_unifews import identity_n_norm, gcn_norm, softmax, leaky_relu, scatter, maybe_num_nodes, reset_weight_, reset_bias_,ConvThr,add_remaining_self_loops,pow_with_pinv,normalize,norm + +class SAGEConvRaw(SAGEConv): + def __init__(self, in_channels, out_channels, rnorm=None, diag=1., depth_inv=False, *args, **kwargs): + self.rnorm = rnorm + self.diag = diag + self.depth_inv = depth_inv + kwargs.pop('thr_a', None) + kwargs.pop('thr_w', None) + kwargs.pop('root_weight', None) + kwargs.pop('project', None) + kwargs.pop('bias',None) + super().__init__(in_channels, out_channels, *args, **kwargs) + + self.logger_a = LayerNumLogger() + self.logger_w = LayerNumLogger() + self.logger_in = LayerNumLogger() + self.logger_msg = LayerNumLogger() + self.reset_parameters() + + + def reset_parameters(self): + reset_weight_(self.fc_neigh.weights, self.in_feat, initializer='kaiming_uniform') + if self.aggr != 'gcn': + reset_weight_(self.fc_self.weights, self.in_feat, initializer='kaiming_uniform') + + def forward(self, x, edge_tuple: PairTensor, **kwargs): + (edge_index, edge_weight) = edge_tuple + self.logger_a.numel_after = edge_index.shape[1] + total_w = _tensor_numel(self.fc_neigh.weights) + if self.aggr != 'gcn': + total_w += _tensor_numel(self.fc_self.weights) + self.logger_w.numel_after = total_w + return super().forward(x, edge_index) + + @classmethod + def cnt_flops(cls, module, input, output): + x_in, edge_tuple = input + x_out = output + f_in, f_out = x_in.shape[-1], x_out.shape[-1] + n, m = x_in.shape[0], edge_tuple[0].shape[1] + module.__flops__ += int(f_in * f_out * n) * (2 if module.aggr != 'gcn' else 1) + module.__flops__ += (f_out if module.add_bias else 0) * n + module.__flops__ += f_in * m + +class SAGEConvThr(ConvThr, SAGEConvRaw): + def __init__(self, *args, thr_a, thr_w, **kwargs): + super().__init__(*args, thr_a=thr_a, thr_w=thr_w, **kwargs) + + self.prune_lst = [self.fc_neigh] + if self.aggr != 'gcn': + self.prune_lst.append(self.fc_self) + self.register_forward_hook(self.prune_on_msg) + + def prune_on_msg(self, module, inputs, output): + msg_tensor = output[0] if isinstance(output, (list, tuple)) else output + num_edges = msg_tensor.shape[0] + + if self.scheme_a in ['pruneall', 'pruneinc']: + + norm_feat_msg = tlx.sqrt(tlx.reduce_sum(tlx.square(msg_tensor), axis=1)) + norm_all_msg = tlx.reduce_sum(tlx.abs(norm_feat_msg)) / num_edges + mask_prune = norm_feat_msg < (self.threshold_a * 0.1 * norm_all_msg) + + final_mask_to_keep = tlx.logical_not(mask_prune) + + msg_tensor = tlx.where(tlx.expand_dims(final_mask_to_keep, 1), msg_tensor, tlx.zeros_like(msg_tensor)) + edge_indices = tlx.arange(0, num_edges, dtype=tlx.int64) + self.idx_keep = tlx.cast(edge_indices[final_mask_to_keep], tlx.int64).squeeze() + + elif self.scheme_a == 'keep': + keep_mask = tlx.zeros((num_edges,), dtype=tlx.bool) + if self.idx_keep is not None: + indices_to_save = self.idx_keep[self.idx_keep < num_edges] + keep_mask = tlx.scatter_update(keep_mask, indices_to_save, tlx.ones_like(indices_to_save, dtype=tlx.bool)) + + msg_tensor = tlx.where(tlx.expand_dims(keep_mask, 1), msg_tensor, tlx.zeros_like(msg_tensor)) + + return (msg_tensor,) + output[1:] if isinstance(output, (list, tuple)) else msg_tensor + + def forward(self, x, edge_tuple: PairTensor, node_lock: OptTensor = None, verbose: bool = False): + (edge_index, edge_weight) = edge_tuple + x = (x, x) if tlx.is_tensor(x) else x + if self.scheme_w in ['pruneall', 'pruneinc']: + if self.scheme_w == 'pruneall': + if prune.is_pruned(self.fc_neigh): + rewind(self.fc_neigh, 'weights') + else: + if prune.is_pruned(self.fc_neigh): + prune.remove(self.fc_neigh, 'weights') + if self.aggr != 'gcn': + if self.scheme_w == 'pruneall': + if prune.is_pruned(self.fc_self): + rewind(self.fc_self, 'weights') + else: + if prune.is_pruned(self.fc_self): + prune.remove(self.fc_self, 'weights') + norm_node_in = tlx.sqrt(tlx.reduce_sum(tlx.square(x[0]), axis=0)) + norm_all_in = tlx.reduce_sum(norm_node_in) / x[0].shape[1] + if norm_all_in > 1e-8: + threshold_wi = self.threshold_w * norm_all_in + ThrInPrune.apply(self.fc_neigh, 'weights', threshold_wi, dim=0) + if self.aggr != 'gcn': + ThrInPrune.apply(self.fc_self, 'weights', threshold_wi, dim=0) + total_w_before = _tensor_numel(self.fc_neigh.weights) + total_w_after = int(tlx.convert_to_numpy(tlx.reduce_sum(tlx.cast(self.fc_neigh.weights != 0, tlx.float32)))) + if self.aggr != 'gcn': + total_w_before += _tensor_numel(self.fc_self.weights) + total_w_after += int(tlx.convert_to_numpy(tlx.reduce_sum(tlx.cast(self.fc_self.weights != 0, tlx.float32)))) + self.logger_w.numel_before = total_w_before + self.logger_w.numel_after = total_w_after + out = self.propagate(x[0], edge_index, edge_weight=edge_weight, num_nodes=x[1].shape[0]) + out = self.fc_neigh(out) + if self.aggr != 'gcn': + out += self.fc_self(x[1]) + + if self.add_bias: + out += self.bias + + self.idx_lock = None + if self.scheme_a in ['pruneall', 'pruneinc', 'keep']: + num_edges = edge_index.shape[1] + self.logger_a.numel_before = num_edges + + if self.idx_keep is None: + self.idx_keep = tlx.arange(start=0, limit=num_edges, dtype=tlx.int64) + + self.idx_keep = self.idx_keep[self.idx_keep < num_edges] + self.idx_keep = tlx.cast(self.idx_keep, tlx.int64).squeeze() + self.logger_a.numel_after = self.idx_keep.shape[0] + + edge_index = edge_index[:, self.idx_keep] + if edge_weight is not None: + edge_weight = edge_weight[self.idx_keep] + + return out, (edge_index, edge_weight) + +def Linear_cnt_flops(module, input, output): + input = input[0] + pre_last = np.prod(input.shape[0:-1], dtype=np.int64) + bias_flops = output.shape[-1] if module.bias is not None else 0 + module.__flops__ += int((input.shape[-1]*output.shape[-1] + bias_flops) * pre_last * (module.logger_w.ratio if hasattr(module,'logger_w') else 1)) + +layer_dict_gsage = { + 'gsage': SAGEConvRaw, + 'gsage_unifews': SAGEConvThr, +} + + +flops_modules_dict_gsage = { + nn.Linear: Linear_cnt_flops, + SAGEConvRaw: SAGEConvRaw.cnt_flops, + SAGEConvThr: SAGEConvThr.cnt_flops, +} \ No newline at end of file diff --git a/gammagl/models/__init__.py b/gammagl/models/__init__.py index 062ee67ea..7a14c50c5 100644 --- a/gammagl/models/__init__.py +++ b/gammagl/models/__init__.py @@ -67,7 +67,11 @@ from .sgformer import SGFormerModel from .adagad import PreModel, ReModel from .nodeid import NodeIDGNN - +from .gcn2_unifews import SandwitchGCNII +from .gcn2_unifews import SandwitchThr +from .gnn_unifews import GNNThr +from .gnn_unifews import GNNLPThr +from .mlp_unifews import MLP_unifews __all__ = [ 'HeCo', 'GCNModel', @@ -141,8 +145,13 @@ 'LogReg', 'sgformer', 'PreModel', - 'ReModel' - , 'NodeIDGNN' + 'ReModel', + 'NodeIDGNN', + 'SandwitchGCNII', + 'SandwitchThr', + 'GNNThr', + 'GNNLPThr', + 'MLP_unifews' ] classes = __all__ diff --git a/gammagl/models/gcn2_unifews.py b/gammagl/models/gcn2_unifews.py new file mode 100644 index 000000000..cd4acbd1e --- /dev/null +++ b/gammagl/models/gcn2_unifews.py @@ -0,0 +1,247 @@ +import os +import tensorlayerx as tlx +import tensorlayerx.nn as nn +from gammagl.gglspeedup.prunes_gamma import prune, rewind, ThrInPrune, ThrProdPrune + +def reset_bn_(bn_module): + if bn_module is None: + return + + if hasattr(bn_module, 'weight') and bn_module.weight is not None: + new_gamma = tlx.convert_to_tensor( + tlx.initializers.Ones()(bn_module.weight.shape), dtype=bn_module.weight.dtype) + tlx.assign(bn_module.weight, new_gamma) + + if hasattr(bn_module, 'bias') and bn_module.bias is not None: + new_beta = tlx.convert_to_tensor( + tlx.initializers.Zeros()(bn_module.bias.shape), dtype=bn_module.bias.dtype) + tlx.assign(bn_module.bias, new_beta) + + if hasattr(bn_module, 'moving_mean') and bn_module.moving_mean is not None: + new_mean = tlx.convert_to_tensor( + tlx.initializers.Zeros()(bn_module.moving_mean.shape), dtype=bn_module.moving_mean.dtype) + tlx.assign(bn_module.moving_mean, new_mean) + + if hasattr(bn_module, 'moving_var') and bn_module.moving_var is not None: + new_var = tlx.convert_to_tensor( + tlx.initializers.Ones()(bn_module.moving_var.shape), dtype=bn_module.moving_var.dtype) + tlx.assign(bn_module.moving_var, new_var) + +from gammagl.layers.conv.gat_unifews import ( + ThrInPrune, LayerNumLogger, rewind, + reset_weight_, reset_bias_, + add_remaining_self_loops +) +from gammagl.layers.conv.gcn_unifews import gcn_norm +from .gnn_unifews import layer_dict + +kwargs_default = { + 'gcn': { + 'cached': False, + 'add_self_loops': False, + 'improved': False, + 'normalize': False, + 'rnorm': 0.5, + 'diag': 1.0, + 'depth_inv': False, + }, + 'gin': { + 'eps': 0.0, + 'train_eps': False, + 'rnorm': None, + 'diag': 1.0, + 'depth_inv': False, + }, + 'gat': { + 'heads': 1, + 'concat': True, + 'add_self_loops': False, + 'rnorm': None, + 'diag': 1.0, + 'depth_inv': False, + 'depth': 2, + }, + 'gcn2': { + 'alpha': 0.1, + 'beta': 0.5, + 'cached': False, + 'add_self_loops': False, + 'normalize': False, + 'rnorm': 0.5, + 'diag': 1.0, + 'depth_inv': True, + }, + 'gsage': { + 'aggr': 'mean', + 'improved': False, + 'normalize': False, + 'root_weight': True, + 'project': False, + 'bias': True, + 'rnorm': 0.5, + 'diag': 1.0, + 'depth_inv': False, + }, +} + +def state2module(model, name): + parts = name.split('.') + module = model + for part in parts[:-1]: + module = getattr(module, part) + return module + +def set_attr(module, key, value): + if hasattr(module, key): + setattr(module, key, value) + + +class SandwitchGCNII(nn.Module): + def __init__(self, nlayer, nfeat, nhidden, nclass, alpha, beta, + thr_a=0.0, thr_w=0.0, dropout=0.0, variant=False, layer='gcn2', **kwargs): + super().__init__() + self.nfeat = nfeat + self.nhidden = nhidden + self.nclass = nclass + self.act = nn.ReLU() + self.dropout = nn.Dropout(dropout) + self.use_bn = True + + self.lin_in = nn.Linear(in_features=nfeat, out_features=nhidden) + self.lin_out = nn.Linear(in_features=nhidden, out_features=nclass) + + for lin in [self.lin_in, self.lin_out]: + object.__setattr__(lin, 'act', lambda x: x) + + Conv = layer_dict[layer] + + self.convs = nn.ModuleList() + self.norms = nn.ModuleList() + + thr_a = [thr_a] * nlayer if not isinstance(thr_a, list) else thr_a + thr_w = [thr_w] * nlayer if not isinstance(thr_w, list) else thr_w + + for i in range(nlayer): + self.convs.append(Conv( + in_channels=nhidden, + out_channels=nhidden, + alpha=alpha, + beta=beta, + variant=variant, + thr_a=thr_a[i], + thr_w=thr_w[i] + )) + self.norms.append(nn.BatchNorm1d(num_features=nhidden, momentum=0.1)) + + def reset_parameters(self): + if not hasattr(self.lin_in, 'weights'): self.lin_in.build((None, self.nfeat)) + reset_weight_(self.lin_in.weights, self.nfeat) + reset_bias_(self.lin_in.biases, self.nfeat) + + if not hasattr(self.lin_out, 'weights'): self.lin_out.build((None, self.nhidden)) + reset_weight_(self.lin_out.weights, self.nhidden) + reset_bias_(self.lin_out.biases, self.nhidden) + + for conv in self.convs: + if hasattr(conv, 'reset_parameters'): conv.reset_parameters() + for norm in self.norms: + reset_bn_(norm) + + def forward(self, x, edge_idx, **kwargs): + + x = self.lin_in(x) + x = x_0 = self.act(x) + x = self.dropout(x) + for i, conv in enumerate(self.convs): + x = conv(x, x_0, edge_idx) + if self.use_bn: x = self.norms[i](x) + x = self.act(x) + x = self.dropout(x) + return self.lin_out(x) + + def set_scheme(self, scheme_a, scheme_w): + self.apply(lambda m: set_attr(m, 'scheme_a', scheme_a)) + self.apply(lambda m: set_attr(m, 'scheme_w', scheme_w)) + + def get_numel(self): + numel_a = sum(c.logger_a.numel_after for c in self.convs if hasattr(c, 'logger_a')) + numel_w = sum(c.logger_w.numel_after for c in self.convs if hasattr(c, 'logger_w')) + return numel_a/1e3, numel_w/1e3 + +class SandwitchThr(SandwitchGCNII): + def __init__(self, nlayer, nfeat, nhidden, nclass, + thr_a=0.0, thr_w=0.0, dropout: float = 0.0, layer: str = 'gcn2_unifews', + **kwargs): + + alpha = kwargs.get('alpha', 0.1) + beta = kwargs.get('beta', 0.5) + variant = kwargs.get('variant', False) + + + super().__init__(nlayer=nlayer, nfeat=nfeat, nhidden=nhidden, nclass=nclass, + alpha=alpha, beta=beta, thr_a=thr_a, thr_w=thr_w, + dropout=dropout, variant=variant, layer=layer, **kwargs) + + self.apply_thr = '_' in layer + self.normalize_adj = kwargs.get('normalize', False) + self.add_self_loops = kwargs.get('add_self_loops', False) + + def _process_graph(self, x, edge_idx): + if not (self.normalize_adj or self.add_self_loops): + return edge_idx + edge_index, edge_weight = edge_idx if isinstance(edge_idx, (tuple, list)) else (edge_idx, None) + num_nodes = x.shape[0] + if self.normalize_adj: + edge_index, edge_weight = gcn_norm(edge_index, edge_weight, num_nodes, add_self_loops=self.add_self_loops, dtype=x.dtype) + elif self.add_self_loops: + edge_index, edge_weight = add_remaining_self_loops(edge_index, edge_weight, num_nodes=num_nodes) + return (edge_index, edge_weight) + + def forward(self, x, edge_idx, node_lock=tlx.convert_to_tensor([]), verbose=False): + edge_idx = self._process_graph(x, edge_idx) + + x = self.lin_in(x) + x = x_0 = self.act(x) + x = self.dropout(x) + + if self.apply_thr: + for i, conv in enumerate(self.convs): + + x, edge_idx = conv(x, x_0, edge_idx, node_lock=node_lock, verbose=verbose) + if self.use_bn: x = self.norms[i](x) + x = self.act(x) + x = self.dropout(x) + else: + for i, conv in enumerate(self.convs): + x = conv(x, x_0, edge_idx) + if self.use_bn: x = self.norms[i](x) + x = self.act(x) + x = self.dropout(x) + + x = self.lin_out(x) + return x + + def remove(self): + for conv in self.convs: + if hasattr(conv, 'prune_lst'): + for m in conv.prune_lst: + if prune.is_pruned(m): prune.remove(m, 'weights') + + def get_repre(self, x, edge_idx, layer=None, node_lock=tlx.convert_to_tensor([]), verbose=False): + layer = layer or len(self.convs) + edge_idx = self._process_graph(x, edge_idx) + + x = self.lin_in(x) + x = x_0 = self.act(x) + x = self.dropout(x) + + for i, conv in enumerate(self.convs[:layer]): + if self.apply_thr: + x, edge_idx = conv(x, x_0, edge_idx, node_lock=node_lock, verbose=verbose) + else: + x = conv(x, x_0, edge_idx) + if self.use_bn: x = self.norms[i](x) + x = self.act(x) + x = self.dropout(x) + return x + diff --git a/gammagl/models/gnn_unifews.py b/gammagl/models/gnn_unifews.py new file mode 100644 index 000000000..9a25174aa --- /dev/null +++ b/gammagl/models/gnn_unifews.py @@ -0,0 +1,268 @@ +import os +import tensorlayerx as tlx +import tensorlayerx.nn as nn +from gammagl.gglspeedup.prunes_gamma import prune, rewind, ThrInPrune, ThrProdPrune + +def reset_bn_(bn_module): + + if bn_module is None: + return + + if hasattr(bn_module, 'weight') and bn_module.weight is not None: + new_gamma = tlx.initializers.Ones()(bn_module.weight.shape) + bn_module.weight.data = new_gamma + + if hasattr(bn_module, 'bias') and bn_module.bias is not None: + new_beta = tlx.initializers.Zeros()(bn_module.bias.shape) + bn_module.bias.data = new_beta + + if hasattr(bn_module, 'moving_mean') and bn_module.moving_mean is not None: + new_mean = tlx.initializers.Zeros()(bn_module.moving_mean.shape) + bn_module.moving_mean.data = new_mean + + if hasattr(bn_module, 'moving_var') and bn_module.moving_var is not None: + new_var = tlx.initializers.Ones()(bn_module.moving_var.shape) + bn_module.moving_var.data = new_var + +from gammagl.layers.conv.gcn_unifews import ( + ThrInPrune, LayerNumLogger, rewind, + reset_weight_, reset_bias_, + gcn_norm, add_remaining_self_loops, + layer_dict_gcn, + flops_modules_dict_gcn +) +from gammagl.layers.conv.gat_unifews import layer_dict_gat,flops_modules_dict_gat +from gammagl.layers.conv.gcn2_unifews import layer_dict_gcn2,flops_modules_dict_gcn2 +from gammagl.layers.conv.gsage_unifews import layer_dict_gsage,flops_modules_dict_gsage + +layer_dict = { + **layer_dict_gcn, + **layer_dict_gat, + **layer_dict_gcn2, + **layer_dict_gsage +} + +flops_modules_dict = { + **flops_modules_dict_gcn, + **flops_modules_dict_gat, + **flops_modules_dict_gcn2, + **flops_modules_dict_gsage +} + +kwargs_default = { + 'gcn': { + 'cached': False, + 'add_self_loops': False, + 'improved': False, + 'normalize': False, + 'rnorm': 0.5, + 'diag': 1.0, + 'depth_inv': False, + }, + 'gin': { + 'eps': 0.0, + 'train_eps': False, + 'rnorm': None, + 'diag': 1.0, + 'depth_inv': False, + }, + 'gat': { + 'heads': 1, + 'concat': True, + 'add_self_loops': False, + 'rnorm': None, + 'diag': 1.0, + 'depth_inv': False, + 'depth': 2, + }, + 'gcn2': { + 'alpha': 0.1, + 'beta': 0.5, + 'cached': False, + 'add_self_loops': False, + 'normalize': False, + 'rnorm': 0.5, + 'diag': 1.0, + 'depth_inv': True, + }, + 'gsage': { + 'aggr': 'mean', + 'improved': False, + 'normalize': False, + 'root_weight': True, + 'project': False, + 'bias': True, + 'rnorm': 0.5, + 'diag': 1.0, + 'depth_inv': False, + }, +} + +def state2module(model, name): + parts = name.split('.') + module = model + for part in parts[:-1]: + module = getattr(module, part) + return module + +def set_attr(module, key, value): + if hasattr(module, key): + setattr(module, key, value) + +class GNNThr(nn.Module): + def __init__(self, nlayer, nfeat, nhidden, nclass, thr_a=0.0, thr_w=0.0, dropout=0.0, layer='gcn', **kwargs): + super().__init__() + self.apply_thr = '_' in layer + self.dropout = nn.Dropout(dropout) + self.act = nn.ReLU() + self.use_bn = True + self.kwargs = kwargs + + Conv = layer_dict[layer] + layer_base = layer.split('_')[0] + for k, v in kwargs_default[layer_base].items(): + self.kwargs.setdefault(k, v) + + self.is_gat = layer_base.startswith('gat') + if self.is_gat: + self.heads = self.kwargs['heads'] + self.concat = self.kwargs['concat'] + self.feat_dim = self.heads * nhidden if self.concat else nhidden + else: + self.feat_dim = nhidden + + thr_a = [thr_a] * nlayer if not isinstance(thr_a, list) else thr_a + thr_w = [thr_w] * nlayer if not isinstance(thr_w, list) else thr_w + + self.depth_inv = self.kwargs.pop('depth_inv', False) + self.normalize_adj = self.kwargs.pop('normalize', False) + self.add_self_loops = self.kwargs.pop('add_self_loops', False) + self.cached = self.kwargs.pop('cached', False) + + self.mid_kwargs = self.kwargs.copy() + for k in ['improved', 'rnorm', 'diag']: + self.mid_kwargs.pop(k, None) + + self.final_kwargs = self.mid_kwargs.copy() + for k in ['heads', 'concat']: + self.final_kwargs.pop(k, None) + if self.is_gat: + self.final_kwargs['heads'] = 1 + self.final_kwargs['concat'] = False + + self.convs = nn.ModuleList() + self.norms = nn.ModuleList() + + self.convs.append(Conv(nfeat, nhidden, thr_a=thr_a[0], thr_w=thr_w[0], **self.mid_kwargs)) + self.norms.append(nn.BatchNorm1d(num_features=self.feat_dim, momentum=0.1)) + + for i in range(1, nlayer-1): + self.convs.append(Conv(nhidden, nhidden, thr_a=thr_a[i], thr_w=thr_w[i], **self.mid_kwargs)) + self.norms.append(nn.BatchNorm1d(num_features=self.feat_dim, momentum=0.1)) + + final_in = self.feat_dim if self.is_gat else nhidden + self.convs.append(Conv(final_in, nclass, thr_a=thr_a[-1], thr_w=thr_w[-1], **self.final_kwargs)) + + def reset_parameters(self): + for conv in self.convs: + if hasattr(conv, 'reset_parameters'): + conv.reset_parameters() + for norm in self.norms: + if hasattr(norm, 'reset_parameters'): + norm.reset_parameters() + elif 'BatchNorm' in type(norm).__name__: + reset_bn_(norm) + + def _process_graph(self, x, edge_idx): + if not (self.normalize_adj or self.add_self_loops): + return edge_idx + edge_index, edge_weight = edge_idx if isinstance(edge_idx, (tuple, list)) else (edge_idx, None) + num_nodes = x.shape[0] + if self.normalize_adj: + edge_index, edge_weight = gcn_norm(edge_index, edge_weight, num_nodes, add_self_loops=self.add_self_loops, dtype=x.dtype) + elif self.add_self_loops: + edge_index, edge_weight = add_remaining_self_loops(edge_index, edge_weight, num_nodes=num_nodes) + return (edge_index, edge_weight) + + def forward(self, x, edge_idx, node_lock=tlx.convert_to_tensor([]), verbose=False): + edge_idx = self._process_graph(x, edge_idx) + raw_edge = edge_idx + + if self.apply_thr: + for i, conv in enumerate(self.convs[:-1]): + x, _ = conv(x, raw_edge, node_lock=node_lock, verbose=verbose) + if self.use_bn: + x = self.norms[i](x) + x = self.act(x) + x = self.dropout(x) + x, _ = self.convs[-1](x, raw_edge, node_lock=node_lock, verbose=verbose) + else: + for i, conv in enumerate(self.convs[:-1]): + x = conv(x, raw_edge) + if self.use_bn: + x = self.norms[i](x) + x = self.act(x) + x = self.dropout(x) + x = self.convs[-1](x, raw_edge) + return x + + def get_repre(self, x, edge_idx, layer=None, node_lock=tlx.convert_to_tensor([]), verbose=False): + layer = layer or len(self.convs)-1 + edge_idx = self._process_graph(x, edge_idx) + raw_edge = edge_idx + if self.apply_thr: + for i, conv in enumerate(self.convs[:layer]): + x, _ = conv(x, raw_edge, node_lock=node_lock, verbose=verbose) + if self.use_bn: x = self.norms[i](x) + x = self.act(x); x = self.dropout(x) + x, _ = self.convs[layer](x, raw_edge, node_lock=node_lock, verbose=verbose) + else: + for i, conv in enumerate(self.convs[:layer]): + x = conv(x, raw_edge) + if self.use_bn: x = self.norms[i](x) + x = self.act(x); x = self.dropout(x) + x = self.convs[layer](x, raw_edge) + return x + + def set_scheme(self, scheme_a, scheme_w): + self.apply(lambda m: set_attr(m, 'scheme_a', scheme_a)) + self.apply(lambda m: set_attr(m, 'scheme_w', scheme_w)) + + def remove(self): + for conv in self.convs: + if hasattr(conv, 'prune_lst'): + for m in conv.prune_lst: + if prune.is_pruned(m): prune.remove(m, 'weights') + + def get_numel(self): + numel_a = sum(c.logger_a.numel_after for c in self.convs) + numel_w = sum(c.logger_w.numel_after for c in self.convs) + return numel_a/1e3, numel_w/1e3 + + @classmethod + def batch_counter_hook(cls, module, inp, out): + if not hasattr(module, '__batch_counter__'): module.__batch_counter__ = 0 + module.__batch_counter__ += 1 + +class GNNLPThr(GNNThr): + def __init__(self, nlayer, nfeat, nhidden, nclass, thr_a=0.0, thr_w=0.0, dropout=0.0, layer='gcn', **kwargs): + super().__init__(nlayer, nfeat, nhidden, nhidden, thr_a, thr_w, dropout, layer, **kwargs) + self.lin_out = nn.ModuleList([ + nn.Linear(nhidden, nhidden), + nn.Linear(nhidden, nhidden), + nn.Linear(nhidden, 1), + ]) + + def reset_parameters(self): + super().reset_parameters() + for lin in self.lin_out: + reset_weight_(lin.weights, lin.in_features, initializer='kaiming_uniform') + reset_bias_(lin.biases, lin.in_features, initializer='uniform') + + def decode(self, x_i, x_j): + x = x_i * x_j + for lin in self.lin_out[:-1]: + x = lin(x) + x = self.act(x) + x = self.dropout(x) + return self.lin_out[-1](x) \ No newline at end of file diff --git a/gammagl/models/mlp_unifews.py b/gammagl/models/mlp_unifews.py new file mode 100644 index 000000000..599357004 --- /dev/null +++ b/gammagl/models/mlp_unifews.py @@ -0,0 +1,144 @@ +import os +import tensorlayerx as tlx +import tensorlayerx.nn as nn +from gammagl.gglspeedup.prunes_gamma import prune, rewind, ThrInPrune, ThrProdPrune + +def reset_bn_(bn_module): + if bn_module is None: + return + + if hasattr(bn_module, 'weight') and bn_module.weight is not None: + new_gamma = tlx.convert_to_tensor( + tlx.initializers.Ones()(bn_module.weight.shape), dtype=bn_module.weight.dtype) + tlx.assign(bn_module.weight, new_gamma) + + if hasattr(bn_module, 'bias') and bn_module.bias is not None: + new_beta = tlx.convert_to_tensor( + tlx.initializers.Zeros()(bn_module.bias.shape), dtype=bn_module.bias.dtype) + tlx.assign(bn_module.bias, new_beta) + + if hasattr(bn_module, 'moving_mean') and bn_module.moving_mean is not None: + new_mean = tlx.convert_to_tensor( + tlx.initializers.Zeros()(bn_module.moving_mean.shape), dtype=bn_module.moving_mean.dtype) + tlx.assign(bn_module.moving_mean, new_mean) + + if hasattr(bn_module, 'moving_var') and bn_module.moving_var is not None: + new_var = tlx.convert_to_tensor( + tlx.initializers.Ones()(bn_module.moving_var.shape), dtype=bn_module.moving_var.dtype) + tlx.assign(bn_module.moving_var, new_var) + +from gammagl.layers.conv.gat_unifews import ( + ThrInPrune, LayerNumLogger, rewind, + reset_weight_, reset_bias_, + add_remaining_self_loops +) +from gammagl.layers.conv.gcn_unifews import gcn_norm +from .gnn_unifews import layer_dict + +kwargs_default = { + 'gcn': { + 'cached': False, + 'add_self_loops': False, + 'improved': False, + 'normalize': False, + 'rnorm': 0.5, + 'diag': 1.0, + 'depth_inv': False, + }, + 'gin': { + 'eps': 0.0, + 'train_eps': False, + 'rnorm': None, + 'diag': 1.0, + 'depth_inv': False, + }, + 'gat': { + 'heads': 1, + 'concat': True, + 'add_self_loops': False, + 'rnorm': None, + 'diag': 1.0, + 'depth_inv': False, + 'depth': 2, + }, + 'gcn2': { + 'alpha': 0.1, + 'beta': 0.5, + 'cached': False, + 'add_self_loops': False, + 'normalize': False, + 'rnorm': 0.5, + 'diag': 1.0, + 'depth_inv': True, + }, + 'gsage': { + 'aggr': 'mean', + 'improved': False, + 'normalize': False, + 'root_weight': True, + 'project': False, + 'bias': True, + 'rnorm': 0.5, + 'diag': 1.0, + 'depth_inv': False, + }, +} + +def state2module(model, name): + parts = name.split('.') + module = model + for part in parts[:-1]: + module = getattr(module, part) + return module + +def set_attr(module, key, value): + if hasattr(module, key): + setattr(module, key, value) + + +class MLP_unifews(nn.Module): + def __init__(self, nlayer, nfeat, nhidden, nclass, dropout, thr_w=0.0, layer='sgc'): + super().__init__() + self.dropout = nn.Dropout(dropout) + self.act = nn.ReLU() + self.threshold_w = thr_w + self.scheme_w = 'full' + + self.fcs = nn.ModuleList() + if nlayer == 1: + self.fcs.append(nn.Linear(nfeat, nclass)) + else: + self.fcs.append(nn.Linear(nfeat, nhidden)) + for _ in range(nlayer-2): + self.fcs.append(nn.Linear(nhidden, nhidden)) + self.fcs.append(nn.Linear(nhidden, nclass)) + + for fc in self.fcs: + fc.logger_w = LayerNumLogger(layer) + object.__setattr__(fc, 'act', lambda x: x) + fc.bias = fc._bias if hasattr(fc, '_bias') else None + + def reset_parameters(self): + for lin in self.fcs: + reset_weight_(lin.weights, lin.in_features, initializer='kaiming_uniform') + reset_bias_(lin.biases, lin.in_features, initializer='uniform') + + def apply_prune(self, lin, x): + log = lin.logger_w + log.numel_before = 1 + log.numel_after = 1 + + def forward(self, x, edge_idx=None, *args, **kwargs): + for i, fc in enumerate(self.fcs[:-1]): + self.apply_prune(fc, x) + x = fc(x) + x = self.act(x) + x = self.dropout(x) + self.apply_prune(self.fcs[-1], x) + return self.fcs[-1](x) + + def set_scheme(self, scheme_a, scheme_w): + self.scheme_w = scheme_w + + def get_numel(self): + return 0, sum(fc.logger_w.numel_after for fc in self.fcs)/1e3 \ No newline at end of file diff --git a/gammagl/utils/data_processor.py b/gammagl/utils/data_processor.py new file mode 100644 index 000000000..2fb75e668 --- /dev/null +++ b/gammagl/utils/data_processor.py @@ -0,0 +1,505 @@ +import os +import numpy as np +import scipy.sparse as sp +from sklearn.preprocessing import StandardScaler +from sklearn.model_selection import train_test_split +from typing import Optional, List + + +NTRAIN_PER_CLASS = 20 +NVAL_PER_CLASS = NTRAIN_PER_CLASS * 10 + + +# ==================== +def diag_sp(diag) -> sp.dia_matrix: + """Diagonal array to scipy sparse diagonal matrix""" + n = len(diag) + return sp.dia_matrix((diag, [0]), shape=(n, n)) + + +def to_torch_sparse(x: sp.coo_matrix): + import torch + x = x.tocoo() + shape = x.shape + i = torch.LongTensor(np.vstack((x.row, x.col))) + v = torch.FloatTensor(x.data) + return torch.sparse.FloatTensor(i, v, torch.Size(shape)) + + +def matstd(m, with_mean=False): + """Matrix standardization""" + scaler = StandardScaler(with_mean=with_mean) + m = scaler.fit_transform(m) + return m + + +def matnorm_inf_dual(m, axis=0): + """Normalization of matrix, set positive/negative sum of column to 1 + """ + pos = m.clip(min=0) + possum = pos.sum(axis=axis) + possum[possum == 0] = 1 # Avoid sum = 0 + pos = pos / possum + + neg = m.clip(max=0) + negsum = - neg.sum(axis=axis) + negsum[negsum == 0] = 1 # Avoid sum = 0 + neg = neg / negsum + return (pos + neg) + + +def matstd_clip(m, idx, with_mean=False, clip=False): + """Standardize and clip per feature""" + # idx = np.setdiff1d(idx, [0]) + if (len(idx) > 0.75 * m.shape[0]) and (m.shape[0] > 2e9): + idx = np.random.choice(idx, size=int(len(idx)/5), replace=False) + scaler = StandardScaler(with_mean=with_mean) + scaler.fit(m[idx]) + if clip: + mean, std = scaler.mean_, scaler.scale_ + k = 9 + m = np.clip(m, a_min=mean-k*std, a_max=mean+k*std) + m = scaler.transform(m) + return m + + +def edgeidx2adj(row, col, n, undirected=False): + if undirected: + row, col = np.concatenate([row, col], axis=0), np.concatenate([col, row], axis=0) + ones = np.ones(len(row), dtype=np.int8) + adj = sp.coo_matrix( + (ones, (row, col)), + shape=(n, n)) + adj = adj.tocsr() + adj.setdiag(0) + adj.eliminate_zeros() + adj.data = np.ones(len(adj.data), dtype=np.int8) + return adj + + +def split_random(seed, n, n_train, n_val): + """Split index randomly""" + np.random.seed(seed) + rnd = np.random.permutation(n) + + train_idx = np.sort(rnd[:n_train]) + val_idx = np.sort(rnd[n_train:n_train + n_val]) + + train_val_idx = np.concatenate((train_idx, val_idx)) + test_idx = np.sort(np.setdiff1d(np.arange(n), train_val_idx)) + return train_idx, val_idx, test_idx + + +def split_label(seed, n, n_train_per_class, n_val, labels): + """Split index with equal label in train set""" + np.random.seed(seed) + rnd = set(np.arange(n)) + + train_idx = np.array([], dtype=int) + if labels.ndim == 1: + lb_nonnan = labels[~np.isnan(labels)] + nclass = int(lb_nonnan.max()) + 1 + for i in range(nclass): + cdd = np.where(labels == i)[0] + sz = min(n_train_per_class, len(cdd)) + idxi = np.random.choice(cdd, size=sz, replace=False) + train_idx = np.concatenate((train_idx, idxi)) + else: + nclass = labels.shape[1] + for i in range(nclass): + cdd = np.where(labels[:, i] > 0)[0] + sz = min(n_train_per_class, len(cdd)) + idxi = np.random.choice(cdd, size=sz, replace=False) + train_idx = np.concatenate((train_idx, idxi)) + + train_idx = np.unique(train_idx.flatten(), axis=0) + val_idx = np.array((list( rnd - set(train_idx) ))) + val_idx = np.random.choice(val_idx, size=n_val, replace=False) + val_idx = np.sort(val_idx) + + train_val_idx = np.concatenate((train_idx, val_idx)) + test_idx = np.sort(np.setdiff1d(np.arange(n), train_val_idx)) + return train_idx, val_idx, test_idx + + +def split_stratify(seed, n, n_train, n_val, labels): + assert labels.ndim == 1, 'Only support 1D labels' + idx = np.arange(n) + train_idx, test_idx = train_test_split(idx, train_size=n_train, random_state=seed, stratify=labels) + val_idx, test_idx = train_test_split(test_idx, train_size=n_val, random_state=seed, stratify=labels[test_idx]) + return train_idx, val_idx, test_idx + + +# ==================== +class DataProcess(object): + def __init__(self, name: str, path: str='../data/', rrz: float=0.5, seed: int=0) -> None: + super().__init__() + self.name = name + self.path = path + self.rrz = rrz + self.seed = seed + + self._n = None + self._m = None + self._nfeat = None + self._nclass = None + + self.adjnpz_path = self._get_path('adj.npz') + self.adjtxt_path = self._get_path('adj.txt') + self.degree_path = self._get_path('degree.npz') + self.labels_path = self._get_path('labels.npz') + self.query_path = self._get_path('query.txt') + self.querytrain_path = self._get_path('query_train.txt') + self.feats_path = self._get_path('feats.npy') + self.featsnorm_path = self._get_path(f'feats_normt_{self.rrz:g}.npy') + + self.adj_matrix = None + self.deg = None + self.labels = None # Labels can be 1D array or 2D one hot + self.idx_train = None + self.idx_val = None + self.idx_test = None + self.attr_matrix = None + self.attr_matrix_norm = None + + def _get_path(self, fname): + return os.path.join(self.path, self.name, fname) + + @property + def n(self) -> int: + if self._n: + return self._n + # 1. Use cache adj matrix + if self.adj_matrix is not None: + self._n = self.adj_matrix.shape[0] + # 2: Use attribute file + elif os.path.isfile(self._get_path('attribute.txt')): + with open(self._get_path('attribute.txt'), 'r') as attr_f: + nline = attr_f.readline().rstrip() + self._n = int(''.join(filter(str.isdigit, nline))) + # 3: Use label length | WARNING: incorrect for inductive dataset + elif self.labels is not None: + self._n = len(self.labels) + return self._n + + @property + def n_train(self) -> int: + return len(self.idx_train) + + @property + def n_val(self) -> int: + return len(self.idx_val) + + @property + def n_test(self) -> int: + return len(self.idx_test) + + @property + def m(self) -> int: + if self._m: + return self._m + # 1: Use cache adj matrix + if self.adj_matrix is not None: + self._m = self.adj_matrix.nnz + # 2: Use attribute file + elif os.path.isfile(self._get_path('attribute.txt')): + with open(self._get_path('attribute.txt'), 'r') as attr_f: + nline = attr_f.readline().rstrip() + mline = attr_f.readline().rstrip() + self._m = int(''.join(filter(str.isdigit, mline))) + # 3: Count by wc -l + else: + import subprocess + self._m = int(subprocess.check_output(["wc", "-l", self.adjtxt_path]).split()[0]) + return self._m + + @property + def nfeat(self) -> int: + if self._nfeat: + return self._nfeat + if self.attr_matrix is None: + self.input(['attr_matrix']) + self._nfeat = self.attr_matrix.shape[1] + return self._nfeat + + @property + def nclass(self) -> int: + if self._nclass: + return self._nclass + if self.labels is None: + self.input(['labels']) + # 1D array + if self.labels.ndim == 1: + # self._nclass = int(self.labels.max()) + 1 + lb_nonnan = self.labels[~np.isnan(self.labels)] + self._nclass = int(lb_nonnan.max()) + 1 + # 2D one hot + else: + self._nclass = self.labels.shape[1] + return self._nclass + + def __str__(self) -> str: + s = f"n={self.n}, m={self.m}, F={self.nfeat}, C={self.nclass} | " + s += f"feat: {self.attr_matrix.shape}, label: {self.labels.shape} | " + s += f"{self.n_train}/{self.n_val}/{self.n_test}=" + s += f"{self.n_train/self.n:0.2f}/{self.n_val/self.n:0.2f}/{self.n_test/self.n:0.2f}" + return s + + def calculate(self, lst: List[str]) -> None: + for key in lst: + if key == 'deg': + assert self.adj_matrix is not None + assert np.sum(self.adj_matrix.diagonal()) == 0 + self.deg = self.adj_matrix.sum(1).A1 + elif key in ['idx_train', 'idx_val', 'idx_test']: + # n_train = NTRAIN_PER_CLASS * self.nclass + # n_val = NVAL_PER_CLASS * self.nclass + n_train = int(0.50 * self.n) + n_val = int(0.25 * self.n) + if 'paper' == self.name: + np.random.seed(self.seed) + self.input(['idx_train', 'idx_val', 'idx_test']) + rnd = np.concatenate((self.idx_train, self.idx_val, self.idx_test)) + rnd = np.random.permutation(rnd) + self.idx_train = np.sort(rnd[:n_train]) + self.idx_val = np.sort(rnd[n_train:n_train + n_val]) + self.idx_test = np.sort(rnd[n_train + n_val:]) + elif 'papers' in self.name: + # Common + n_train = NTRAIN_PER_CLASS * self.nclass + n_val = int(1.5 * NTRAIN_PER_CLASS * self.nclass) + self.input(['idx_train', 'idx_val', 'idx_test']) + rnd = np.concatenate((self.idx_train, self.idx_val, self.idx_test)) + idxrange = n_train + n_val + 2000 + rnd = np.random.permutation(rnd)[:idxrange] + idx_train, idx_val, idx_test = split_label(self.seed, len(rnd), NTRAIN_PER_CLASS, n_val, self.labels[rnd]) + # NIGCN + # self.input(['idx_train', 'idx_val', 'idx_test']) + # rnd = np.concatenate((self.idx_train, self.idx_val, self.idx_test)) + # n_train, n_val = 1000, 25000 + # idxrange = n_train + n_val + 25000 + # idx_train, idx_val, idx_test = split_random(self.seed, idxrange, n_train, n_val) + + self.idx_train = np.sort(rnd[idx_train]) + self.idx_val = np.sort(rnd[idx_val]) + self.idx_test = np.sort(rnd[idx_test]) + elif 'mag' in self.name: + # self.idx_train, self.idx_val, self.idx_test = split_label(self.seed, self.n, NTRAIN_PER_CLASS * 5, n_val, self.labels) + self.idx_train, self.idx_val, self.idx_test = split_stratify(self.seed, self.n, n_train * 5, n_val, self.labels) + elif 'ppi' in self.name or 'protein' in self.name or 'yelp' in self.name or 'amazon' in self.name: + self.idx_train, self.idx_val, self.idx_test = split_random(self.seed, self.n, n_train, n_val) + else: + # self.idx_train, self.idx_val, self.idx_test = split_random(self.seed, self.n, n_train, n_val) + # self.idx_train, self.idx_val, self.idx_test = split_label(self.seed, self.n, NTRAIN_PER_CLASS, n_val, self.labels) + self.idx_train, self.idx_val, self.idx_test = split_stratify(self.seed, self.n, n_train, n_val, self.labels) + elif key == 'labels_oh': + if self.labels.ndim == 2: + self.labels_oh = self.labels + else: + self.labels_oh = np.zeros((self.n, self.nclass), dtype=np.int8) + idx = ~ np.isnan(self.labels) + row = np.arange(self.labels.size) + self.labels_oh[row[idx], self.labels[idx]] = 1 + elif key == 'role': + self.role = {} + self.role['tr'] = self.idx_train.tolist() + self.role['va'] = self.idx_val.tolist() + self.role['te'] = self.idx_test.tolist() + elif key == 'mask': + self.mask = {} + self.mask['tr'] = np.zeros(self.n, dtype=bool) + self.mask['va'] = np.zeros(self.n, dtype=bool) + self.mask['te'] = np.zeros(self.n, dtype=bool) + self.mask['tr'][self.idx_train] = True + self.mask['va'][self.idx_val] = True + self.mask['te'][self.idx_test] = True + elif key == 'edge_idx': + self.edge_idx = self.adj_matrix.tocoo().copy() + self.edge_idx = np.vstack([self.edge_idx.row, self.edge_idx.col]) + elif key == 'attr_matrix_norm': + assert self.attr_matrix is not None + assert self.deg is not None + deg_pow = np.power(np.maximum(self.deg, 1e-12), 1 - self.rrz) + deg_pow = diag_sp(deg_pow).astype(np.float32) + self.attr_matrix_norm = deg_pow @ matstd(self.attr_matrix) + self.attr_matrix_norm = matnorm_inf_dual(self.attr_matrix_norm).astype(np.float32) + self.attr_matrix_norm = self.attr_matrix_norm.transpose().astype(np.float32, order='C') + else: + print("Key not exist: {}".format(key)) + + def input(self, lst: List[str]) -> None: + for key in lst: + if key == 'adjnpz': + self.adj_matrix = sp.load_npz(self.adjnpz_path) + # assert self.adj_matrix.diagonal().sum() == 0, "adj_matrix error" + elif key == 'adjtxt': + with open(self.adjtxt_path, 'r') as attr_f: + nline = attr_f.readline().rstrip() + self._n = int(''.join(filter(str.isdigit, nline))) + adjtxt = np.loadtxt(self.adjtxt_path) + self._m = adjtxt.shape[0] + ones = np.ones((self.m), dtype=np.int8) + self.adj_matrix = sp.coo_matrix( + (ones, (adjtxt[:, 0], adjtxt[:, 1])), + shape=(self.n, self.n)) + self.adj_matrix = self.adj_matrix.tocsr() + # assert self.adj_matrix.diagonal().sum() == 0, "adj_matrix error" + elif key == 'deg': + #self.deg = dict(np.load(self.degree_path))['arr_0'] + + data = np.load(self.degree_path, allow_pickle=True) + self.deg = data[next(iter(data.keys()))] + elif key == 'labels': + self.labels = dict(np.load(self.labels_path, allow_pickle=True))['labels'] + # assert (self.labels.dim()==2 and self.labels.shape[1]==1) or self.labels.dim()==1, "label shape error" + elif key == 'idx_train': + self.idx_train = dict(np.load(self.labels_path, allow_pickle=True))['idx_train'] + elif key == 'idx_val': + self.idx_val = dict(np.load(self.labels_path, allow_pickle=True))['idx_val'] + elif key == 'idx_test': + self.idx_test = dict(np.load(self.labels_path, allow_pickle=True))['idx_test'] + elif key == 'attr_matrix': + self.attr_matrix = np.load(self.feats_path) + elif key == 'attr_matrix_norm': + self.attr_matrix_norm = np.load(self.featsnorm_path) + else: + print("Key not exist: {}".format(key)) + + def output(self, lst: List[str]) -> None: + for key in lst: + if key == 'adjnpz': + self.adj_matrix = self.adj_matrix.tocsr() + assert sp.isspmatrix_csr(self.adj_matrix) + sp.save_npz(self.adjnpz_path, self.adj_matrix) + elif key == 'adjtxt': + self.adj_matrix = self.adj_matrix.tocoo() + with open(self.adjtxt_path, 'w') as f: + f.write("# {:d}\n".format(self.n)) + for i in range(self.m): + f.write("{:d} {:d}\n".format(self.adj_matrix.row[i], self.adj_matrix.col[i])) + elif key == 'adjl': + import struct + self.adj_matrix = self.adj_matrix.tocsr() + assert sp.isspmatrix_csr(self.adj_matrix), "Adj wrong format" + el = self.adj_matrix.indices + pl = self.adj_matrix.indptr + el = np.array(el, dtype=np.uint32) + pl = np.array(pl, dtype=np.uint32) + + el_re = [] + for i in range(1,pl.shape[0]): + el_re += sorted(el[pl[i-1]:pl[i]],key=lambda x:pl[x+1]-pl[x]) + el = np.asarray(el_re,dtype=np.uint32) + + with open(self._get_path('adj_el.bin'), 'wb') as f: + for i in el: + m = struct.pack('I', i) + f.write(m) + with open(self._get_path('adj_pl.bin'), 'wb') as f: + for i in pl: + m = struct.pack('I', i) + f.write(m) + elif key == 'attribute': + with open(self._get_path("attribute.txt"), 'w') as f: + f.write(f"n={self.n:d}\n") + f.write(f"m={self.m:d}") + elif key == 'deg': + np.savez_compressed(self.degree_path, self.deg) + elif key in ['labels', 'idx_train', 'idx_val', 'idx_test']: + labels_dict = {'labels': self.labels, + 'idx_train': self.idx_train, + 'idx_val': self.idx_val, + 'idx_test': self.idx_test} + np.savez_compressed(self.labels_path, **labels_dict) + elif key == 'query': + query = np.arange(self.n, dtype=int) + np.savetxt(self.query_path, query, fmt='%d', delimiter='\n') + elif key == 'query_topdeg': + k = self.nfeat + idx = np.argpartition(self.deg, -k)[-k:] + np.savetxt(self.query_path, idx, fmt='%d', delimiter='\n') + elif key == 'query_train': + assert self.idx_train is not None + np.savetxt(self.querytrain_path, self.idx_train, fmt='%d', delimiter='\n') + elif key == 'attr_matrix': + self.attr_matrix = self.attr_matrix.astype(np.float32, order='C') + np.save(self.feats_path, self.attr_matrix) + elif key == 'attr_matrix_norm': + self.attr_matrix_norm = self.attr_matrix_norm.astype(np.float32, order='C') + np.save(self.featsnorm_path, self.attr_matrix_norm) + else: + print("Key not exist: {}".format(key)) + + def output_split(self, attr_matrix: Optional[np.ndarray], spt: int=10, name: str='feats') -> None: + """Split large matrix by feature dimension.""" + from tqdm import trange + if attr_matrix is None: + attr_matrix = self.attr_matrix + n = attr_matrix.shape[0] + nd = n // spt + for i in trange(spt): + if i < spt - 1: + idxl, idxr = i * nd, (i+1) * nd + else: + idxl, idxr = i * nd, n + prt = attr_matrix[idxl:idxr, :] + + prt_path = self._get_path('{}_{}.npy'.format(name, i)) + np.save(prt_path, prt) + + def to_undirected(self) -> None: + self.adj_matrix = self.adj_matrix + self.adj_matrix.T + self.adj_matrix.data = np.ones(len(self.adj_matrix.data), dtype=np.int8) + self._m = self.adj_matrix.nnz + + self.calculate(['deg']) + idx_zero = np.where(self.deg == 0)[0] + if len(idx_zero) > 0: + print(f"Warning: {len(idx_zero)} isolated nodes found: {idx_zero}!") + + +class DataProcess_inductive(DataProcess): + """ Inductive processor for generating adj and attr of only train nodes + For graphs with edge attributes refer to pyg.utils.subgraph + """ + def __init__(self, name: str, path: str = '../data/', rrz: float = 0.5, seed: int = 0) -> None: + name = name + '_train' + super().__init__(name, path, rrz, seed) + + def fetch(self) -> None: + # Node index mapping + assert self.adj_matrix is not None + assert self.idx_train is not None + nraw = self.adj_matrix.shape[0] + ntrain = self.n_train + mapping = - np.ones(nraw, dtype=int) # key: old index, value: 0~ntrain-1 + mapping[self.idx_train] = np.arange(ntrain) + + # Construct induced adj + self.adj_matrix = self.adj_matrix.tocoo() + idx_in = np.isin(self.adj_matrix.col, self.idx_train) & np.isin(self.adj_matrix.row, self.idx_train) + col, row = mapping[self.adj_matrix.col[idx_in]], mapping[self.adj_matrix.row[idx_in]] + ones = np.ones(len(col), dtype=np.int8) + self.adj_matrix = sp.coo_matrix( + (ones, (col, row)), + shape=(ntrain, ntrain)) + + self.adj_matrix = self.adj_matrix.tocsr() + self.adj_matrix.setdiag(0) + self.adj_matrix.eliminate_zeros() + self.adj_matrix.data = np.ones(len(self.adj_matrix.data), dtype=np.int8) + self._n, self._m = ntrain, self.adj_matrix.nnz + + # Construct induced attr + assert self.attr_matrix is not None + self.attr_matrix = self.attr_matrix[self.idx_train] + self.attr_matrix = self.attr_matrix.astype(np.float32, order='C') + + +if __name__ == '__main__': + dp = DataProcess('pubmed', seed=0) + dp.input(['adjtxt', 'attr_matrix', 'labels']) + dp.calculate(['deg', 'idx_train']) + dp.output(['deg', 'query']) + print(dp) + diff --git a/gammagl/utils/data_transfer_unifews.py b/gammagl/utils/data_transfer_unifews.py new file mode 100644 index 000000000..26a8bacbe --- /dev/null +++ b/gammagl/utils/data_transfer_unifews.py @@ -0,0 +1,20 @@ +"""Deprecated: Data conversion helpers moved to examples/unifews/convert_data.py + +The classes in this module have been moved to examples/unifews/convert_data.py +with configurable dataset paths. Use that module instead. +""" + +import warnings + +warnings.warn( + "gammagl.utils.data_transfer_unifews is deprecated. " + "Use examples.unifews.convert_data instead.", + DeprecationWarning, + stacklevel=2 +) + +from examples.unifews.convert_data import ( + DataProcess_OGB, + DataProcess_PyGFlickr, + DataProcess_PyG, +) diff --git a/gammagl/utils/gen_cat.py b/gammagl/utils/gen_cat.py new file mode 100644 index 000000000..ebe1f634d --- /dev/null +++ b/gammagl/utils/gen_cat.py @@ -0,0 +1,680 @@ +# This code is from: [GenCAT](https://github.com/seijimaekawa/GenCAT) +import numpy as np +from numpy import linalg as la +from scipy import sparse +import random +import copy +import sys +import powerlaw + +import warnings +warnings.simplefilter('ignore') + + +def config_diagonal(M, D, x=1): + import copy + k = M.shape[0] # number of classes + M_ = copy.deepcopy(M) + D_ = copy.deepcopy(D) + if x != 0: + for i in range(k): # for each diagonal element + # for i in range(int(k/2)+1,k): + for j in range(k): + if i == j: + M_[i][j] -= 0.1 * x + else: + M_[i][j] += (0.1 * x) / (k - 1) + M_[M_ < 0] = 0 + for i in range(k): + # for i in range(int(k/2)+1,k): + for j in range(k): + if i == j: + D_[i][j] = D_[i][j] * (M_[i][j] / (M_[i][j] + 0.1 * x)) + else: + D_[i][j] = D_[i][j] * (M_[i][j] / (M_[i][j] - (0.1) / (k - 1))) + for i in range(k): + M_[i] = M_[i] / sum(M_[i]) + D_[D_ <= 0] = 0 + return M_, D_ + + +def feature_extraction(S, X, Label): + k = max(Label) + 1 + M, D = calc_class_features(S, k, Label) + H = calc_attr_cor(X, Label) + + partition = [] + for i in range(k): + partition.append([]) + for i in range(len(Label)): + partition[Label[i]].append(i) + + class_size = [] + for i in partition: + class_size.append(len(i)) + class_size = np.array(class_size) / sum(class_size) + + # node degree + theta = np.zeros(len(Label)) + nnz = S.nonzero() + for i in range(len(nnz[0])): + if nnz[0][i] < nnz[1][i]: + theta[nnz[0][i]] += 1 + theta[nnz[1][i]] += 1 + + return M, D, list(class_size), H, sorted(theta, reverse=True) + + +def calc_class_features(S, k, Label): + pref = np.zeros((len(Label), k)) + nnz = S.nonzero() + for i in range(len(nnz[0])): + if nnz[0][i] < nnz[1][i]: + pref[nnz[0][i]][Label[nnz[1][i]]] += 1 + pref[nnz[1][i]][Label[nnz[0][i]]] += 1 + for i in range(len(Label)): + pref[i] /= sum(pref[i]) + pref = np.nan_to_num(pref) + + partition = [] + for i in range(k): + partition.append([]) + for i in range(len(Label)): + partition[Label[i]].append(i) + + # caluculate average and deviation of class preference + from statistics import mean, median, variance, stdev + class_pref_mean = np.zeros((k, k)) + class_pref_dev = np.zeros((k, k)) + for i in range(k): + pref_tmp = [] + for j in partition[i]: + pref_tmp.append(pref[j]) + pref_tmp = np.array(pref_tmp).transpose() + for h in range(k): + class_pref_mean[i, h] = mean(pref_tmp[h]) + if len(pref_tmp[h]) > 1: + class_pref_dev[i, h] = stdev(pref_tmp[h]) + else: + class_pref_dev[i, h] = 0 + return class_pref_mean, class_pref_dev + + +def calc_attr_cor(X, Label): + k = max(Label) + 1 + n = X.shape[0] + d = X.shape[1] + + partition = [] + for i in range(k): + partition.append([]) + for i in range(len(Label)): + partition[Label[i]].append(i) + + from statistics import mean + attr_cor = np.zeros((d, k)) + for i in range(k): + tmp = np.zeros(d) + for j in partition[i]: + tmp += X[j] + attr_cor[:, i] = tmp / len(partition[i]) + return attr_cor + + +def node_deg(n,m,max_deg,p=3.): + simulated_data = [0] + while sum(simulated_data)/2 < m: + theoretical_distribution = powerlaw.Power_Law(xmin = 1., parameters = [p]) + simulated_data=theoretical_distribution.generate_random(n) + over_list = np.where(simulated_data>max_deg)[0] + while len(over_list) != 0: + add_deg = theoretical_distribution.generate_random(len(over_list)) + for i,node_id in enumerate(over_list): + simulated_data[node_id] = add_deg[i] + over_list = np.where(simulated_data>max_deg)[0] + simulated_data = np.round(simulated_data) + if (m - sum(simulated_data)/2) < m/5: + p -= 0.01 + else: + p -= 0.1 + if p<1.01: + print("break") + break + # print("expected number of edges : ",int(sum(simulated_data)/2)) + return sorted(simulated_data,reverse=True) + + +def count_node_degree(S): + n = S.shape[0] + node_degree = np.zeros(n) + nnz = S.nonzero() + for i in range(len(nnz[0])): + if nnz[0][i] < nnz[1][i]: + node_degree[nnz[0][i]] += 1 + node_degree[nnz[1][i]] += 1 + return int(sum(node_degree)/2) + +def distribution_generator(flag, para_pow, para_normal, para_zip, t): + if flag == "power_law": + dist = 1 - np.random.power(para_pow, t) # R^{k} + elif flag == "uniform": + dist = np.random.uniform(0,1,t) + elif flag == "normal": + dist = np.random.normal(0.5,para_normal,t) + elif flag == "zipfian": + dist = np.random.zipf(para_zip,t) + return dist + +def class_size_gen(k,phi_c): + # chi = distribution_generator("power_law",phi_c,0,0, k) + chi = distribution_generator("normal",phi_c,0,0, k) + return np.array(chi) / sum(chi) + + +def latent_factor_gen(n,k,M,D,class_size): + import sys + density = np.zeros(k) + for l in range(k): + density[l] = M[l,l] + + # generate U from class preference matrix + def reverse(U_tmp,l): + U_ = 1 - U_tmp + sum_U_ = sum(U_) - U_tmp[l] + for i in range(k): + if i != l: + U_[i] = U_[i] * U_tmp[l] / sum_U_ + return U_ + + U = np.zeros((n,k)) + C=[] + for i in range(n): + C_tmp = random.choices(list(range(0,k)),k=1,weights=class_size)[0] + C.append(C_tmp) + for h in range(k): + U[i,h] = np.random.normal(loc=M[C_tmp][h],scale=D[C_tmp][h],size=1)[0] + if M[C_tmp,C_tmp] < 1/k: # for heterophily + U[i] = reverse(U[i],C_tmp) + + # eliminate U<0 and U>1 (keep 0<=U<=1) + minus_list = np.where(U <= 0) + for i in range(len(minus_list[0])): + U[minus_list[0][i],minus_list[1][i]] = sys.float_info.epsilon + one_list = np.where(U > 1) + for i in range(len(one_list[0])): + U[one_list[0][i],one_list[1][i]] = 1 + # normalize + for i in range(n): + U[i] /= sum(U[i]) + + return U,C,density + +def adjust(n,k,U,C,M): # fitting latent factors to given statistics by minimizing loss + U_prime = copy.deepcopy(U) + partition = [] + for l in range(k): + partition.append([]) + for i in range(len(C)): + partition[C[i]].append(i) + + # Freezing function + def freez_func(q,Th): + return q**(1/Th) / np.sum(q**(1/Th)) + + # Reverse function + def reverse(U_tmp,l): + U_ = 1 - U_tmp + sum_U_ = sum(U_) - U_tmp[l] + for i in range(k): + if i != l: + U_[i] = U_[i] * U_tmp[l] / sum_U_ + return U_ + flag=0 + for l in range(k): + loss_min = float('inf') + if M[l][l] >= 1/k: + for Th in np.arange(0.01,1,0.01): + sum_estimated = np.zeros(k) + for i in partition[l]: + sum_estimated += freez_func(U[i],Th) * freez_func(U[i],Th) + loss_tmp = la.norm(M[l]-sum_estimated/len(partition[l])) + if loss_tmp < loss_min: + loss_min = loss_tmp + Th_min = Th + for i in partition[l]: + U[i] = freez_func(U[i],Th_min) + U_prime[i] = U[i] + else: + for Th in np.arange(0.01,1,0.01): + sum_estimated = np.zeros(k) + for i in partition[l]: + sum_estimated += freez_func(U[i],Th) * reverse(freez_func(U[i],Th),l) + loss_tmp = la.norm(M[l]-sum_estimated/len(partition[l])) + if loss_tmp < loss_min: + loss_min = loss_tmp + Th_min = Th + for i in partition[l]: + U[i] = freez_func(U[i],Th_min) + U_prime[i] = reverse(U[i],l) +# print(Th_min) + return U, U_prime + +def edge_construction(n, U, k, U_prime, step, theta, r): + U_ = copy.deepcopy(U) + + S = sparse.dok_matrix((n,n)) + degree_list = np.zeros(n) + count_list = [] + + print_count = 1 + for i in range(n): # for each node +# if i/n * 10 > print_count: +# print("finished " +str(print_count)+"0%") +# print_count += 1 + count = 0 + ng_list = set([i]) # store how many loops edge generation needs. This is used to monitoring the behavior of GenCAT, not an algorithmic part. + while count < r and degree_list[i] < theta[i]: + to_classes = random.choices(list(range(0,k)), k=int(theta[i]-degree_list[i]), weights=U_[i]) # choose classes based on U[i] + for to_class in to_classes: + for loop in range(50): + j = U_prime[to_class][int(random.random()/step)] # choose nodes from to_class based on the probability of U'[to_class] + if j not in ng_list: # if node i and node j do not have an edge + ng_list.add(j) + break + if degree_list[j] < theta[j] and i!=j: + S[i,j] = 1;S[j,i] = 1 # undirected edge + degree_list[i]+=1;degree_list[j]+=1 + count += 1 + count_list.append(count) # store how many loops edge generation needs. This is used to monitoring the behavior of GenCAT, not an algorithmic part. + return S, count_list + +def ITS_U_prime(n,k,U_prime,step): # inverse transformation sampling for efficient edge generation + class_list = [] + UT = U_prime.transpose() + for i in range(k): # create a cumulative distribution for each class + UT_tmp = UT[i]/ sum(UT[i]) + for j in range(n-1): + UT_tmp[j+1] += UT_tmp[j] + + class_tmp = [] + node_counter = 0 + for l in np.arange(0,1,step): + if node_counter >= n-1: + class_tmp.append(n-1) + elif UT_tmp[node_counter] > l: + class_tmp.append(node_counter) + else: + node_counter += 1 + class_tmp.append(node_counter) + class_list.append(class_tmp) + return class_list + + +def adjust_att(n,k,d,U,C,H): + V = copy.deepcopy(H) + partition = [] + for i in range(k): + partition.append([]) + for i in range(len(C)): + partition[C[i]].append(i) + + # Freezing function + def freez_func(q,Th): + return q**(1/Th) / np.sum(q**(1/Th)) + + P = np.zeros((k,k)) + for l in range(k): + for j in partition[l]: + P[l] += U[j] + P[l] = P[l]/len(partition[l]) + + for delta in range(d): + loss = [] + for Th in np.arange(0.1, 1.1, 0.05): + loss.append(np.linalg.norm(H[delta] - P @ freez_func(V[delta],Th).T)) + V[delta] = freez_func(V[delta],0.1*(np.argmin(loss)+1)) + return V + +def attribute_generation(n,d,k,U,V,C,omega,att_type,att_scale): + X = U@V.T + + if att_type == "normal": + for i in range(d): # each attribute demension + clus_dev = np.random.uniform(omega*0.99,omega*1.01,k) # variation for each class + for p in range(n): # each node + X[p,i] += np.random.normal(0.0,clus_dev[C[p]],1) + # normalization [0,1] + for i in range(d): + X[:,i] -= min(X[:,i]) + X[:,i] /= max(X[:,i]) + elif att_type == "Bernoulli": # Bernoulli distribution + for i in range(d): + X[:,i] = X[:,i] * att_scale[i] + X = X - np.random.rand(n,d) # efficient approach to compute probabilities + X[X>=0] = 1 + X[X<0] = 0 + else: + raise NotImplemented + return np.nan_to_num(X) + + +##################################################################################### +##################################################################################### +############################ modules for ablation study ################################### +##################################################################################### +##################################################################################### +# These modules are not used for the functionality of FULL GenCAT. + +def adjust_woAP(n,k,U,C,density): # skip adjusting phases + U_prime = copy.deepcopy(U) + partition = [] + for i in range(k): + partition.append([]) + for i in range(len(C)): + partition[C[i]].append(i) + + def reverse(U_tmp,l,k): + U_ = 1 - U_tmp + sum_U_ = sum(U_) - U_tmp[l] + for i in range(k): + if i != l: + U_[i] = U_[i] * U_tmp[l] / sum_U_ + return U_ + for l in range(k): + if density[l] < 1/k: # for heterphilic classes + for j in partition[l]: + U_prime[i] = reverse(U[j],l,k) + return U_prime + +def edge_construction_wo_ITS(n, U, k, U_primeT, theta, r): # not using inverse transform sampling for U' + S = sparse.dok_matrix((n,n)) + degree_list = np.zeros(n) + count_list = [] + + print_count = 1 + reconst = U @ U_primeT # simply construct probability matrix + for i in range(n): + count = 0 + ng_list = set([i]) # A temporal set storing node IDs that have already connected with node i. This set improves the efficiency. + + while count < r and degree_list[i] < theta[i]: + to_nodes = random.choices(list(range(0,n)), k=int(theta[i]-degree_list[i]), weights=reconst[i]) + for j in to_nodes: + if degree_list[j] < theta[j] and i!=j: + S[i,j] = 1;S[j,i] = 1 + degree_list[i]+=1;degree_list[j]+=1 + count += 1 + count_list.append(count) # store how many loops edge generation needs. This is used to monitoring the behavior of GenCAT, not an algorithmic part. + return S, count_list + + +########################################################################################## +########################################################################################## +################################# main function ########################################## +########################################################################################## +########################################################################################## + + +def gencat(M, # class preference mean : R^{kxk} + D, # class preference deviation : R^{kxk} + H, # attribute-class correlation : R^{dxk} + class_size=None, # class size distribution : R^k + n=3000, # the number of nodes : integer + m=5000, # expected number of edges : integer + p=3., # parameter of power-law distribution : float + max_deg=None, # upper bound of node degree : integer + theta=None, # node degree distribution : R^n. Users need to set either "n and m" or "theta". + phi_c=1, # parameter of power-low distribution for class size : float. Users need to set either "class_size" or "phi_c". + omega=0.2, # deviation of normal distribution for attribute generation : float + r=20, # the number of iterations for generating edges : integer + step=50, # the number of steps for inverse transform sampling : integer + att_type="normal", # attribute type : [Bernoulli, normal] + woAP=False,woITS=False): + k=M.shape[0] + d=H.shape[0] + + if theta == None: + # node degree generation + if max_deg == None: + max_deg = int(n/10) + theta = node_deg(n,m,max_deg,p) + else: + n = len(theta) + m = int(sum(theta)/2) +# line_warn(sum(theta)/2) + + # class generation + if class_size == None: + class_size = class_size_gen(k,phi_c) + U,C,density = latent_factor_gen(n,k,M,D,class_size) + + # adjusting phase + if not woAP: + U,U_prime = adjust(n,k,U,C,M) + else: + print("woAP") + U_prime = adjust_woAP(n,k,U,C,M) + + # Inverse Transform Sampling + if not woITS: + step = 1/(n*step) + U_prime_CDF = ITS_U_prime(n,k,U_prime,step) + + # Edge generation + S_gen, count_list = edge_construction(n, U, k, U_prime_CDF, step, theta, r) + else: + print("woITS") + S_gen, count_list = edge_construction_wo_ITS(n, U, k, U_prime.T, theta, r) + + # print("number of generated edges : " + str(count_node_degree(S_gen))) + + ### Attribute + att_scale = np.sum(H, axis=1) + H_ = copy.deepcopy(H) + for i in range(d): + H_[i] = H_[i] / sum(H_[i]) + + # adjust attribute-class proportion + V= adjust_att(n,k,d,U,C,H_) + + # Attribute generation + X = attribute_generation(n,d,k,U,V,C,omega,att_type,att_scale) + + return S_gen,X,C + +########################################################################################## +########################################################################################## +############################## for simple input ########################################## +########################################################################################## +########################################################################################## + +def gencat_simple(n,m,density,H,class_size=None,max_deg=None,p=3.,theta=None,phi_c=1,omega=0.2,r=50,step=100,att_type="normal"): + d=H.shape[0] + k=len(density) + # node degree generation + if theta == None: + # node degree generation + if max_deg == None: + max_deg = int(n/10) + theta = node_deg(n,m,max_deg,p) + else: + n = len(theta) + m = int(sum(theta)/2) + + # generate class preference mean from given diagonal elements + M = np.zeros((k,k)) + for l1 in range(k): + for l2 in range(k): + if l1==l2: + M[l1][l2] = density[l1] + else: + M[l1][l2] = (1-density[l1]) / (k-1) + + + # class generation + if class_size==None: + U,C = class_generation(n,k,phi_c) + + # adjusting phase + U,U_prime = adjust(n,k,U,C,M) + + # Inverse Transform Sampling + step = 1/(n*step) + U_prime_CDF = ITS_U_prime(n,k,U_prime,step) + + # Edge generation + S_gen, count_list = edge_construction(n, U, k, U_prime_CDF, step, theta, r) + + # Attribute + att_scale = np.sum(H, axis=1) + H_ = copy.deepcopy(H) + for i in range(d): + H_[i] = H_[i] / sum(H_[i]) + V = adjust_att(n,k,d,U,C,H_) + + # Attribute generation + X = attribute_generation(n,d,k,U,V,C,omega,att_type,att_scale) + + return S_gen,X,C + +def class_generation(n, k, phi_c): + class_size = class_size_gen(k,phi_c) + + U = np.random.dirichlet(class_size, n) + C = [] # class assignment list (finally, R^{n}) + for i in range(n): + C.append(np.argmax(U[i])) + + counter=[];x=[] + for i in range(k): + x.append(i) + counter.append(C.count(i)) + print("class size disribution : ",end="") + print(counter) + if 0 in counter: + print('Error! There is a class which has no member.') + sys.exit(1) + + return U,C + +########################################################################################## +########################################################################################## +############################## for reproduction ########################################## +########################################################################################## +########################################################################################## + +def class_reproduction(k,S,Label): + # extract class preference matrix from given graph + + M, D = calc_class_features(S,k,Label) + + partition = [] + for i in range(k): + partition.append([]) + for i in range(len(Label)): + partition[Label[i]].append(i) + + class_size = [] + for i in partition: + class_size.append(len(i)) + class_size = np.array(class_size) / sum(class_size) + return M,D,class_size + +def gencat_reproduction(S,Label,X=None,H=None,n=0,m=0,max_deg=0,omega=0.2,r=50,step=100,att_type="Bernoulli"): + if X == None and H == None: + d=0 + elif X != None: + from func import calc_attr_cor + H = calc_attr_cor(X,Label) + d = H.shape[0] + else: + d = H.shape[0] + + # node degree generation + if n == 0: + theta = np.zeros(len(Label)) + nnz = S.nonzero() + for i in range(len(nnz[0])): + if nnz[0][i] < nnz[1][i]: + theta[nnz[0][i]] += 1 + theta[nnz[1][i]] += 1 + else: + theta = node_deg(n,m,max_deg) + n = len(theta) + m = count_node_degree(S) + k = len(set(Label)) + step = 1/(n*step) + + # class feature extraction + M,D,class_size = class_reproduction(k,S,Label) + + # latent factor generation + U,C,density = latent_factor_gen(n,k,M,D,class_size) + + # adjusting phase + U,U_prime = adjust(n,k,U,C,M) + + # Inverse Transform Sampling + U_prime_CDF = ITS_U_prime(n,k,U_prime,step) + + # Edge generation + S_gen, count_list = edge_construction(n, U, k, U_prime_CDF, step, theta, r) + print("number of generated edges : " + str(count_node_degree(S_gen))) + + ### Attribute + if d != 0: + att_scale = np.sum(H, axis=1) + H_ = copy.deepcopy(H) + for i in range(d): + H_[i] = H_[i] / sum(H_[i]) + + V = adjust_att(n,k,d,U,C,H_) + + # Attribute generation + X = attribute_generation(n,d,k,U,V,C,omega,att_type,att_scale) + else: + X=[] + + return S_gen,X,C + + +########################################################################################## +########################################################################################## +############################## only attribute ############################################ +########################################################################################## +########################################################################################## + + +def gencat_only_att(n,M,D,H,phi_c=1,omega=0.2,r=50,step=100,att_type="normal",woAP=False,woITS=False): + k=M.shape[0] + d=H.shape[0] + # class generation + class_size = class_size_gen(k,phi_c) + U,C,density = latent_factor_gen(n,k,M,D,class_size) + + # adjusting phase + if not woAP: + U,U_prime = adjust(n,k,U,C,M) + else: + print("woAP") + U_prime = adjust_woAP(n,k,U,C,M) + + S_gen = [] + + ### Attribute + att_scale = np.sum(H, axis=1) + H_ = copy.deepcopy(H) + for i in range(d): + H_[i] = H_[i] / sum(H_[i]) + V = adjust_att(n,k,d,U,C,H_) + + # Attribute generation + X = attribute_generation(n,d,k,U,V,C,omega,att_type,att_scale) + + # not applying user-specified distribution + X_not = U@V.T + for i in range(d): + X_not[:,i] -= min(X_not[:,i]) + X_not[:,i] /= max(X_not[:,i]) + + return S_gen,X,X_not,C diff --git a/gammagl/utils/loader_unifews.py b/gammagl/utils/loader_unifews.py new file mode 100644 index 000000000..377a2d4b9 --- /dev/null +++ b/gammagl/utils/loader_unifews.py @@ -0,0 +1,222 @@ +import os.path as osp +import sys +import gc +import copy +import os +from dotmap import DotMap +import numpy as np +import scipy.sparse as sp + +import tensorlayerx as tlx +from gammagl.data import Graph +from typing import Any, Callable, List, Optional, Union, Sequence + +def stochastic_blockmodel_graph(block_sizes, edge_probs, directed=False): + N = sum(block_sizes) + edges = [] + for i, size_i in enumerate(block_sizes): + start_i = sum(block_sizes[:i]) + end_i = start_i + size_i + for j, size_j in enumerate(block_sizes): + start_j = sum(block_sizes[:j]) + end_j = start_j + size_j + p = edge_probs[i, j] if isinstance(edge_probs, np.ndarray) else edge_probs[i][j] + if directed: + idx_i = np.random.choice(np.arange(start_i, end_i), size=int(size_i * size_j * p)) + idx_j = np.random.choice(np.arange(start_j, end_j), size=int(size_i * size_j * p)) + else: + if i > j: + continue + n_edges = int(size_i * size_j * p) + idx_i, idx_j = np.random.randint(start_i, end_i, n_edges), np.random.randint(start_j, end_j, n_edges) + if i == j: + mask = idx_i < idx_j + idx_i, idx_j = idx_i[mask], idx_j[mask] + if len(idx_i) > 0: + edges.append(np.stack([idx_i, idx_j])) + edge_index = np.concatenate(edges, axis=1) + if not directed: + edge_index = np.concatenate([edge_index, edge_index[::-1]], axis=1) + return edge_index + +def to_scipy_sparse_matrix(edge_index, num_nodes=None): + + if num_nodes is None: + num_nodes = edge_index.max() + 1 + return sp.coo_matrix((np.ones(edge_index.shape[1]), (edge_index[0], edge_index[1])), + shape=(num_nodes, num_nodes)) + +class DummySBM: + def __init__(self, edge_index, num_nodes): + self.edge_index = edge_index + self.num_nodes = num_nodes + +from pathlib import Path +ROOT = Path(__file__).parent.parent.parent +sys.path.append(str(ROOT)) +from .gen_cat import gencat, feature_extraction +from .data_processor import DataProcess, DataProcess_inductive, matstd_clip + +np.set_printoptions(linewidth=160, edgeitems=5, threshold=20, + formatter=dict(float=lambda x: "% 9.3e" % x)) + + +def to_list(value: Any) -> Sequence: + if isinstance(value, Sequence) and not isinstance(value, str): + return value + else: + return [value] + + +# ===================== 彻底删除废弃的GammaGL数据集继承,纯手动实现 ===================== +def sbm_mixture_of_Gaussians(n, n_features, sizes, probs, std_, means, nodes_per_mean): + + edge_index = stochastic_blockmodel_graph(sizes, probs, directed=False) + g_ = DummySBM(edge_index=edge_index, num_nodes=sum(sizes)) + + A = to_scipy_sparse_matrix(g_.edge_index, num_nodes=g_.num_nodes) + A.setdiag(A.diagonal() + 1) + d_mat = np.sum(A, axis=0) + dinv_mat = 1/d_mat + dinv = dinv_mat.tolist()[0] + Dinv = sp.diags(dinv, 0) + + X = np.zeros((2*n,n_features)) + for ct, mean_ in enumerate(means): + X[nodes_per_mean[ct],:] = std_ * np.random.randn(len(nodes_per_mean[ct]), n_features) + mean_ + + data = np.zeros((2*n,n_features+1)) + data[:,:-1] = Dinv@A@X + data[:,-1] = np.ones(2*n) + g_.x = tlx.convert_to_tensor(data, dtype=tlx.float32) + + y = np.zeros(2*n) + y[0:n] = 1 + g_.y = tlx.convert_to_tensor(y, dtype=tlx.float32) + + return g_ + + +def load_csbm(datastr: str, datapath: str="./data/", + inductive: bool=False, multil: bool=False, + seed: int=0, **kwargs): + + n = 500 + n_features = 127 + q = 0.1 + _, p, mu = datastr.split('-') + p, mu = float(p), float(mu) + + probs = [[p, q], [q, p]] + sizes = [n, n] + std_ = mu + means = [100, -100] + nodes_per_mean = [list(range(n)),list(range(n,2*n))] + + g = sbm_mixture_of_Gaussians(n, n_features, sizes, probs, std_, means, nodes_per_mean) + + idx_rnd = np.random.permutation(2*n) + adj = {'train': g.edge_index, + 'test': g.edge_index} + feat = {'train': tlx.convert_to_tensor(g.x, tlx.float32), + 'test': tlx.convert_to_tensor(g.x, tlx.float32)} + idx = {'train': tlx.convert_to_tensor(idx_rnd[:int(0.6*2*n)], tlx.int64), + 'val': tlx.convert_to_tensor(idx_rnd[int(0.6*2*n):int(0.8*2*n)], tlx.int64), + 'test': tlx.convert_to_tensor(idx_rnd[int(0.8*2*n):], tlx.int64)} + y = tlx.cast(g.y, tlx.int64) + labels = {'train': y[idx['train']], + 'val': y[idx['val']], + 'test': y[idx['test']]} + nfeat = n_features + 1 + nclass = 2 + if seed >= 15: + print(g) + return adj, feat, labels, idx, nfeat, nclass + + +def load_gencat(datastr: str, datapath: str="./data/", + inductive: bool=False, multil: bool=False, + seed: int=0, **kwargs): + dp = DataProcess('cora', path="./data/", seed=0) + dp.input(['adjnpz', 'labels', 'attr_matrix']) + M, D, class_size, H, node_degree = feature_extraction(dp.adj_matrix, dp.attr_matrix, dp.labels.tolist()) + + _, p, omega = datastr.split('-') + p, omega = float(p), float(omega) + adj, feat, labels = gencat(M, D, H, n=1000, m=5000, p=p, omega=omega) + dp.adj_matrix = adj + dp.attr_matrix = feat + dp.labels = np.array(labels) + + dp.calculate(['idx_train']) + idx = {'train': tlx.convert_to_tensor(dp.idx_train, tlx.int64), + 'val': tlx.convert_to_tensor(dp.idx_val, tlx.int64), + 'test': tlx.convert_to_tensor(dp.idx_test, tlx.int64)} + labels = tlx.convert_to_tensor(dp.labels.flatten(), tlx.int64) + labels = {'train': labels[idx['train']], + 'val': labels[idx['val']], + 'test': labels[idx['test']]} + dp.calculate(['edge_idx']) + adj = {'test': tlx.convert_to_tensor(dp.edge_idx, tlx.int64), + 'train': tlx.convert_to_tensor(dp.edge_idx, tlx.int64)} + feat = {'test': tlx.convert_to_tensor(dp.attr_matrix, tlx.float32), + 'train': tlx.convert_to_tensor(dp.attr_matrix, tlx.float32)} + n, m = dp.n, dp.m + nfeat, nclass = dp.nfeat, dp.nclass + if seed >= 15: + print(dp) + return adj, feat, labels, idx, nfeat, nclass + +def load_edgelist(datastr: str, datapath: str="./data/", + inductive: bool=False, multil: bool=False, + seed: int=0, **kwargs): + if datastr.startswith('csbm'): + return load_csbm(datastr, datapath, inductive, multil, seed, **kwargs) + elif datastr.startswith('gencat'): + return load_gencat(datastr, datapath, inductive, multil, seed, **kwargs) + + dp = DataProcess(datastr, path=datapath, seed=seed) + dp.input(['adjnpz', 'labels', 'attr_matrix']) + if inductive: + dpi = DataProcess_inductive(datastr, path=datapath, seed=seed) + dpi.input(['adjnpz', 'attr_matrix']) + else: + dpi = dp + + if (datastr.startswith('cora') or datastr.startswith('citeseer') or datastr.startswith('pubmed')): + dp.calculate(['idx_train']) + else: + dp.input(['idx_train', 'idx_val', 'idx_test']) + idx = {'train': tlx.convert_to_tensor(dp.idx_train, tlx.int64), + 'val': tlx.convert_to_tensor(dp.idx_val, tlx.int64), + 'test': tlx.convert_to_tensor(dp.idx_test, tlx.int64)} + + if multil: + dp.calculate(['labels_oh']) + dp.labels_oh[dp.labels_oh < 0] = 0 + labels = tlx.convert_to_tensor(dp.labels_oh, tlx.float32) + else: + dp.labels[dp.labels < 0] = 0 + labels = tlx.convert_to_tensor(dp.labels.flatten(), tlx.int64) + labels = {'train': labels[idx['train']], + 'val': labels[idx['val']], + 'test': labels[idx['test']]} + + dp.calculate(['edge_idx']) + adj = {'test': tlx.convert_to_tensor(dp.edge_idx, tlx.int64)} + if inductive: + dpi.calculate(['edge_idx']) + adj['train'] = tlx.convert_to_tensor(dpi.edge_idx, tlx.int64) + else: + adj['train'] = adj['test'] + + feat = dp.attr_matrix + feati = dpi.attr_matrix if inductive else feat + feat = {'test': tlx.convert_to_tensor(feat, tlx.float32)} + feat['train'] = tlx.convert_to_tensor(feati, tlx.float32) + + n, m = dp.n, dp.m + nfeat, nclass = dp.nfeat, dp.nclass + if seed >= 15: + print(dp) + return adj, feat, labels, idx, nfeat, nclass diff --git a/gammagl/utils/logger_unifews.py b/gammagl/utils/logger_unifews.py new file mode 100644 index 000000000..87000c299 --- /dev/null +++ b/gammagl/utils/logger_unifews.py @@ -0,0 +1,270 @@ +import os +from datetime import datetime +import uuid +import json +import numpy as np +from typing import Union, Callable +from dotmap import DotMap + +import tensorlayerx as tlx + +from tensorlayerx.nn import Module + + +def prepare_opt(parser) -> DotMap: + + opt_parser = vars(parser.parse_args()) + config_path = opt_parser['config'] + if not os.path.isfile(config_path): + config_path = os.path.join('./config/', config_path + '.json') + with open(config_path, 'r') as config_file: + opt_config = json.load(config_file) + for k, v in opt_parser.items(): + if v is not None: + opt_config[k] = v + return DotMap(**opt_config) + + +class Logger(object): + def __init__(self, data: str, algo: str, flag_run: str='', dir: tuple=None): + super(Logger, self).__init__() + + self.seed_str = str(uuid.uuid4())[:6] + self.seed = int(self.seed_str, 16) + if not flag_run: + flag_run = datetime.now().strftime("%m%d") + '-' + self.seed_str + elif flag_run.count('date') > 0: + flag_run.replace('date', datetime.now().strftime("%m%d")) + else: + pass + + if dir is None: + self.dir_save = os.path.join("./save/", data, algo, flag_run) + else: + self.dir_save = os.path.join(*dir) + self.path_exists = os.path.exists(self.dir_save) + + self.flag_run = flag_run + self.file_log = self.path_join('log.txt') + self.file_config = self.path_join('config.json') + + flag_run = flag_run.split('-')[0] + seed = int(flag_run) if flag_run.isdigit() else 11 + if seed < 10: + self.lvl_log = 0 + elif seed < 20: + self.lvl_log = 1 + elif seed < 30: + self.lvl_log = 2 + else: + self.lvl_log = 3 + if seed < 5: + self.lvl_config = 0 + elif seed < 15: + self.lvl_config = 1 + elif seed < 25: + self.lvl_config = 2 + else: + self.lvl_config = 3 + + def path_join(self, *args) -> str: + return os.path.join(self.dir_save, *args) + + def print(self, s, sf=None, lvl=None) -> None: + lvl = self.lvl_log if lvl is None else lvl + if lvl > 0: + print(s, flush=True) + if lvl > 2: + sf = s if sf is None else sf + with open(self.file_log, 'a') as f: + f.write(str(sf) + '\n') + + def print_on_top(self, s) -> None: + if self.lvl_log > 0: + print(s) + if self.lvl_log > 2: + with open(self.file_log, 'a') as f: + pass + with open(self.file_log, 'r+') as f: + temp = f.read() + f.seek(0, 0) + f.write(str(s) + '\n') + f.write(temp) + + def print_header(self, hs, s) -> None: + if self.lvl_log > 0: + if os.path.isfile(self.file_log): + print(hs) + else: + self.print(hs, hs.replace('|', ','), lvl=self.lvl_config) + self.print(s, lvl=self.lvl_config) + + def _opt_to_dict(self, opt): + if isinstance(opt, DotMap): + return opt.toDict() + return vars(opt) + + def save_opt(self, opt) -> None: + if self.lvl_log > 2: + os.makedirs(self.dir_save, exist_ok=True) + opt_dict = self._opt_to_dict(opt) + with open(self.file_config, 'w') as f: + json.dump(opt_dict, fp=f, indent=4, sort_keys=False) + f.write('\n') + print("Option saved.") + print("Config path: {}".format(self.file_config)) + print("Option dict: {}\n".format(opt_dict)) + + def load_opt(self) -> DotMap: + with open(self.file_config, 'r') as config_file: + opt = DotMap(json.load(config_file)) + print("Option loaded.") + print("Config path: {}".format(self.file_config)) + print("Option dict: {}\n".format(opt.toDict())) + return opt + + def str_csv(self, data, algo, seed, thr_a, thr_w, + acc_test, conv_epoch, epoch, time_train, macs_train, + time_test, macs_test, numel_a, numel_w): + hstr, cstr = '', '' + hstr += f" Data| Model| Seed| ThA| ThW| " + cstr += f"{data:10s},{algo:10s},{seed:6d},{thr_a:7.2e},{thr_w:7.2e}," + hstr += f" Acc| Cn| EP| " + cstr += f"{acc_test:7.5f},{conv_epoch:4d},{epoch:4d}," + hstr += f" Ttrain| Ctrain| " + cstr += f"{time_train:8.4f},{macs_train:8.3f}," + hstr += f" Ttest| CTest| NumelA| NumelW" + cstr += f"{time_test:8.4f},{macs_test:8.4f},{numel_a:8.3f},{numel_w:8.3f}" + return hstr, cstr + + def str_csvg(self, data, algo, seed, thr_a, thr_w, + acc_test, conv_epoch, epoch, time_train, macs_train, + macs_a, macs_wtr, macs_wte, + time_test, macs_test, numel_a, numel_w, hop, layer, time_pre): + hstr, cstr = '', '' + hstr += f" Data| Model| Seed| ThA| ThW| " + cstr += f"{data:10s},{algo:10s},{seed:6d},{thr_a:7.1e},{thr_w:7.1e}," + hstr += f" Acc| Cn| EP| " + cstr += f"{acc_test:7.5f},{conv_epoch:4d},{epoch:4d}," + hstr += f" Ttrain| Ctrain| " + cstr += f"{time_train:8.4f},{macs_train:8.3f}," + hstr += f" Ttest| CTest| NumelA| NumelW| " + cstr += f"{time_test:8.4f},{macs_test:8.4f},{numel_a:8.3f},{numel_w:8.3f}," + hstr += f" CPre| CTr| CTe| Hop| Lay| TPre " + cstr += f"{macs_a:8.4f},{macs_wtr:8.4f},{macs_wte:8.4f},{hop:4d},{layer:4d},{time_pre:8.4f}" + return hstr, cstr + +class ModelLogger(object): + def __init__(self, logger: Logger, patience: int=99999, + prefix: str='model', storage: str='state', + cmp: Union[Callable[[float, float], bool], str]='>'): + super(ModelLogger, self).__init__() + self.logger = logger + self.patience = patience + self.prefix = prefix + self.model = None + + assert storage in ['model', 'state', 'state_gpu'] + self.storage = storage + + if cmp in ['>', 'max']: + self.cmp = lambda x, y: x > y + elif cmp in ['<', 'min']: + self.cmp = lambda x, y: x < y + else: + self.cmp = cmp + + def __set_model(self, model: Module) -> Module: + self.model = model + return self.model + + def register(self, model: Module, save_init: bool=True) -> None: + self.__set_model(model) + if save_init: + self.save('0') + + def load(self, *suffix, model: Module=None) -> Module: + name = '_'.join((self.prefix,) + suffix) + path = self.logger.path_join(name + '.npz') + + if self.storage in ('state', 'state_gpu'): + if model is None: + model = self.model + loaded = np.load(path) + for i, w in enumerate(model.trainable_weights): + new_val = tlx.convert_to_tensor(loaded[f'arr_{i}'], dtype=w.dtype) + if hasattr(w, 'assign'): + w.assign(new_val) + else: + w.data = new_val + elif self.storage == 'model': + model = tlx.load(path) + + return self.__set_model(model) + + def save(self, *suffix) -> None: + name = '_'.join((self.prefix,) + suffix) + path = self.logger.path_join(name + '.npz') + os.makedirs(os.path.dirname(path), exist_ok=True) + + if self.storage in ('state', 'state_gpu'): + weights = [tlx.convert_to_numpy(w) for w in self.model.trainable_weights] + np.savez(path, *weights) + elif self.storage == 'model': + tlx.save(self.model, path) + + def get_last_epoch(self) -> int: + name_pre = '_'.join((self.prefix,) + ('',)) + last_epoch = -2 + + for fname in os.listdir(self.logger.dir_save): + fname = str(fname) + if fname.startswith(name_pre) and fname.endswith('.npz'): + suffix = fname.replace(name_pre, '').replace('.npz', '') + if suffix == 'init': + this_epoch = -1 + elif suffix.isdigit(): + this_epoch = int(suffix) - 1 + else: + this_epoch = -2 + if this_epoch > last_epoch: + last_epoch = this_epoch + return last_epoch + + def save_epoch(self, epoch: int, period: int=1) -> None: + if (epoch + 1) % period == 0: + self.save(str(epoch+1)) + + def save_best(self, score: float, epoch: int=-1, + print_log: bool=False) -> int: + if self.is_best(score, epoch): + self.save('best') + if print_log: + self.logger.print('[best saved] {:>.4f}'.format(self.score_best)) + return self.score_best + + def is_best(self, score: float, epoch: int=-1) -> bool: + res = (not hasattr(self, 'score_best')) + if res or self.cmp(score, self.score_best): + self.score_best = score + self.epoch_best = epoch + res = True + return res + + def is_early_stop(self, epoch: int=-1) -> bool: + return epoch - self.epoch_best >= self.patience + + +class LayerNumLogger(object): + def __init__(self, name: str=None): + self.name = name + self.numel_before = None + self.numel_after = None + + @property + def ratio(self) -> float: + return self.numel_after / self.numel_before + + def __str__(self) -> str: + s = f"{self.numel_after}/{self.numel_before} ({1-self.ratio:6.2%})" + return s \ No newline at end of file diff --git a/gammagl/utils/metric_unifews.py b/gammagl/utils/metric_unifews.py new file mode 100644 index 000000000..90726003f --- /dev/null +++ b/gammagl/utils/metric_unifews.py @@ -0,0 +1,109 @@ +import time +import resource +import numpy as np +import os +import tensorlayerx as tlx +from tensorlayerx.nn import Module + + +class F1Calculator(object): + def __init__(self, num_classes: int): + self.num_classes = num_classes + self.TP = 0.0 + self.FP = 0.0 + self.FN = 0.0 + + def update(self, y_true, y_pred): + def _to_one_hot(tensor, num_classes): + if len(tensor.shape) == 1 or tensor.shape[1] == 1: + idx = tlx.cast(tlx.reshape(tensor, (-1,)), tlx.int64) + eye_np = np.eye(num_classes, dtype=np.float32) + eye_tlx = tlx.convert_to_tensor(eye_np) + return tlx.gather(eye_tlx, idx) + return tlx.cast(tensor, tlx.float32) + + y_true = _to_one_hot(y_true, self.num_classes) + y_pred = _to_one_hot(y_pred, self.num_classes) + + self.TP += tlx.reduce_sum(y_true * y_pred, axis=0) + self.FP += tlx.reduce_sum((1 - y_true) * y_pred, axis=0) + self.FN += tlx.reduce_sum(y_true * (1 - y_pred), axis=0) + + def compute(self, average: str=None): + eps = 1e-10 + + if isinstance(self.TP, float): + return 0.0 + + TP = self.TP + FP = self.FP + FN = self.FN + + if average == 'micro': + f1 = 2 * tlx.reduce_sum(TP) / (2 * tlx.reduce_sum(TP) + tlx.reduce_sum(FP) + tlx.reduce_sum(FN) + eps) + return float(f1) + elif average == 'macro': + f1 = 2 * TP / (2 * TP + FP + FN + eps) + return float(tlx.reduce_mean(f1)) + else: + raise ValueError('average must be "micro" or "macro"') + + +class Stopwatch(object): + def __init__(self): + self.reset() + + def start(self): + self.start_time = time.time() + + def pause(self) -> float: + self.elapsed_sec += time.time() - self.start_time + self.start_time = None + return self.elapsed_sec + + def lap(self) -> float: + return time.time() - self.start_time + self.elapsed_sec + + def reset(self): + self.start_time = None + self.elapsed_sec = 0 + + @property + def time(self) -> float: + return self.elapsed_sec + + +class Accumulator(object): + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.count = 0 + + def update(self, val: float, count: int=1): + self.val += val + self.count += count + return self.val + + @property + def avg(self) -> float: + return self.val / self.count + + +def get_ram() -> float: + return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 2**20 + + +def get_cuda_mem(dev) -> float: + return 0.0 + + +def get_num_params(model: Module) -> float: + num_paramst = sum([np.prod(param.shape) for param in model.trainable_weights]) + return num_paramst / 1e6 + + +def get_mem_params(model: Module) -> float: + mem_params = sum([np.prod(param.shape) * 4 for param in model.all_weights]) + return mem_params / (1024**2) \ No newline at end of file