Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions GNNGraphs/src/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -795,18 +795,14 @@ function _unbatch_nodemasks(graph_indicator, num_graphs)
end

function _unbatch_edgemasks(s, t, num_graphs, cumnum_nodes)
edgemasks = []
for i in 1:(num_graphs - 1)
lastedgeid = findfirst(s) do x
x > cumnum_nodes[i + 1] && x <= cumnum_nodes[i + 2]
end
firstedgeid = i == 1 ? 1 : last(edgemasks[i - 1]) + 1
# if nothing make empty range
lastedgeid = lastedgeid === nothing ? firstedgeid - 1 : lastedgeid - 1
edgemasks = [Int[] for _ in 1:num_graphs]

push!(edgemasks, firstedgeid:lastedgeid)
for (eid, src) in enumerate(s)
graph_idx = searchsortedfirst(cumnum_nodes, src) - 1
@assert 1 <= graph_idx <= num_graphs
push!(edgemasks[graph_idx], eid)
end
push!(edgemasks, (last(edgemasks[end]) + 1):length(s))

return edgemasks
end

Expand Down
37 changes: 36 additions & 1 deletion GNNGraphs/test/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,49 @@ end
g1 = rand_graph(10, 20, graph_type = GRAPH_T)
g2 = rand_graph(5, 10, graph_type = GRAPH_T)
g12 = MLUtils.batch([g1, g2])
gs = MLUtils.unbatch([g1, g2])
gs = MLUtils.unbatch(g12)
@test length(gs) == 2
@test gs[1].num_nodes == 10
@test gs[1].num_edges == 20
@test gs[1].num_graphs == 1
@test gs[2].num_nodes == 5
@test gs[2].num_edges == 10
@test gs[2].num_graphs == 1

if GRAPH_T == :coo
@test gs[1] == g1
@test gs[2] == g2

@testset "coo zero-edge graphs" begin
gempty = GNNGraph(2)
gedge1 = GNNGraph(([1], [2]), num_nodes = 2)
gedge2 = GNNGraph(([2], [1]), num_nodes = 2)

for graphs in ([gempty, gedge1, gedge2],
[gedge1, gempty, gedge2],
[gedge1, gedge2, gempty])
@test MLUtils.unbatch(MLUtils.batch(graphs)) == collect(graphs)
end
end

@testset "coo zero-edge graphs preserve features" begin
g1f = GNNGraph(([1], [2]), num_nodes = 2,
ndata = (x = Float32[1 2; 3 4],),
edata = (e = Float32[10; 11;;],),
gdata = 100f0)
g2f = GNNGraph(2,
ndata = (x = Float32[5 6; 7 8],),
edata = (e = zeros(Float32, 2, 0),),
gdata = 200f0)
g3f = GNNGraph(([2], [1]), num_nodes = 2,
ndata = (x = Float32[9 10; 11 12],),
edata = (e = Float32[12; 13;;],),
gdata = 300f0)

gs_feat = [g1f, g2f, g3f]
@test MLUtils.unbatch(MLUtils.batch(gs_feat)) == gs_feat
end
end
end
end

Expand Down
Loading