From 0ea92ea3c847f4c1f7b972378c5c83135edac5b0 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Sun, 1 Feb 2026 02:26:25 +0100 Subject: [PATCH] Add `ishermitian` and `issymmetric` --- src/GenericSparseArrays.jl | 2 +- src/core.jl | 5 +++ src/matrix_coo/matrix_coo.jl | 12 ++++++ src/matrix_csc/matrix_csc.jl | 38 +++++++++++++++++ src/matrix_csc/matrix_csc_kernels.jl | 64 ++++++++++++++++++++++++++++ src/matrix_csr/matrix_csr.jl | 38 +++++++++++++++++ src/matrix_csr/matrix_csr_kernels.jl | 64 ++++++++++++++++++++++++++++ test/shared/matrix_coo.jl | 27 ++++++++++++ test/shared/matrix_csc.jl | 27 ++++++++++++ test/shared/matrix_csr.jl | 27 ++++++++++++ 10 files changed, 303 insertions(+), 1 deletion(-) diff --git a/src/GenericSparseArrays.jl b/src/GenericSparseArrays.jl index 2d8bf2c..aeae93f 100644 --- a/src/GenericSparseArrays.jl +++ b/src/GenericSparseArrays.jl @@ -1,7 +1,7 @@ module GenericSparseArrays using LinearAlgebra -import LinearAlgebra: wrap, copymutable_oftype, __normalize!, kron, Diagonal +import LinearAlgebra: wrap, copymutable_oftype, __normalize!, kron, Diagonal, issymmetric, ishermitian using SparseArrays import SparseArrays: SparseVector, SparseMatrixCSC import SparseArrays: getcolptr, getrowval, getnzval, nonzeroinds diff --git a/src/core.jl b/src/core.jl index 88d6663..41ed659 100644 --- a/src/core.jl +++ b/src/core.jl @@ -83,6 +83,11 @@ end Base.:+(B::DenseMatrix, A::AbstractGenericSparseMatrix) = A + B +LinearAlgebra.issymmetric(A::Transpose{<:Any, <:AbstractGenericSparseMatrix}) = issymmetric(parent(A)) +LinearAlgebra.issymmetric(A::Adjoint{<:Any, <:AbstractGenericSparseMatrix}) = issymmetric(parent(A)) +LinearAlgebra.ishermitian(A::Transpose{<:Any, <:AbstractGenericSparseMatrix}) = ishermitian(parent(A)) +LinearAlgebra.ishermitian(A::Adjoint{<:Any, <:AbstractGenericSparseMatrix}) = ishermitian(parent(A)) + # Keep this at the end of the file trans_adj_wrappers(fmt) = ( (T -> :($fmt{$T}), false, false, identity, T -> :($T)), diff --git a/src/matrix_coo/matrix_coo.jl b/src/matrix_coo/matrix_coo.jl index 0ff2afb..458b049 100644 --- a/src/matrix_coo/matrix_coo.jl +++ b/src/matrix_coo/matrix_coo.jl @@ -739,3 +739,15 @@ for (wrap, trans, conj, unwrap, whereT) in trans_adj_wrappers(:GenericSparseMatr return kron(A_coo, D) end end + +function LinearAlgebra.issymmetric(A::GenericSparseMatrixCOO) + m, n = size(A) + m == n || return false + return issymmetric(GenericSparseMatrixCSC(A)) +end + +function LinearAlgebra.ishermitian(A::GenericSparseMatrixCOO) + m, n = size(A) + m == n || return false + return ishermitian(GenericSparseMatrixCSC(A)) +end diff --git a/src/matrix_csc/matrix_csc.jl b/src/matrix_csc/matrix_csc.jl index 07bb76e..d7468bb 100644 --- a/src/matrix_csc/matrix_csc.jl +++ b/src/matrix_csc/matrix_csc.jl @@ -568,3 +568,41 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:GenericSpars end end end + +function LinearAlgebra.issymmetric(A::GenericSparseMatrixCSC) + m, n = size(A) + m == n || return false + + # Empty matrix is symmetric + nnz(A) == 0 && return true + + backend = get_backend(A) + + # Result array (initialize to true) + result = similar(nonzeros(A), Bool, 1) + fill!(result, true) + + kernel! = kernel_check_symmetry_csc!(backend) + kernel!(result, getcolptr(A), rowvals(A), nonzeros(A), Val{false}(); ndrange = (n,)) + + return @allowscalar result[1] +end + +function LinearAlgebra.ishermitian(A::GenericSparseMatrixCSC) + m, n = size(A) + m == n || return false + + # Empty matrix is hermitian + nnz(A) == 0 && return true + + backend = get_backend(A) + + # Result array (initialize to true) + result = similar(nonzeros(A), Bool, 1) + fill!(result, true) + + kernel! = kernel_check_symmetry_csc!(backend) + kernel!(result, getcolptr(A), rowvals(A), nonzeros(A), Val{true}(); ndrange = (n,)) + + return @allowscalar result[1] +end diff --git a/src/matrix_csc/matrix_csc_kernels.jl b/src/matrix_csc/matrix_csc_kernels.jl index f5d16be..ab3a66f 100644 --- a/src/matrix_csc/matrix_csc_kernels.jl +++ b/src/matrix_csc/matrix_csc_kernels.jl @@ -343,3 +343,67 @@ end end end end + +# Kernel for checking symmetry/hermitianity in CSC format +# For each entry A[row, col], we need to check if A[col, row] exists and has the correct value +# This kernel checks all entries in a given column +@kernel inbounds = true function kernel_check_symmetry_csc!( + result, + @Const(colptr), + @Const(rowval), + @Const(nzval), + ::Val{HERMITIAN}, + ) where {HERMITIAN} + col = @index(Global) + + is_valid = true + + # Iterate over all entries in this column + for idx in colptr[col]:(colptr[col + 1] - 1) + is_valid || break + + row = rowval[idx] + val = nzval[idx] + + # For diagonal elements, check self-conjugate property for hermitian + if row == col + if HERMITIAN && val != conj(val) + is_valid = false + end + else + # For off-diagonal: need to find A[col, row] (i.e., in column 'row', find entry at row 'col') + # Binary search in column 'row' for row index 'col' + lo = colptr[row] + hi = colptr[row + 1] - 1 + + found = false + while lo <= hi + mid = (lo + hi) ÷ 2 + mid_row = rowval[mid] + if mid_row == col + # Found the transpose entry + trans_val = nzval[mid] + expected = HERMITIAN ? conj(trans_val) : trans_val + if val != expected + is_valid = false + end + found = true + break + elseif mid_row < col + lo = mid + 1 + else + hi = mid - 1 + end + end + + # If transpose entry not found, matrix is not symmetric/hermitian + if !found + is_valid = false + end + end + end + + if !is_valid + result[1] = false + end +end diff --git a/src/matrix_csr/matrix_csr.jl b/src/matrix_csr/matrix_csr.jl index aaa5107..1c52daa 100644 --- a/src/matrix_csr/matrix_csr.jl +++ b/src/matrix_csr/matrix_csr.jl @@ -563,3 +563,41 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:GenericSpars end end end + +function LinearAlgebra.issymmetric(A::GenericSparseMatrixCSR) + m, n = size(A) + m == n || return false + + # Empty matrix is symmetric + nnz(A) == 0 && return true + + backend = get_backend(A) + + # Result array (initialize to true) + result = similar(nonzeros(A), Bool, 1) + fill!(result, true) + + kernel! = kernel_check_symmetry_csr!(backend) + kernel!(result, getrowptr(A), colvals(A), nonzeros(A), Val{false}(); ndrange = (m,)) + + return @allowscalar result[1] +end + +function LinearAlgebra.ishermitian(A::GenericSparseMatrixCSR) + m, n = size(A) + m == n || return false + + # Empty matrix is hermitian + nnz(A) == 0 && return true + + backend = get_backend(A) + + # Result array (initialize to true) + result = similar(nonzeros(A), Bool, 1) + fill!(result, true) + + kernel! = kernel_check_symmetry_csr!(backend) + kernel!(result, getrowptr(A), colvals(A), nonzeros(A), Val{true}(); ndrange = (m,)) + + return @allowscalar result[1] +end diff --git a/src/matrix_csr/matrix_csr_kernels.jl b/src/matrix_csr/matrix_csr_kernels.jl index bed5860..ef18f5a 100644 --- a/src/matrix_csr/matrix_csr_kernels.jl +++ b/src/matrix_csr/matrix_csr_kernels.jl @@ -341,3 +341,67 @@ end end end end + +# Kernel for checking symmetry/hermitianity in CSR format +# For each entry A[row, col], we need to check if A[col, row] exists and has the correct value +# This kernel checks all entries in a given row +@kernel inbounds = true function kernel_check_symmetry_csr!( + result, + @Const(rowptr), + @Const(colval), + @Const(nzval), + ::Val{HERMITIAN}, + ) where {HERMITIAN} + row = @index(Global) + + is_valid = true + + # Iterate over all entries in this row + for idx in rowptr[row]:(rowptr[row + 1] - 1) + is_valid || break + + col = colval[idx] + val = nzval[idx] + + # For diagonal elements, check self-conjugate property for hermitian + if row == col + if HERMITIAN && val != conj(val) + is_valid = false + end + else + # For off-diagonal: need to find A[col, row] (i.e., in row 'col', find entry at column 'row') + # Binary search in row 'col' for column index 'row' + lo = rowptr[col] + hi = rowptr[col + 1] - 1 + + found = false + while lo <= hi + mid = (lo + hi) ÷ 2 + mid_col = colval[mid] + if mid_col == row + # Found the transpose entry + trans_val = nzval[mid] + expected = HERMITIAN ? conj(trans_val) : trans_val + if val != expected + is_valid = false + end + found = true + break + elseif mid_col < row + lo = mid + 1 + else + hi = mid - 1 + end + end + + # If transpose entry not found, matrix is not symmetric/hermitian + if !found + is_valid = false + end + end + end + + if !is_valid + result[1] = false + end +end diff --git a/test/shared/matrix_coo.jl b/test/shared/matrix_coo.jl index e63b22d..0dc3762 100644 --- a/test/shared/matrix_coo.jl +++ b/test/shared/matrix_coo.jl @@ -93,6 +93,33 @@ function shared_test_linearalgebra_matrix_coo( end end + @testset "issymmetric and ishermitian" begin + for T in (complex_types...,) + n = 50 + # Non-symmetric/non-hermitian matrix + A_nonsym = sprand(T, n, n, 0.1) + A_nonsym[1, 2] = 1.0 + 0.0im + A_nonsym[2, 1] = 2.0 + 1.0im + dA_nonsym = adapt(op, GenericSparseMatrixCOO(A_nonsym)) + @test issymmetric(dA_nonsym) == false + @test ishermitian(dA_nonsym) == false + @test issymmetric(transpose(dA_nonsym)) == false + @test ishermitian(adjoint(dA_nonsym)) == false + + # Symmetric matrix (complex symmetric is NOT hermitian) + A_sym = sparse(A_nonsym + transpose(A_nonsym)) + dA_sym = adapt(op, GenericSparseMatrixCOO(A_sym)) + @test issymmetric(dA_sym) == true + @test issymmetric(transpose(dA_sym)) == true + + # Hermitian matrix (complex) + A_herm = sparse(A_nonsym + adjoint(A_nonsym)) + dA_herm = adapt(op, GenericSparseMatrixCOO(A_herm)) + @test ishermitian(dA_herm) == true + @test ishermitian(adjoint(dA_herm)) == true + end + end + @testset "Three-argument dot" begin for T in (int_types..., float_types..., complex_types...) if T in (Int32,) diff --git a/test/shared/matrix_csc.jl b/test/shared/matrix_csc.jl index 21c6c21..67313d2 100644 --- a/test/shared/matrix_csc.jl +++ b/test/shared/matrix_csc.jl @@ -94,6 +94,33 @@ function shared_test_linearalgebra_matrix_csc( end end + @testset "issymmetric and ishermitian" begin + for T in (complex_types...,) + n = 50 + # Non-symmetric/non-hermitian matrix + A_nonsym = sprand(T, n, n, 0.1) + A_nonsym[1, 2] = 1.0 + 0.0im + A_nonsym[2, 1] = 2.0 + 1.0im + dA_nonsym = adapt(op, GenericSparseMatrixCSC(A_nonsym)) + @test issymmetric(dA_nonsym) == false + @test ishermitian(dA_nonsym) == false + @test issymmetric(transpose(dA_nonsym)) == false + @test ishermitian(adjoint(dA_nonsym)) == false + + # Symmetric matrix (complex symmetric is NOT hermitian) + A_sym = sparse(A_nonsym + transpose(A_nonsym)) + dA_sym = adapt(op, GenericSparseMatrixCSC(A_sym)) + @test issymmetric(dA_sym) == true + @test issymmetric(transpose(dA_sym)) == true + + # Hermitian matrix (complex) + A_herm = sparse(A_nonsym + adjoint(A_nonsym)) + dA_herm = adapt(op, GenericSparseMatrixCSC(A_herm)) + @test ishermitian(dA_herm) == true + @test ishermitian(adjoint(dA_herm)) == true + end + end + @testset "Three-argument dot" begin for T in (int_types..., float_types..., complex_types...) for op_A in (identity, transpose, adjoint) diff --git a/test/shared/matrix_csr.jl b/test/shared/matrix_csr.jl index 5ebc8d5..75aaaba 100644 --- a/test/shared/matrix_csr.jl +++ b/test/shared/matrix_csr.jl @@ -92,6 +92,33 @@ function shared_test_linearalgebra_matrix_csr( end end + @testset "issymmetric and ishermitian" begin + for T in (complex_types...,) + n = 50 + # Non-symmetric/non-hermitian matrix + A_nonsym = sprand(T, n, n, 0.1) + A_nonsym[1, 2] = 1.0 + 0.0im + A_nonsym[2, 1] = 2.0 + 1.0im + dA_nonsym = adapt(op, GenericSparseMatrixCSR(A_nonsym)) + @test issymmetric(dA_nonsym) == false + @test ishermitian(dA_nonsym) == false + @test issymmetric(transpose(dA_nonsym)) == false + @test ishermitian(adjoint(dA_nonsym)) == false + + # Symmetric matrix (complex symmetric is NOT hermitian) + A_sym = sparse(A_nonsym + transpose(A_nonsym)) + dA_sym = adapt(op, GenericSparseMatrixCSR(A_sym)) + @test issymmetric(dA_sym) == true + @test issymmetric(transpose(dA_sym)) == true + + # Hermitian matrix (complex) + A_herm = sparse(A_nonsym + adjoint(A_nonsym)) + dA_herm = adapt(op, GenericSparseMatrixCSR(A_herm)) + @test ishermitian(dA_herm) == true + @test ishermitian(adjoint(dA_herm)) == true + end + end + @testset "Three-argument dot" begin for T in (int_types..., float_types..., complex_types...) for op_A in (identity, transpose, adjoint)