Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/GenericSparseArrays.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module GenericSparseArrays

using LinearAlgebra
import LinearAlgebra: wrap, copymutable_oftype, __normalize!, kron, Diagonal
import LinearAlgebra: wrap, copymutable_oftype, __normalize!, kron, Diagonal, issymmetric, ishermitian
using SparseArrays
import SparseArrays: SparseVector, SparseMatrixCSC
import SparseArrays: getcolptr, getrowval, getnzval, nonzeroinds
Expand Down
5 changes: 5 additions & 0 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ end

Base.:+(B::DenseMatrix, A::AbstractGenericSparseMatrix) = A + B

LinearAlgebra.issymmetric(A::Transpose{<:Any, <:AbstractGenericSparseMatrix}) = issymmetric(parent(A))
LinearAlgebra.issymmetric(A::Adjoint{<:Any, <:AbstractGenericSparseMatrix}) = issymmetric(parent(A))
LinearAlgebra.ishermitian(A::Transpose{<:Any, <:AbstractGenericSparseMatrix}) = ishermitian(parent(A))
LinearAlgebra.ishermitian(A::Adjoint{<:Any, <:AbstractGenericSparseMatrix}) = ishermitian(parent(A))

# Keep this at the end of the file
trans_adj_wrappers(fmt) = (
(T -> :($fmt{$T}), false, false, identity, T -> :($T)),
Expand Down
12 changes: 12 additions & 0 deletions src/matrix_coo/matrix_coo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -739,3 +739,15 @@ for (wrap, trans, conj, unwrap, whereT) in trans_adj_wrappers(:GenericSparseMatr
return kron(A_coo, D)
end
end

function LinearAlgebra.issymmetric(A::GenericSparseMatrixCOO)
m, n = size(A)
m == n || return false
return issymmetric(GenericSparseMatrixCSC(A))
end

function LinearAlgebra.ishermitian(A::GenericSparseMatrixCOO)
m, n = size(A)
m == n || return false
return ishermitian(GenericSparseMatrixCSC(A))
end
38 changes: 38 additions & 0 deletions src/matrix_csc/matrix_csc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -568,3 +568,41 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:GenericSpars
end
end
end

function LinearAlgebra.issymmetric(A::GenericSparseMatrixCSC)
m, n = size(A)
m == n || return false

# Empty matrix is symmetric
nnz(A) == 0 && return true

backend = get_backend(A)

# Result array (initialize to true)
result = similar(nonzeros(A), Bool, 1)
fill!(result, true)

kernel! = kernel_check_symmetry_csc!(backend)
kernel!(result, getcolptr(A), rowvals(A), nonzeros(A), Val{false}(); ndrange = (n,))

return @allowscalar result[1]
end

function LinearAlgebra.ishermitian(A::GenericSparseMatrixCSC)
m, n = size(A)
m == n || return false

# Empty matrix is hermitian
nnz(A) == 0 && return true

backend = get_backend(A)

# Result array (initialize to true)
result = similar(nonzeros(A), Bool, 1)
fill!(result, true)

kernel! = kernel_check_symmetry_csc!(backend)
kernel!(result, getcolptr(A), rowvals(A), nonzeros(A), Val{true}(); ndrange = (n,))

return @allowscalar result[1]
end
64 changes: 64 additions & 0 deletions src/matrix_csc/matrix_csc_kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,67 @@ end
end
end
end

# Kernel for checking symmetry/hermitianity in CSC format
# For each entry A[row, col], we need to check if A[col, row] exists and has the correct value
# This kernel checks all entries in a given column
@kernel inbounds = true function kernel_check_symmetry_csc!(
result,
@Const(colptr),
@Const(rowval),
@Const(nzval),
::Val{HERMITIAN},
) where {HERMITIAN}
col = @index(Global)

is_valid = true

# Iterate over all entries in this column
for idx in colptr[col]:(colptr[col + 1] - 1)
is_valid || break

row = rowval[idx]
val = nzval[idx]

# For diagonal elements, check self-conjugate property for hermitian
if row == col
if HERMITIAN && val != conj(val)
is_valid = false
end
else
# For off-diagonal: need to find A[col, row] (i.e., in column 'row', find entry at row 'col')
# Binary search in column 'row' for row index 'col'
lo = colptr[row]
hi = colptr[row + 1] - 1

found = false
while lo <= hi
mid = (lo + hi) ÷ 2
mid_row = rowval[mid]
if mid_row == col
# Found the transpose entry
trans_val = nzval[mid]
expected = HERMITIAN ? conj(trans_val) : trans_val
if val != expected
is_valid = false
end
found = true
break
elseif mid_row < col
lo = mid + 1
else
hi = mid - 1
end
end

# If transpose entry not found, matrix is not symmetric/hermitian
if !found
is_valid = false
end
end
end

if !is_valid
result[1] = false
end
end
38 changes: 38 additions & 0 deletions src/matrix_csr/matrix_csr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,41 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:GenericSpars
end
end
end

function LinearAlgebra.issymmetric(A::GenericSparseMatrixCSR)
m, n = size(A)
m == n || return false

# Empty matrix is symmetric
nnz(A) == 0 && return true

backend = get_backend(A)

# Result array (initialize to true)
result = similar(nonzeros(A), Bool, 1)
fill!(result, true)

kernel! = kernel_check_symmetry_csr!(backend)
kernel!(result, getrowptr(A), colvals(A), nonzeros(A), Val{false}(); ndrange = (m,))

return @allowscalar result[1]
end

function LinearAlgebra.ishermitian(A::GenericSparseMatrixCSR)
m, n = size(A)
m == n || return false

# Empty matrix is hermitian
nnz(A) == 0 && return true

backend = get_backend(A)

# Result array (initialize to true)
result = similar(nonzeros(A), Bool, 1)
fill!(result, true)

kernel! = kernel_check_symmetry_csr!(backend)
kernel!(result, getrowptr(A), colvals(A), nonzeros(A), Val{true}(); ndrange = (m,))

return @allowscalar result[1]
end
64 changes: 64 additions & 0 deletions src/matrix_csr/matrix_csr_kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,3 +341,67 @@ end
end
end
end

# Kernel for checking symmetry/hermitianity in CSR format
# For each entry A[row, col], we need to check if A[col, row] exists and has the correct value
# This kernel checks all entries in a given row
@kernel inbounds = true function kernel_check_symmetry_csr!(
result,
@Const(rowptr),
@Const(colval),
@Const(nzval),
::Val{HERMITIAN},
) where {HERMITIAN}
row = @index(Global)

is_valid = true

# Iterate over all entries in this row
for idx in rowptr[row]:(rowptr[row + 1] - 1)
is_valid || break

col = colval[idx]
val = nzval[idx]

# For diagonal elements, check self-conjugate property for hermitian
if row == col
if HERMITIAN && val != conj(val)
is_valid = false
end
else
# For off-diagonal: need to find A[col, row] (i.e., in row 'col', find entry at column 'row')
# Binary search in row 'col' for column index 'row'
lo = rowptr[col]
hi = rowptr[col + 1] - 1

found = false
while lo <= hi
mid = (lo + hi) ÷ 2
mid_col = colval[mid]
if mid_col == row
# Found the transpose entry
trans_val = nzval[mid]
expected = HERMITIAN ? conj(trans_val) : trans_val
if val != expected
is_valid = false
end
found = true
break
elseif mid_col < row
lo = mid + 1
else
hi = mid - 1
end
end

# If transpose entry not found, matrix is not symmetric/hermitian
if !found
is_valid = false
end
end
end

if !is_valid
result[1] = false
end
end
27 changes: 27 additions & 0 deletions test/shared/matrix_coo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,33 @@ function shared_test_linearalgebra_matrix_coo(
end
end

@testset "issymmetric and ishermitian" begin
for T in (complex_types...,)
n = 50
# Non-symmetric/non-hermitian matrix
A_nonsym = sprand(T, n, n, 0.1)
A_nonsym[1, 2] = 1.0 + 0.0im
A_nonsym[2, 1] = 2.0 + 1.0im
dA_nonsym = adapt(op, GenericSparseMatrixCOO(A_nonsym))
@test issymmetric(dA_nonsym) == false
@test ishermitian(dA_nonsym) == false
@test issymmetric(transpose(dA_nonsym)) == false
@test ishermitian(adjoint(dA_nonsym)) == false

# Symmetric matrix (complex symmetric is NOT hermitian)
A_sym = sparse(A_nonsym + transpose(A_nonsym))
dA_sym = adapt(op, GenericSparseMatrixCOO(A_sym))
@test issymmetric(dA_sym) == true
@test issymmetric(transpose(dA_sym)) == true

# Hermitian matrix (complex)
A_herm = sparse(A_nonsym + adjoint(A_nonsym))
dA_herm = adapt(op, GenericSparseMatrixCOO(A_herm))
@test ishermitian(dA_herm) == true
@test ishermitian(adjoint(dA_herm)) == true
end
end

@testset "Three-argument dot" begin
for T in (int_types..., float_types..., complex_types...)
if T in (Int32,)
Expand Down
27 changes: 27 additions & 0 deletions test/shared/matrix_csc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,33 @@ function shared_test_linearalgebra_matrix_csc(
end
end

@testset "issymmetric and ishermitian" begin
for T in (complex_types...,)
n = 50
# Non-symmetric/non-hermitian matrix
A_nonsym = sprand(T, n, n, 0.1)
A_nonsym[1, 2] = 1.0 + 0.0im
A_nonsym[2, 1] = 2.0 + 1.0im
dA_nonsym = adapt(op, GenericSparseMatrixCSC(A_nonsym))
@test issymmetric(dA_nonsym) == false
@test ishermitian(dA_nonsym) == false
@test issymmetric(transpose(dA_nonsym)) == false
@test ishermitian(adjoint(dA_nonsym)) == false

# Symmetric matrix (complex symmetric is NOT hermitian)
A_sym = sparse(A_nonsym + transpose(A_nonsym))
dA_sym = adapt(op, GenericSparseMatrixCSC(A_sym))
@test issymmetric(dA_sym) == true
@test issymmetric(transpose(dA_sym)) == true

# Hermitian matrix (complex)
A_herm = sparse(A_nonsym + adjoint(A_nonsym))
dA_herm = adapt(op, GenericSparseMatrixCSC(A_herm))
@test ishermitian(dA_herm) == true
@test ishermitian(adjoint(dA_herm)) == true
end
end

@testset "Three-argument dot" begin
for T in (int_types..., float_types..., complex_types...)
for op_A in (identity, transpose, adjoint)
Expand Down
27 changes: 27 additions & 0 deletions test/shared/matrix_csr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,33 @@ function shared_test_linearalgebra_matrix_csr(
end
end

@testset "issymmetric and ishermitian" begin
for T in (complex_types...,)
n = 50
# Non-symmetric/non-hermitian matrix
A_nonsym = sprand(T, n, n, 0.1)
A_nonsym[1, 2] = 1.0 + 0.0im
A_nonsym[2, 1] = 2.0 + 1.0im
dA_nonsym = adapt(op, GenericSparseMatrixCSR(A_nonsym))
@test issymmetric(dA_nonsym) == false
@test ishermitian(dA_nonsym) == false
@test issymmetric(transpose(dA_nonsym)) == false
@test ishermitian(adjoint(dA_nonsym)) == false

# Symmetric matrix (complex symmetric is NOT hermitian)
A_sym = sparse(A_nonsym + transpose(A_nonsym))
dA_sym = adapt(op, GenericSparseMatrixCSR(A_sym))
@test issymmetric(dA_sym) == true
@test issymmetric(transpose(dA_sym)) == true

# Hermitian matrix (complex)
A_herm = sparse(A_nonsym + adjoint(A_nonsym))
dA_herm = adapt(op, GenericSparseMatrixCSR(A_herm))
@test ishermitian(dA_herm) == true
@test ishermitian(adjoint(dA_herm)) == true
end
end

@testset "Three-argument dot" begin
for T in (int_types..., float_types..., complex_types...)
for op_A in (identity, transpose, adjoint)
Expand Down
Loading