diff --git a/GNNGraphs/src/transform.jl b/GNNGraphs/src/transform.jl index fed3b19ca..c20579d20 100644 --- a/GNNGraphs/src/transform.jl +++ b/GNNGraphs/src/transform.jl @@ -65,7 +65,7 @@ end function remove_self_loops(g::GNNGraph{<:ADJMAT_T}) @assert isempty(g.edata) - A = g.graph + A = copy(g.graph) A[diagind(A)] .= 0 if A isa AbstractSparseMatrix dropzeros!(A) diff --git a/GNNGraphs/test/transform.jl b/GNNGraphs/test/transform.jl index 39b415701..bcc7d08b5 100644 --- a/GNNGraphs/test/transform.jl +++ b/GNNGraphs/test/transform.jl @@ -376,6 +376,20 @@ end @test size(get_edge_weight(g2)) == (g2.num_edges,) @test size(g2.edata.e1) == (3, g2.num_edges) @test size(g2.edata.e2) == (g2.num_edges,) + else + A = [1 1 0 + 0 1 1 + 1 0 0] + A_no_loops = [0 1 0 + 0 0 1 + 1 0 0] + g = GNNGraph(A; graph_type = GRAPH_T) + g2 = remove_self_loops(g) + + @test Matrix(adjacency_matrix(g)) == A + @test Matrix(adjacency_matrix(g2)) == A_no_loops + @test g2.num_edges == count(!iszero, A_no_loops) + @test g.graph !== g2.graph end end end