diff --git a/GNNGraphs/src/transform.jl b/GNNGraphs/src/transform.jl index 50695e976..fed3b19ca 100644 --- a/GNNGraphs/src/transform.jl +++ b/GNNGraphs/src/transform.jl @@ -795,18 +795,14 @@ function _unbatch_nodemasks(graph_indicator, num_graphs) end function _unbatch_edgemasks(s, t, num_graphs, cumnum_nodes) - edgemasks = [] - for i in 1:(num_graphs - 1) - lastedgeid = findfirst(s) do x - x > cumnum_nodes[i + 1] && x <= cumnum_nodes[i + 2] - end - firstedgeid = i == 1 ? 1 : last(edgemasks[i - 1]) + 1 - # if nothing make empty range - lastedgeid = lastedgeid === nothing ? firstedgeid - 1 : lastedgeid - 1 + edgemasks = [Int[] for _ in 1:num_graphs] - push!(edgemasks, firstedgeid:lastedgeid) + for (eid, src) in enumerate(s) + graph_idx = searchsortedfirst(cumnum_nodes, src) - 1 + @assert 1 <= graph_idx <= num_graphs + push!(edgemasks[graph_idx], eid) end - push!(edgemasks, (last(edgemasks[end]) + 1):length(s)) + return edgemasks end diff --git a/GNNGraphs/test/transform.jl b/GNNGraphs/test/transform.jl index fb7e95bf4..39b415701 100644 --- a/GNNGraphs/test/transform.jl +++ b/GNNGraphs/test/transform.jl @@ -66,7 +66,7 @@ end g1 = rand_graph(10, 20, graph_type = GRAPH_T) g2 = rand_graph(5, 10, graph_type = GRAPH_T) g12 = MLUtils.batch([g1, g2]) - gs = MLUtils.unbatch([g1, g2]) + gs = MLUtils.unbatch(g12) @test length(gs) == 2 @test gs[1].num_nodes == 10 @test gs[1].num_edges == 20 @@ -74,6 +74,41 @@ end @test gs[2].num_nodes == 5 @test gs[2].num_edges == 10 @test gs[2].num_graphs == 1 + + if GRAPH_T == :coo + @test gs[1] == g1 + @test gs[2] == g2 + + @testset "coo zero-edge graphs" begin + gempty = GNNGraph(2) + gedge1 = GNNGraph(([1], [2]), num_nodes = 2) + gedge2 = GNNGraph(([2], [1]), num_nodes = 2) + + for graphs in ([gempty, gedge1, gedge2], + [gedge1, gempty, gedge2], + [gedge1, gedge2, gempty]) + @test MLUtils.unbatch(MLUtils.batch(graphs)) == collect(graphs) + end + end + + @testset "coo zero-edge graphs preserve features" begin + g1f = GNNGraph(([1], [2]), num_nodes = 2, + ndata = (x = Float32[1 2; 3 4],), + edata = (e = Float32[10; 11;;],), + gdata = 100f0) + g2f = GNNGraph(2, + ndata = (x = Float32[5 6; 7 8],), + edata = (e = zeros(Float32, 2, 0),), + gdata = 200f0) + g3f = GNNGraph(([2], [1]), num_nodes = 2, + ndata = (x = Float32[9 10; 11 12],), + edata = (e = Float32[12; 13;;],), + gdata = 300f0) + + gs_feat = [g1f, g2f, g3f] + @test MLUtils.unbatch(MLUtils.batch(gs_feat)) == gs_feat + end + end end end