From 222cd4a3a09b88ec9b7c28a335470bb9a48a1e61 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Sun, 1 Feb 2026 15:10:24 +0100 Subject: [PATCH 1/2] Use generic methods for `*` and `/` with a scalar --- src/core.jl | 7 ++++++- src/matrix_coo/matrix_coo.jl | 12 ------------ src/matrix_csc/matrix_csc.jl | 12 ------------ src/matrix_csr/matrix_csr.jl | 12 ------------ 4 files changed, 6 insertions(+), 37 deletions(-) diff --git a/src/core.jl b/src/core.jl index 3582d04..2c57e9c 100644 --- a/src/core.jl +++ b/src/core.jl @@ -33,12 +33,17 @@ function LinearAlgebra.lmul!(x::Number, A::AbstractGenericSparseArray) end function LinearAlgebra.rdiv!(A::AbstractGenericSparseArray, x::Number) - rdiv!(nonzeros(A), x) + nzvals = nonzeros(A) + nzvals ./= x return A end Base.:+(A::AbstractGenericSparseArray) = copy(A) +Base.:*(α::Number, A::AbstractGenericSparseArray) = lmul!(α, copy(A)) +Base.:*(A::AbstractGenericSparseArray, α::Number) = rmul!(copy(A), α) +Base.:(/)(A::AbstractGenericSparseArray, α::Number) = rdiv!(copy(A), α) + Base.:*(A::AbstractGenericSparseArray, J::UniformScaling) = A * J.λ Base.:*(J::UniformScaling, A::AbstractGenericSparseArray) = J.λ * A diff --git a/src/matrix_coo/matrix_coo.jl b/src/matrix_coo/matrix_coo.jl index a6e4232..0ff10f7 100644 --- a/src/matrix_coo/matrix_coo.jl +++ b/src/matrix_coo/matrix_coo.jl @@ -104,18 +104,6 @@ function Base.zero(A::GenericSparseMatrixCOO) return GenericSparseMatrixCOO(A.m, A.n, rowind, colind, nzval) end -function Base.:(*)(α::Number, A::GenericSparseMatrixCOO) - return GenericSparseMatrixCOO( - A.m, - A.n, - copy(getrowind(A)), - copy(getcolind(A)), - α .* nonzeros(A), - ) -end -Base.:(*)(A::GenericSparseMatrixCOO, α::Number) = α * A -Base.:(/)(A::GenericSparseMatrixCOO, α::Number) = (1 / α) * A - function Base.:-(A::GenericSparseMatrixCOO) return GenericSparseMatrixCOO(A.m, A.n, copy(A.rowind), copy(A.colind), -A.nzval) end diff --git a/src/matrix_csc/matrix_csc.jl b/src/matrix_csc/matrix_csc.jl index 33a9ada..fe16010 100644 --- a/src/matrix_csc/matrix_csc.jl +++ b/src/matrix_csc/matrix_csc.jl @@ -85,18 +85,6 @@ function Base.zero(A::GenericSparseMatrixCSC) return GenericSparseMatrixCSC(A.m, A.n, colptr, rowval, nzval) end -function Base.:(*)(α::Number, A::GenericSparseMatrixCSC) - return GenericSparseMatrixCSC( - A.m, - A.n, - copy(getcolptr(A)), - copy(rowvals(A)), - α .* nonzeros(A), - ) -end -Base.:(*)(A::GenericSparseMatrixCSC, α::Number) = α * A -Base.:(/)(A::GenericSparseMatrixCSC, α::Number) = (1 / α) * A - function Base.:-(A::GenericSparseMatrixCSC) return GenericSparseMatrixCSC(A.m, A.n, copy(A.colptr), copy(A.rowval), -A.nzval) end diff --git a/src/matrix_csr/matrix_csr.jl b/src/matrix_csr/matrix_csr.jl index f728cc1..bc34a59 100644 --- a/src/matrix_csr/matrix_csr.jl +++ b/src/matrix_csr/matrix_csr.jl @@ -85,18 +85,6 @@ function Base.zero(A::GenericSparseMatrixCSR) return GenericSparseMatrixCSR(A.m, A.n, rowptr, rowval, nzval) end -function Base.:(*)(α::Number, A::GenericSparseMatrixCSR) - return GenericSparseMatrixCSR( - A.m, - A.n, - copy(getrowptr(A)), - copy(colvals(A)), - α .* nonzeros(A), - ) -end -Base.:(*)(A::GenericSparseMatrixCSR, α::Number) = α * A -Base.:(/)(A::GenericSparseMatrixCSR, α::Number) = (1 / α) * A - function Base.:-(A::GenericSparseMatrixCSR) return GenericSparseMatrixCSR(A.m, A.n, copy(A.rowptr), copy(A.colval), -A.nzval) end From 165e06ce66980c72d0ffc3f044cfb5edb791187e Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Sun, 1 Feb 2026 16:54:58 +0100 Subject: [PATCH 2/2] Fix errors with Metal --- src/core.jl | 5 ++--- src/matrix_coo/matrix_coo.jl | 13 +++++++++---- src/matrix_csc/matrix_csc.jl | 10 ++++++++-- src/matrix_csr/matrix_csr.jl | 10 ++++++++-- 4 files changed, 27 insertions(+), 11 deletions(-) diff --git a/src/core.jl b/src/core.jl index 2c57e9c..3ebda9d 100644 --- a/src/core.jl +++ b/src/core.jl @@ -32,9 +32,8 @@ function LinearAlgebra.lmul!(x::Number, A::AbstractGenericSparseArray) return A end -function LinearAlgebra.rdiv!(A::AbstractGenericSparseArray, x::Number) - nzvals = nonzeros(A) - nzvals ./= x +function LinearAlgebra.rdiv!(A::AbstractGenericSparseArray{Tv}, x::Number) where {Tv} + rmul!(A, inv(Tv(x))) return A end diff --git a/src/matrix_coo/matrix_coo.jl b/src/matrix_coo/matrix_coo.jl index 0ff10f7..1441de8 100644 --- a/src/matrix_coo/matrix_coo.jl +++ b/src/matrix_coo/matrix_coo.jl @@ -380,8 +380,7 @@ function Base.:+(A::GenericSparseMatrixCOO, B::GenericSparseMatrixCOO) ) C = GenericSparseMatrixCOO(m, n, rowind_C, colind_C, nzval_C) - dropzeros!(C) - return C + return dropzeros(C) end # Addition with transpose/adjoint support @@ -499,8 +498,7 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:GenericSpars ) C = GenericSparseMatrixCOO(m, n, rowind_C, colind_C, nzval_C) - dropzeros!(C) - return C + return dropzeros(C) end @eval function Base.:-(A::$TypeA, B::$TypeB) where {$(whereT1(:T1)), $(whereT2(:T2))} @@ -808,6 +806,13 @@ function SparseArrays.dropzeros!(A::GenericSparseMatrixCOO) return A end + if total_nnz == 0 + # All elements are zeros - some GPU backends (e.g., Metal) don't support + # resize to 0. Keep the stored zeros; users can use dropzeros() (non-mutating) + # which returns a new matrix with properly empty arrays. + return A + end + # Allocate temporary arrays for new data new_rowind = similar(rowind, total_nnz) new_colind = similar(colind, total_nnz) diff --git a/src/matrix_csc/matrix_csc.jl b/src/matrix_csc/matrix_csc.jl index fe16010..3052d00 100644 --- a/src/matrix_csc/matrix_csc.jl +++ b/src/matrix_csc/matrix_csc.jl @@ -353,8 +353,7 @@ function Base.:+(A::GenericSparseMatrixCSC, B::GenericSparseMatrixCSC) ) C = GenericSparseMatrixCSC(m, n, colptr_C, rowval_C, nzval_C) - dropzeros!(C) - return C + return dropzeros(C) end # Addition with transpose/adjoint support @@ -650,6 +649,13 @@ function SparseArrays.dropzeros!(A::GenericSparseMatrixCSC) cumsum_nnz = _cumsum_AK(nnz_per_col) total_nnz = @allowscalar cumsum_nnz[end] + if total_nnz == 0 + # All elements are zeros - some GPU backends (e.g., Metal) don't support + # resize to 0. Keep the stored zeros; users can use dropzeros() (non-mutating) + # which returns a new matrix with properly empty arrays. + return A + end + # Allocate temporary arrays for new data new_colptr = similar(getcolptr(A)) new_rowval = similar(rowvals(A), total_nnz) diff --git a/src/matrix_csr/matrix_csr.jl b/src/matrix_csr/matrix_csr.jl index bc34a59..7a98738 100644 --- a/src/matrix_csr/matrix_csr.jl +++ b/src/matrix_csr/matrix_csr.jl @@ -351,8 +351,7 @@ function Base.:+(A::GenericSparseMatrixCSR, B::GenericSparseMatrixCSR) ) C = GenericSparseMatrixCSR(m, n, rowptr_C, colval_C, nzval_C) - dropzeros!(C) - return C + return dropzeros(C) end # Addition with transpose/adjoint support @@ -645,6 +644,13 @@ function SparseArrays.dropzeros!(A::GenericSparseMatrixCSR) cumsum_nnz = _cumsum_AK(nnz_per_row) total_nnz = @allowscalar cumsum_nnz[end] + if total_nnz == 0 + # All elements are zeros - some GPU backends (e.g., Metal) don't support + # resize to 0. Keep the stored zeros; users can use dropzeros() (non-mutating) + # which returns a new matrix with properly empty arrays. + return A + end + # Allocate temporary arrays for new data new_rowptr = similar(getrowptr(A)) new_colval = similar(colvals(A), total_nnz)