diff --git a/graphlearn_torch/csrc/cpu/graph.cc b/graphlearn_torch/csrc/cpu/graph.cc index cd8b4ba6..64dbd66e 100644 --- a/graphlearn_torch/csrc/cpu/graph.cc +++ b/graphlearn_torch/csrc/cpu/graph.cc @@ -22,7 +22,8 @@ void Graph::InitCPUGraphFromCSR( const torch::Tensor& indptr, const torch::Tensor& indices, const torch::Tensor& edge_ids, - const torch::Tensor& edge_weights) { + const torch::Tensor& edge_weights, + int64_t col_count) { CheckEq(indptr.dim(), 1); CheckEq(indices.dim(), 1); @@ -30,7 +31,8 @@ void Graph::InitCPUGraphFromCSR( col_idx_ = indices.data_ptr(); row_count_ = indptr.size(0) - 1; edge_count_ = indices.size(0); - col_count_ = std::get<0>(at::_unique(indices)).size(0); + col_count_ = col_count >= 0 ? col_count : + std::get<0>(at::_unique(indices)).size(0); if (edge_ids.numel()) { CheckEq(edge_ids.dim(), 1); diff --git a/graphlearn_torch/include/graph.h b/graphlearn_torch/include/graph.h index 85d1fd81..2a585309 100644 --- a/graphlearn_torch/include/graph.h +++ b/graphlearn_torch/include/graph.h @@ -51,7 +51,8 @@ class Graph { void InitCPUGraphFromCSR(const torch::Tensor& indptr, const torch::Tensor& indices, const torch::Tensor& edge_ids=torch::empty(0), - const torch::Tensor& edge_weights=torch::empty(0)); + const torch::Tensor& edge_weights=torch::empty(0), + int64_t col_count=-1); #ifdef WITH_CUDA virtual ~Graph(); void LookupDegree(const int64_t* nodes, diff --git a/graphlearn_torch/python/data/graph.py b/graphlearn_torch/python/data/graph.py index 4654b071..0e523666 100644 --- a/graphlearn_torch/python/data/graph.py +++ b/graphlearn_torch/python/data/graph.py @@ -200,13 +200,22 @@ class Graph(object): operations. Note that this parameter will be ignored if the graph mode set to 'CPU'. The value of ``torch.cuda.current_device()`` will be used if set to ``None``. (Default: ``None``). + col_count (int, optional): Cached number of unique column ids for CPU + graphs. If provided, CPU graph init will reuse it instead of + recomputing the value from indices. (Default: ``None``). """ def __init__(self, topo: Topology, mode = 'ZERO_COPY', - device: Optional[int] = None): + device: Optional[int] = None, + col_count: Optional[int] = None): self.topo = topo self.topo.share_memory_() self.mode = mode.upper() self.device = device + if col_count is not None: + print(f"Reusing cached col_count: {col_count}") + else: + print("col count not provided, will compute it from indices") + self._cached_col_count = None if col_count is None else int(col_count) if self.mode != 'CPU' and self.device is not None: self.device = int(self.device) @@ -234,7 +243,14 @@ def lazy_init(self): edge_weights = torch.empty(0) if self.mode == 'CPU': - self._graph.init_cpu_from_csr(indptr, indices, edge_ids, edge_weights) + self._graph.init_cpu_from_csr( + indptr, + indices, + edge_ids, + edge_weights, + -1 if self._cached_col_count is None else self._cached_col_count + ) + self._cached_col_count = self._graph.get_col_count() else: if self.device is None: self.device = torch.cuda.current_device() @@ -258,16 +274,23 @@ def share_ipc(self): r""" Create ipc handle for multiprocessing. Returns: - A tuple of topo and graph mode. + A tuple of topo, graph mode and cached CPU col_count metadata. """ - return self.topo, self.mode + if self.mode == 'CPU' and self._cached_col_count is None and \ + self._graph is not None: + self._cached_col_count = self._graph.get_col_count() + return self.topo, self.mode, self._cached_col_count @classmethod def from_ipc_handle(cls, ipc_handle): - r""" Create from ipc handle. + r""" Create from an old or new ipc handle. """ - topo, mode = ipc_handle - return cls(topo, mode, device=None) + if len(ipc_handle) == 2: + topo, mode = ipc_handle + cached_col_count = None + else: + topo, mode, cached_col_count = ipc_handle + return cls(topo, mode, device=None, col_count=cached_col_count) @property def row_count(self): diff --git a/graphlearn_torch/python/py_export_glt.cc b/graphlearn_torch/python/py_export_glt.cc index e70ab36f..c1ad2b36 100644 --- a/graphlearn_torch/python/py_export_glt.cc +++ b/graphlearn_torch/python/py_export_glt.cc @@ -58,7 +58,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def(py::init<>()) .def("init_cpu_from_csr", &Graph::InitCPUGraphFromCSR, py::arg("indptr"), py::arg("indices"), - py::arg("edge_ids"), py::arg("edge_weights")) + py::arg("edge_ids"), py::arg("edge_weights"), + py::arg("col_count") = -1) #ifdef WITH_CUDA .def("init_cuda_from_csr", py::overload_cast 'CSC' row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3], dtype=torch.int64) @@ -197,4 +220,4 @@ def test_topo_with_layout(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()