diff --git a/src/core.jl b/src/core.jl index 3582d04..3ebda9d 100644 --- a/src/core.jl +++ b/src/core.jl @@ -32,13 +32,17 @@ function LinearAlgebra.lmul!(x::Number, A::AbstractGenericSparseArray) return A end -function LinearAlgebra.rdiv!(A::AbstractGenericSparseArray, x::Number) - rdiv!(nonzeros(A), x) +function LinearAlgebra.rdiv!(A::AbstractGenericSparseArray{Tv}, x::Number) where {Tv} + rmul!(A, inv(Tv(x))) return A end Base.:+(A::AbstractGenericSparseArray) = copy(A) +Base.:*(α::Number, A::AbstractGenericSparseArray) = lmul!(α, copy(A)) +Base.:*(A::AbstractGenericSparseArray, α::Number) = rmul!(copy(A), α) +Base.:(/)(A::AbstractGenericSparseArray, α::Number) = rdiv!(copy(A), α) + Base.:*(A::AbstractGenericSparseArray, J::UniformScaling) = A * J.λ Base.:*(J::UniformScaling, A::AbstractGenericSparseArray) = J.λ * A diff --git a/src/matrix_coo/matrix_coo.jl b/src/matrix_coo/matrix_coo.jl index a6e4232..1441de8 100644 --- a/src/matrix_coo/matrix_coo.jl +++ b/src/matrix_coo/matrix_coo.jl @@ -104,18 +104,6 @@ function Base.zero(A::GenericSparseMatrixCOO) return GenericSparseMatrixCOO(A.m, A.n, rowind, colind, nzval) end -function Base.:(*)(α::Number, A::GenericSparseMatrixCOO) - return GenericSparseMatrixCOO( - A.m, - A.n, - copy(getrowind(A)), - copy(getcolind(A)), - α .* nonzeros(A), - ) -end -Base.:(*)(A::GenericSparseMatrixCOO, α::Number) = α * A -Base.:(/)(A::GenericSparseMatrixCOO, α::Number) = (1 / α) * A - function Base.:-(A::GenericSparseMatrixCOO) return GenericSparseMatrixCOO(A.m, A.n, copy(A.rowind), copy(A.colind), -A.nzval) end @@ -392,8 +380,7 @@ function Base.:+(A::GenericSparseMatrixCOO, B::GenericSparseMatrixCOO) ) C = GenericSparseMatrixCOO(m, n, rowind_C, colind_C, nzval_C) - dropzeros!(C) - return C + return dropzeros(C) end # Addition with transpose/adjoint support @@ -511,8 +498,7 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:GenericSpars ) C = GenericSparseMatrixCOO(m, n, rowind_C, colind_C, nzval_C) - dropzeros!(C) - return C + return dropzeros(C) end @eval function Base.:-(A::$TypeA, B::$TypeB) where {$(whereT1(:T1)), $(whereT2(:T2))} @@ -820,6 +806,13 @@ function SparseArrays.dropzeros!(A::GenericSparseMatrixCOO) return A end + if total_nnz == 0 + # All elements are zeros - some GPU backends (e.g., Metal) don't support + # resize to 0. Keep the stored zeros; users can use dropzeros() (non-mutating) + # which returns a new matrix with properly empty arrays. + return A + end + # Allocate temporary arrays for new data new_rowind = similar(rowind, total_nnz) new_colind = similar(colind, total_nnz) diff --git a/src/matrix_csc/matrix_csc.jl b/src/matrix_csc/matrix_csc.jl index 33a9ada..3052d00 100644 --- a/src/matrix_csc/matrix_csc.jl +++ b/src/matrix_csc/matrix_csc.jl @@ -85,18 +85,6 @@ function Base.zero(A::GenericSparseMatrixCSC) return GenericSparseMatrixCSC(A.m, A.n, colptr, rowval, nzval) end -function Base.:(*)(α::Number, A::GenericSparseMatrixCSC) - return GenericSparseMatrixCSC( - A.m, - A.n, - copy(getcolptr(A)), - copy(rowvals(A)), - α .* nonzeros(A), - ) -end -Base.:(*)(A::GenericSparseMatrixCSC, α::Number) = α * A -Base.:(/)(A::GenericSparseMatrixCSC, α::Number) = (1 / α) * A - function Base.:-(A::GenericSparseMatrixCSC) return GenericSparseMatrixCSC(A.m, A.n, copy(A.colptr), copy(A.rowval), -A.nzval) end @@ -365,8 +353,7 @@ function Base.:+(A::GenericSparseMatrixCSC, B::GenericSparseMatrixCSC) ) C = GenericSparseMatrixCSC(m, n, colptr_C, rowval_C, nzval_C) - dropzeros!(C) - return C + return dropzeros(C) end # Addition with transpose/adjoint support @@ -662,6 +649,13 @@ function SparseArrays.dropzeros!(A::GenericSparseMatrixCSC) cumsum_nnz = _cumsum_AK(nnz_per_col) total_nnz = @allowscalar cumsum_nnz[end] + if total_nnz == 0 + # All elements are zeros - some GPU backends (e.g., Metal) don't support + # resize to 0. Keep the stored zeros; users can use dropzeros() (non-mutating) + # which returns a new matrix with properly empty arrays. + return A + end + # Allocate temporary arrays for new data new_colptr = similar(getcolptr(A)) new_rowval = similar(rowvals(A), total_nnz) diff --git a/src/matrix_csr/matrix_csr.jl b/src/matrix_csr/matrix_csr.jl index f728cc1..7a98738 100644 --- a/src/matrix_csr/matrix_csr.jl +++ b/src/matrix_csr/matrix_csr.jl @@ -85,18 +85,6 @@ function Base.zero(A::GenericSparseMatrixCSR) return GenericSparseMatrixCSR(A.m, A.n, rowptr, rowval, nzval) end -function Base.:(*)(α::Number, A::GenericSparseMatrixCSR) - return GenericSparseMatrixCSR( - A.m, - A.n, - copy(getrowptr(A)), - copy(colvals(A)), - α .* nonzeros(A), - ) -end -Base.:(*)(A::GenericSparseMatrixCSR, α::Number) = α * A -Base.:(/)(A::GenericSparseMatrixCSR, α::Number) = (1 / α) * A - function Base.:-(A::GenericSparseMatrixCSR) return GenericSparseMatrixCSR(A.m, A.n, copy(A.rowptr), copy(A.colval), -A.nzval) end @@ -363,8 +351,7 @@ function Base.:+(A::GenericSparseMatrixCSR, B::GenericSparseMatrixCSR) ) C = GenericSparseMatrixCSR(m, n, rowptr_C, colval_C, nzval_C) - dropzeros!(C) - return C + return dropzeros(C) end # Addition with transpose/adjoint support @@ -657,6 +644,13 @@ function SparseArrays.dropzeros!(A::GenericSparseMatrixCSR) cumsum_nnz = _cumsum_AK(nnz_per_row) total_nnz = @allowscalar cumsum_nnz[end] + if total_nnz == 0 + # All elements are zeros - some GPU backends (e.g., Metal) don't support + # resize to 0. Keep the stored zeros; users can use dropzeros() (non-mutating) + # which returns a new matrix with properly empty arrays. + return A + end + # Allocate temporary arrays for new data new_rowptr = similar(getrowptr(A)) new_colval = similar(colvals(A), total_nnz)