Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions graphlearn_torch/csrc/cpu/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@ void Graph::InitCPUGraphFromCSR(
const torch::Tensor& indptr,
const torch::Tensor& indices,
const torch::Tensor& edge_ids,
const torch::Tensor& edge_weights) {
const torch::Tensor& edge_weights,
int64_t col_count) {
CheckEq<int64_t>(indptr.dim(), 1);
CheckEq<int64_t>(indices.dim(), 1);

row_ptr_ = indptr.data_ptr<int64_t>();
col_idx_ = indices.data_ptr<int64_t>();
row_count_ = indptr.size(0) - 1;
edge_count_ = indices.size(0);
col_count_ = std::get<0>(at::_unique(indices)).size(0);
col_count_ = col_count >= 0 ? col_count :
std::get<0>(at::_unique(indices)).size(0);

if (edge_ids.numel()) {
CheckEq<int64_t>(edge_ids.dim(), 1);
Expand Down
3 changes: 2 additions & 1 deletion graphlearn_torch/include/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ class Graph {
void InitCPUGraphFromCSR(const torch::Tensor& indptr,
const torch::Tensor& indices,
const torch::Tensor& edge_ids=torch::empty(0),
const torch::Tensor& edge_weights=torch::empty(0));
const torch::Tensor& edge_weights=torch::empty(0),
int64_t col_count=-1);
#ifdef WITH_CUDA
virtual ~Graph();
void LookupDegree(const int64_t* nodes,
Expand Down
37 changes: 30 additions & 7 deletions graphlearn_torch/python/data/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,22 @@ class Graph(object):
operations. Note that this parameter will be ignored if the graph mode
set to 'CPU'. The value of ``torch.cuda.current_device()`` will be used
if set to ``None``. (Default: ``None``).
col_count (int, optional): Cached number of unique column ids for CPU
graphs. If provided, CPU graph init will reuse it instead of
recomputing the value from indices. (Default: ``None``).
"""
def __init__(self, topo: Topology, mode = 'ZERO_COPY',
device: Optional[int] = None):
device: Optional[int] = None,
col_count: Optional[int] = None):
self.topo = topo
self.topo.share_memory_()
self.mode = mode.upper()
self.device = device
if col_count is not None:
print(f"Reusing cached col_count: {col_count}")
else:
print("col count not provided, will compute it from indices")
self._cached_col_count = None if col_count is None else int(col_count)

if self.mode != 'CPU' and self.device is not None:
self.device = int(self.device)
Expand Down Expand Up @@ -234,7 +243,14 @@ def lazy_init(self):
edge_weights = torch.empty(0)

if self.mode == 'CPU':
self._graph.init_cpu_from_csr(indptr, indices, edge_ids, edge_weights)
self._graph.init_cpu_from_csr(
indptr,
indices,
edge_ids,
edge_weights,
-1 if self._cached_col_count is None else self._cached_col_count
)
self._cached_col_count = self._graph.get_col_count()
else:
if self.device is None:
self.device = torch.cuda.current_device()
Expand All @@ -258,16 +274,23 @@ def share_ipc(self):
r""" Create ipc handle for multiprocessing.

Returns:
A tuple of topo and graph mode.
A tuple of topo, graph mode and cached CPU col_count metadata.
"""
return self.topo, self.mode
if self.mode == 'CPU' and self._cached_col_count is None and \
self._graph is not None:
self._cached_col_count = self._graph.get_col_count()
return self.topo, self.mode, self._cached_col_count

@classmethod
def from_ipc_handle(cls, ipc_handle):
r""" Create from ipc handle.
r""" Create from an old or new ipc handle.
"""
topo, mode = ipc_handle
return cls(topo, mode, device=None)
if len(ipc_handle) == 2:
topo, mode = ipc_handle
cached_col_count = None
else:
topo, mode, cached_col_count = ipc_handle
return cls(topo, mode, device=None, col_count=cached_col_count)

@property
def row_count(self):
Expand Down
3 changes: 2 additions & 1 deletion graphlearn_torch/python/py_export_glt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def(py::init<>())
.def("init_cpu_from_csr", &Graph::InitCPUGraphFromCSR,
py::arg("indptr"), py::arg("indices"),
py::arg("edge_ids"), py::arg("edge_weights"))
py::arg("edge_ids"), py::arg("edge_weights"),
py::arg("col_count") = -1)
#ifdef WITH_CUDA
.def("init_cuda_from_csr",
py::overload_cast<const torch::Tensor&,
Expand Down
25 changes: 24 additions & 1 deletion test/python/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
# limitations under the License.
# ==============================================================================

import pickle
import unittest
from multiprocessing.reduction import ForkingPickler

import torch

from graphlearn_torch.data import Topology, Graph
Expand Down Expand Up @@ -101,6 +104,8 @@ def test_cpu_graph_init(self):
g = Graph(self.csr_topo, mode='CPU')
self.assertEqual(g.edge_count, self.indices_csr.size(0))
self.assertEqual(g.row_count, self.indptr_csr.size(0) - 1)
self.assertEqual(g.col_count, 6)
self.assertEqual(g.share_ipc()[1:], ('CPU', 6))

def test_cuda_graph_init(self):
g = Graph(self.csr_topo, 'CUDA', 0)
Expand All @@ -112,6 +117,24 @@ def test_pin_graph_init(self):
self.assertEqual(g.edge_count, self.indices_csr.size(0))
self.assertEqual(g.row_count, self.indptr_csr.size(0) - 1)

def test_cpu_graph_init_with_cached_col_count(self):
g = Graph(self.csr_topo, mode='CPU', col_count=6)
self.assertEqual(g.col_count, 6)
self.assertIs(g.share_ipc()[0], self.csr_topo)
self.assertEqual(g.share_ipc()[1:], ('CPU', 6))

def test_graph_from_legacy_ipc_handle(self):
g = Graph.from_ipc_handle((self.csr_topo, 'CPU'))
self.assertEqual(g.col_count, 6)
self.assertIs(g.share_ipc()[0], self.csr_topo)
self.assertEqual(g.share_ipc()[1:], ('CPU', 6))

def test_graph_forking_pickler_preserves_cached_col_count(self):
g = Graph(self.csr_topo, mode='CPU', col_count=6)
restored = pickle.loads(ForkingPickler.dumps(g))
self.assertEqual(restored.share_ipc()[1:], ('CPU', 6))
self.assertEqual(restored.col_count, 6)

def test_topo_with_layout(self):
# 'COO' -> 'CSC'
row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3], dtype=torch.int64)
Expand Down Expand Up @@ -197,4 +220,4 @@ def test_topo_with_layout(self):


if __name__ == "__main__":
unittest.main()
unittest.main()
Loading