diff --git a/ext/DeviceSparseArraysJLArraysExt.jl b/ext/DeviceSparseArraysJLArraysExt.jl index 523756f..a71257d 100644 --- a/ext/DeviceSparseArraysJLArraysExt.jl +++ b/ext/DeviceSparseArraysJLArraysExt.jl @@ -5,5 +5,7 @@ import DeviceSparseArrays DeviceSparseArrays._sortperm_AK(x::JLArray) = JLArray(sortperm(collect(x))) DeviceSparseArrays._cumsum_AK(x::JLArray) = JLArray(cumsum(collect(x))) +DeviceSparseArrays._searchsortedfirst_AK(v::JLArray, x::JLArray) = + JLArray(searchsortedfirst.(Ref(collect(v)), collect(x))) end diff --git a/src/conversions/conversion_kernels.jl b/src/conversions/conversion_kernels.jl index 22c96a4..a2fd734 100644 --- a/src/conversions/conversion_kernels.jl +++ b/src/conversions/conversion_kernels.jl @@ -55,17 +55,3 @@ end i = @index(Global) keys[i] = rowind[i] * n + colind[i] end - -# Kernel for counting entries per column (for COO → CSC) -@kernel inbounds=true function kernel_count_per_col!(colptr, @Const(colind_sorted)) - i = @index(Global) - col = colind_sorted[i] - @atomic colptr[col+1] += 1 -end - -# Kernel for counting entries per row (for COO → CSR) -@kernel inbounds=true function kernel_count_per_row!(rowptr, @Const(rowind_sorted)) - i = @index(Global) - row = rowind_sorted[i] - @atomic rowptr[row+1] += 1 -end diff --git a/src/conversions/conversions.jl b/src/conversions/conversions.jl index d90bdbd..f98f620 100644 --- a/src/conversions/conversions.jl +++ b/src/conversions/conversions.jl @@ -165,21 +165,38 @@ function DeviceSparseMatrixCSC(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti} colind_sorted = A.colind[perm] nzval_sorted = A.nzval[perm] - # Build colptr on device using a histogram approach - colptr = similar(A.colind, Ti, n + 1) - fill!(colptr, zero(Ti)) - - # Count entries per column - kernel! = kernel_count_per_col!(backend) - kernel!(colptr, colind_sorted; ndrange = (nnz_count,)) + # Build colptr on device using searchsortedfirst approach + # Since colind_sorted is sorted, find where each column starts + col_indices = similar(A.colind, Ti, n) + col_indices .= Ti(1):Ti(n) - # Compute cumulative sum - @allowscalar colptr[1] = 1 # TODO: Is there a better way to do this? - colptr[2:end] .= _cumsum_AK(colptr[2:end]) .+ 1 + # Find start positions for each column + colptr = similar(A.colind, Ti, n + 1) + colptr[1:n] .= _searchsortedfirst_AK(colind_sorted, col_indices) + @allowscalar colptr[n+1] = Ti(nnz_count + 1) return DeviceSparseMatrixCSC(m, n, colptr, rowind_sorted, nzval_sorted) end +# Transpose and Adjoint conversions for COO to CSC +DeviceSparseMatrixCSC(A::Transpose{Tv,<:DeviceSparseMatrixCOO}) where {Tv} = + DeviceSparseMatrixCSC(DeviceSparseMatrixCOO( + size(A, 1), + size(A, 2), + A.parent.colind, + A.parent.rowind, + A.parent.nzval, + )) + +DeviceSparseMatrixCSC(A::Adjoint{Tv,<:DeviceSparseMatrixCOO}) where {Tv} = + DeviceSparseMatrixCSC(DeviceSparseMatrixCOO( + size(A, 1), + size(A, 2), + A.parent.colind, + A.parent.rowind, + conj.(A.parent.nzval), + )) + # ============================================================================ # CSR ↔ COO Conversions # ============================================================================ @@ -223,17 +240,15 @@ function DeviceSparseMatrixCSR(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti} colind_sorted = A.colind[perm] nzval_sorted = A.nzval[perm] - # Build rowptr on device using a histogram approach - rowptr = similar(A.rowind, Ti, m + 1) - fill!(rowptr, zero(Ti)) - - # Count entries per row - kernel! = kernel_count_per_row!(backend) - kernel!(rowptr, rowind_sorted; ndrange = (nnz_count,)) + # Build rowptr on device using searchsortedfirst approach + # Since rowind_sorted is sorted, find where each row starts + row_indices = similar(A.rowind, Ti, m) + row_indices .= Ti(1):Ti(m) - # Compute cumulative sum - @allowscalar rowptr[1] = 1 # TODO: Is there a better way to do this? - rowptr[2:end] .= _cumsum_AK(rowptr[2:end]) .+ 1 + # Find start positions for each row + rowptr = similar(A.rowind, Ti, m + 1) + rowptr[1:m] .= _searchsortedfirst_AK(rowind_sorted, row_indices) + @allowscalar rowptr[m+1] = Ti(nnz_count + 1) return DeviceSparseMatrixCSR(m, n, rowptr, colind_sorted, nzval_sorted) end diff --git a/src/helpers.jl b/src/helpers.jl index 025b005..770132c 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -1,3 +1,4 @@ # Helper functions to call AcceleratedKernels methods _sortperm_AK(x) = AcceleratedKernels.sortperm(x) _cumsum_AK(x) = AcceleratedKernels.cumsum(x) +_searchsortedfirst_AK(v, x) = AcceleratedKernels.searchsortedfirst(v, x) diff --git a/src/matrix_coo/matrix_coo.jl b/src/matrix_coo/matrix_coo.jl index 56c8bdb..e9f1202 100644 --- a/src/matrix_coo/matrix_coo.jl +++ b/src/matrix_coo/matrix_coo.jl @@ -385,7 +385,13 @@ function Base.:+(A::DeviceSparseMatrixCOO, B::DeviceSparseMatrixCOO) # Mark unique entries (first occurrence of each (row, col) pair) keep_mask = similar(rowind_sorted, Bool, nnz_concat) kernel_mark! = kernel_mark_unique_coo!(backend) - kernel_mark!(keep_mask, rowind_sorted, colind_sorted, nnz_concat; ndrange = (nnz_concat,)) + kernel_mark!( + keep_mask, + rowind_sorted, + colind_sorted, + nnz_concat; + ndrange = (nnz_concat,), + ) # Compute write indices using cumsum write_indices = _cumsum_AK(keep_mask) @@ -415,42 +421,43 @@ end # Addition with transpose/adjoint support for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparseMatrixCOO) - for (wrapb, transb, conjb, unwrapb, whereT2) in trans_adj_wrappers(:DeviceSparseMatrixCOO) + for (wrapb, transb, conjb, unwrapb, whereT2) in + trans_adj_wrappers(:DeviceSparseMatrixCOO) # Skip the case where both are not transposed (already handled above) (transa == false && transb == false) && continue - + TypeA = wrapa(:(T1)) TypeB = wrapb(:(T2)) - + @eval function Base.:+(A::$TypeA, B::$TypeB) where {$(whereT1(:T1)),$(whereT2(:T2))} size(A) == size(B) || throw( DimensionMismatch( "dimensions must match: A has dims $(size(A)), B has dims $(size(B))", ), ) - + _A = $(unwrapa(:A)) _B = $(unwrapb(:B)) - + backend_A = get_backend(_A) backend_B = get_backend(_B) backend_A == backend_B || throw(ArgumentError("Both matrices must have the same backend")) - + m, n = size(A) Ti = eltype(getrowind(_A)) Tv = promote_type(eltype(nonzeros(_A)), eltype(nonzeros(_B))) - + # For transposed COO, swap row and column indices nnz_A = nnz(_A) nnz_B = nnz(_B) nnz_concat = nnz_A + nnz_B - + # Allocate concatenated arrays rowind_concat = similar(getrowind(_A), nnz_concat) colind_concat = similar(getcolind(_A), nnz_concat) nzval_concat = similar(nonzeros(_A), Tv, nnz_concat) - + # Copy entries from A (potentially swapping row/col for transpose) if $transa rowind_concat[1:nnz_A] .= getcolind(_A) # Swap for transpose @@ -464,7 +471,7 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparse else nzval_concat[1:nnz_A] .= nonzeros(_A) end - + # Copy entries from B (potentially swapping row/col for transpose) if $transb rowind_concat[(nnz_A+1):end] .= getcolind(_B) # Swap for transpose @@ -478,29 +485,41 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparse else nzval_concat[(nnz_A+1):end] .= nonzeros(_B) end - + # Sort and compact (same as before) backend = backend_A keys = similar(rowind_concat, Ti, nnz_concat) kernel_make_keys! = kernel_make_csc_keys!(backend) - kernel_make_keys!(keys, rowind_concat, colind_concat, m; ndrange = (nnz_concat,)) - + kernel_make_keys!( + keys, + rowind_concat, + colind_concat, + m; + ndrange = (nnz_concat,), + ) + perm = _sortperm_AK(keys) rowind_sorted = rowind_concat[perm] colind_sorted = colind_concat[perm] nzval_sorted = nzval_concat[perm] - + keep_mask = similar(rowind_sorted, Bool, nnz_concat) kernel_mark! = kernel_mark_unique_coo!(backend) - kernel_mark!(keep_mask, rowind_sorted, colind_sorted, nnz_concat; ndrange = (nnz_concat,)) - + kernel_mark!( + keep_mask, + rowind_sorted, + colind_sorted, + nnz_concat; + ndrange = (nnz_concat,), + ) + write_indices = _cumsum_AK(keep_mask) nnz_final = @allowscalar write_indices[nnz_concat] - + rowind_C = similar(getrowind(_A), nnz_final) colind_C = similar(getcolind(_A), nnz_final) nzval_C = similar(nonzeros(_A), Tv, nnz_final) - + kernel_compact! = kernel_compact_coo!(backend) kernel_compact!( rowind_C, @@ -513,7 +532,7 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparse nnz_concat; ndrange = (nnz_concat,), ) - + return DeviceSparseMatrixCOO(m, n, rowind_C, colind_C, nzval_C) end end @@ -587,3 +606,86 @@ function LinearAlgebra.kron( return DeviceSparseMatrixCOO(m_C, n_C, rowind_C, colind_C, nzval_C) end + +""" + *(A::DeviceSparseMatrixCOO, B::DeviceSparseMatrixCOO) + +Multiply two sparse matrices in COO format. Both matrices must have compatible dimensions +(number of columns of A equals number of rows of B) and be on the same backend (device). + +The multiplication converts to CSC format, performs the multiplication with GPU-compatible +kernels, and converts back to COO format. This approach is used for all cases including +transpose/adjoint since COO doesn't have an efficient direct multiplication algorithm. + +# Examples +```jldoctest +julia> using DeviceSparseArrays, SparseArrays + +julia> A = DeviceSparseMatrixCOO(sparse([1, 2], [1, 2], [2.0, 3.0], 2, 2)); + +julia> B = DeviceSparseMatrixCOO(sparse([1, 2], [1, 2], [4.0, 5.0], 2, 2)); + +julia> C = A * B; + +julia> collect(C) +2×2 Matrix{Float64}: + 8.0 0.0 + 0.0 15.0 +``` +""" +function Base.:(*)(A::DeviceSparseMatrixCOO, B::DeviceSparseMatrixCOO) + size(A, 2) == size(B, 1) || throw( + DimensionMismatch( + "second dimension of A, $(size(A,2)), does not match first dimension of B, $(size(B,1))", + ), + ) + + backend_A = get_backend(A) + backend_B = get_backend(B) + backend_A == backend_B || + throw(ArgumentError("Both matrices must have the same backend")) + + # Convert to CSC, multiply, convert back to COO + # This is acceptable as COO doesn't have an efficient direct multiplication algorithm + # and CSC provides the sorted structure needed for efficient SpGEMM + A_csc = DeviceSparseMatrixCSC(A) + B_csc = DeviceSparseMatrixCSC(B) + C_csc = A_csc * B_csc + return DeviceSparseMatrixCOO(C_csc) +end + +# Multiplication with transpose/adjoint support - all cases use the same approach +for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparseMatrixCOO) + for (wrapb, transb, conjb, unwrapb, whereT2) in + trans_adj_wrappers(:DeviceSparseMatrixCOO) + # Skip the case where both are not transposed (already handled above) + (transa == false && transb == false) && continue + + TypeA = wrapa(:(T1)) + TypeB = wrapb(:(T2)) + + @eval function Base.:(*)( + A::$TypeA, + B::$TypeB, + ) where {$(whereT1(:T1)),$(whereT2(:T2))} + size(A, 2) == size(B, 1) || throw( + DimensionMismatch( + "second dimension of A, $(size(A,2)), does not match first dimension of B, $(size(B,1))", + ), + ) + + backend_A = get_backend($(unwrapa(:A))) + backend_B = get_backend($(unwrapb(:B))) + backend_A == backend_B || + throw(ArgumentError("Both matrices must have the same backend")) + + # Convert to CSC (handles transpose/adjoint), multiply, convert back to COO + # Same approach as the base case since COO doesn't have an efficient + # direct multiplication algorithm + A_csc = DeviceSparseMatrixCSC(A) + B_csc = DeviceSparseMatrixCSC(B) + C_csc = A_csc * B_csc + return DeviceSparseMatrixCOO(C_csc) + end + end +end diff --git a/src/matrix_coo/matrix_coo_kernels.jl b/src/matrix_coo/matrix_coo_kernels.jl index e412fe2..511a40e 100644 --- a/src/matrix_coo/matrix_coo_kernels.jl +++ b/src/matrix_coo/matrix_coo_kernels.jl @@ -216,16 +216,18 @@ end if i <= nnz_in out_idx = write_indices[i] - + # If this is a new entry (or first of duplicates), write it if i == 1 || (rowind_in[i] != rowind_in[i-1] || colind_in[i] != colind_in[i-1]) rowind_out[out_idx] = rowind_in[i] colind_out[out_idx] = colind_in[i] - + # Sum all duplicates val_sum = nzval_in[i] j = i + 1 - while j <= nnz_in && rowind_in[j] == rowind_in[i] && colind_in[j] == colind_in[i] + while j <= nnz_in && + rowind_in[j] == rowind_in[i] && + colind_in[j] == colind_in[i] val_sum += nzval_in[j] j += 1 end diff --git a/src/matrix_csc/matrix_csc.jl b/src/matrix_csc/matrix_csc.jl index 36ff448..0de50ce 100644 --- a/src/matrix_csc/matrix_csc.jl +++ b/src/matrix_csc/matrix_csc.jl @@ -365,7 +365,7 @@ function Base.:+(A::DeviceSparseMatrixCSC, B::DeviceSparseMatrixCSC) colptr_C[1:1] .= one(Ti) # Allocate result arrays - nnz_total = @allowscalar colptr_C[n+1] - one(Ti) + nnz_total = @allowscalar colptr_C[n+1] - one(Ti) rowval_C = similar(getrowval(A), nnz_total) nzval_C = similar(nonzeros(A), Tv, nnz_total) @@ -391,27 +391,28 @@ end # Addition with transpose/adjoint support for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparseMatrixCSC) - for (wrapb, transb, conjb, unwrapb, whereT2) in trans_adj_wrappers(:DeviceSparseMatrixCSC) + for (wrapb, transb, conjb, unwrapb, whereT2) in + trans_adj_wrappers(:DeviceSparseMatrixCSC) # Skip the case where both are not transposed (already handled above) (transa == false && transb == false) && continue - + TypeA = wrapa(:(T1)) TypeB = wrapb(:(T2)) - + @eval function Base.:+(A::$TypeA, B::$TypeB) where {$(whereT1(:T1)),$(whereT2(:T2))} size(A) == size(B) || throw( DimensionMismatch( "dimensions must match: A has dims $(size(A)), B has dims $(size(B))", ), ) - + # Convert both to CSR (transpose/adjoint of CSC has CSR structure) # and use existing CSR + CSR addition. The conversion methods # already handle transpose/adjoint correctly. A_csr = DeviceSparseMatrixCSR(A) B_csr = DeviceSparseMatrixCSR(B) result_csr = A_csr + B_csr - + # Convert back to CSC return DeviceSparseMatrixCSC(result_csr) end @@ -450,3 +451,144 @@ function LinearAlgebra.kron(A::DeviceSparseMatrixCSC, B::DeviceSparseMatrixCSC) C_coo = kron(A_coo, B_coo) return DeviceSparseMatrixCSC(C_coo) end + +""" + *(A::DeviceSparseMatrixCSC, B::DeviceSparseMatrixCSC) + +Multiply two sparse matrices in CSC format. Both matrices must have compatible dimensions +(number of columns of A equals number of rows of B) and be on the same backend (device). + +The multiplication uses GPU-compatible kernels for efficient sparse-sparse matrix +multiplication (SpGEMM). + +# Examples +```jldoctest +julia> using DeviceSparseArrays, SparseArrays + +julia> A = DeviceSparseMatrixCSC(sparse([1, 2], [1, 2], [2.0, 3.0], 2, 2)); + +julia> B = DeviceSparseMatrixCSC(sparse([1, 2], [1, 2], [4.0, 5.0], 2, 2)); + +julia> C = A * B; + +julia> collect(C) +2×2 Matrix{Float64}: + 8.0 0.0 + 0.0 15.0 +``` +""" +function Base.:(*)(A::DeviceSparseMatrixCSC, B::DeviceSparseMatrixCSC) + size(A, 2) == size(B, 1) || throw( + DimensionMismatch( + "second dimension of A, $(size(A,2)), does not match first dimension of B, $(size(B,1))", + ), + ) + + backend_A = get_backend(A) + backend_B = get_backend(B) + backend_A == backend_B || + throw(ArgumentError("Both matrices must have the same backend")) + + m, k, n = size(A, 1), size(A, 2), size(B, 2) + Ti = eltype(getcolptr(A)) + Tv = promote_type(eltype(nonzeros(A)), eltype(nonzeros(B))) + + backend = backend_A + + # Allocate workspace for counting (one flag per row per column of B) + row_seen = similar(nonzeros(A), Bool, m * n) + + # Count non-zeros per column of C + nnz_per_col = similar(getcolptr(A), n) + fill!(nnz_per_col, zero(Ti)) + + kernel_count! = kernel_count_nnz_spgemm_csc!(backend) + kernel_count!( + nnz_per_col, + row_seen, + getcolptr(A), + getrowval(A), + getcolptr(B), + getrowval(B), + m; + ndrange = (n,), + ) + + # Build colptr for result matrix + cumsum_nnz = _cumsum_AK(nnz_per_col) + colptr_C = similar(getcolptr(A), n + 1) + colptr_C[2:end] .= cumsum_nnz + colptr_C[2:end] .+= one(Ti) + colptr_C[1:1] .= one(Ti) + + # Allocate result arrays + nnz_total = @allowscalar colptr_C[n + 1] - one(Ti) + rowval_C = similar(getrowval(A), nnz_total) + nzval_C = similar(nonzeros(A), Tv, nnz_total) + + # Allocate workspace for accumulation + row_accum = similar(nonzeros(A), Tv, m * n) + row_flags = similar(nonzeros(A), Bool, m * n) + + # Compute the product + kernel_mult! = kernel_spgemm_csc!(backend) + kernel_mult!( + rowval_C, + nzval_C, + colptr_C, + getcolptr(A), + getrowval(A), + nonzeros(A), + getcolptr(B), + getrowval(B), + nonzeros(B), + row_accum, + row_flags, + m, + Val{false}(), + Val{false}(); + ndrange = (n,), + ) + + return DeviceSparseMatrixCSC(m, n, colptr_C, rowval_C, nzval_C) +end + +# Multiplication with transpose/adjoint support +for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparseMatrixCSC) + for (wrapb, transb, conjb, unwrapb, whereT2) in + trans_adj_wrappers(:DeviceSparseMatrixCSC) + # Skip the case where both are not transposed (already handled above) + (transa == false && transb == false) && continue + + TypeA = wrapa(:(T1)) + TypeB = wrapb(:(T2)) + + @eval function Base.:(*)( + A::$TypeA, + B::$TypeB, + ) where {$(whereT1(:T1)),$(whereT2(:T2))} + size(A, 2) == size(B, 1) || throw( + DimensionMismatch( + "second dimension of A, $(size(A,2)), does not match first dimension of B, $(size(B,1))", + ), + ) + + _A = $(unwrapa(:A)) + _B = $(unwrapb(:B)) + + backend_A = get_backend(_A) + backend_B = get_backend(_B) + backend_A == backend_B || + throw(ArgumentError("Both matrices must have the same backend")) + + # For transpose/adjoint, convert to CSR format (which is CSC transposed structurally) + # This follows the same pattern as addition with transpose/adjoint + A_csr = DeviceSparseMatrixCSR(A) + B_csr = DeviceSparseMatrixCSR(B) + result_csr = A_csr * B_csr + + # Convert back to CSC + return DeviceSparseMatrixCSC(result_csr) + end + end +end diff --git a/src/matrix_csc/matrix_csc_kernels.jl b/src/matrix_csc/matrix_csc_kernels.jl index a74a1a0..b6e277a 100644 --- a/src/matrix_csc/matrix_csc_kernels.jl +++ b/src/matrix_csc/matrix_csc_kernels.jl @@ -246,3 +246,100 @@ end i_C += 1 end end + +# Kernels for sparse-sparse matrix multiplication (SpGEMM) + +# Kernel for counting non-zeros per column in C = A * B (CSC format) +# For each column j of B, we accumulate contributions from all nonzeros B[k,j] +# Each B[k,j] contributes (column k of A) to column j of C +@kernel inbounds=true function kernel_count_nnz_spgemm_csc!( + nnz_per_col, + row_seen, + @Const(colptr_A), + @Const(rowval_A), + @Const(colptr_B), + @Const(rowval_B), + @Const(m), +) + col_B = @index(Global) + + # For column col_B of B, find all rows that will have nonzeros in column col_B of C + # Use row_seen array to mark rows (needs to be cleared for each column) + offset = (col_B - 1) * m + + # Clear the seen flags for this column + for i = 1:m + row_seen[offset + i] = false + end + + count = 0 + # For each nonzero B[k, col_B] + for idx_B = colptr_B[col_B]:(colptr_B[col_B + 1] - 1) + k = rowval_B[idx_B] # row index in B (column index in A) + + # Add all rows from column k of A + for idx_A = colptr_A[k]:(colptr_A[k + 1] - 1) + i = rowval_A[idx_A] # row index + if !row_seen[offset + i] + row_seen[offset + i] = true + count += 1 + end + end + end + + nnz_per_col[col_B] = count +end + +# Kernel for computing C = A * B (CSC format) +# This assumes nnz counts and colptr_C are already computed +@kernel inbounds=true function kernel_spgemm_csc!( + rowval_C, + nzval_C, + @Const(colptr_C), + @Const(colptr_A), + @Const(rowval_A), + @Const(nzval_A), + @Const(colptr_B), + @Const(rowval_B), + @Const(nzval_B), + row_accum, + row_flags, + @Const(m), + ::Val{CONJA}, + ::Val{CONJB}, +) where {CONJA,CONJB} + col_B = @index(Global) + + # Offset for this column's workspace + offset = (col_B - 1) * m + + # Clear accumulator and flags for this column + for i = 1:m + row_accum[offset + i] = zero(eltype(nzval_C)) + row_flags[offset + i] = false + end + + # Accumulate: C[:, col_B] = sum over k of A[:, k] * B[k, col_B] + for idx_B = colptr_B[col_B]:(colptr_B[col_B + 1] - 1) + k = rowval_B[idx_B] + val_B = CONJB ? conj(nzval_B[idx_B]) : nzval_B[idx_B] + + # Add val_B * A[:, k] to accumulator + for idx_A = colptr_A[k]:(colptr_A[k + 1] - 1) + i = rowval_A[idx_A] + val_A = CONJA ? conj(nzval_A[idx_A]) : nzval_A[idx_A] + row_accum[offset + i] += val_A * val_B + row_flags[offset + i] = true + end + end + + # Write out results in sorted order + write_pos = colptr_C[col_B] + for i = 1:m + if row_flags[offset + i] + rowval_C[write_pos] = i + nzval_C[write_pos] = row_accum[offset + i] + write_pos += 1 + end + end +end diff --git a/src/matrix_csr/matrix_csr.jl b/src/matrix_csr/matrix_csr.jl index 8523f68..3800118 100644 --- a/src/matrix_csr/matrix_csr.jl +++ b/src/matrix_csr/matrix_csr.jl @@ -389,27 +389,28 @@ end # Addition with transpose/adjoint support for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparseMatrixCSR) - for (wrapb, transb, conjb, unwrapb, whereT2) in trans_adj_wrappers(:DeviceSparseMatrixCSR) + for (wrapb, transb, conjb, unwrapb, whereT2) in + trans_adj_wrappers(:DeviceSparseMatrixCSR) # Skip the case where both are not transposed (already handled above) (transa == false && transb == false) && continue - + TypeA = wrapa(:(T1)) TypeB = wrapb(:(T2)) - + @eval function Base.:+(A::$TypeA, B::$TypeB) where {$(whereT1(:T1)),$(whereT2(:T2))} size(A) == size(B) || throw( DimensionMismatch( "dimensions must match: A has dims $(size(A)), B has dims $(size(B))", ), ) - + # Convert both to CSC (transpose/adjoint of CSR has CSC structure) # and use existing CSC + CSC addition. The conversion methods # already handle transpose/adjoint correctly. A_csc = DeviceSparseMatrixCSC(A) B_csc = DeviceSparseMatrixCSC(B) result_csc = A_csc + B_csc - + # Convert back to CSR return DeviceSparseMatrixCSR(result_csc) end @@ -452,3 +453,141 @@ function LinearAlgebra.kron(A::DeviceSparseMatrixCSR, B::DeviceSparseMatrixCSR) C_coo = kron(A_coo, B_coo) return DeviceSparseMatrixCSR(C_coo) end + +""" + *(A::DeviceSparseMatrixCSR, B::DeviceSparseMatrixCSR) + +Multiply two sparse matrices in CSR format. Both matrices must have compatible dimensions +(number of columns of A equals number of rows of B) and be on the same backend (device). + +The multiplication uses GPU-compatible kernels for efficient sparse-sparse matrix +multiplication (SpGEMM). + +# Examples +```jldoctest +julia> using DeviceSparseArrays, SparseArrays + +julia> A = DeviceSparseMatrixCSR(DeviceSparseMatrixCOO(sparse([1, 2], [1, 2], [2.0, 3.0], 2, 2))); + +julia> B = DeviceSparseMatrixCSR(DeviceSparseMatrixCOO(sparse([1, 2], [1, 2], [4.0, 5.0], 2, 2))); + +julia> C = A * B; + +julia> collect(C) +2×2 Matrix{Float64}: + 8.0 0.0 + 0.0 15.0 +``` +""" +function Base.:(*)(A::DeviceSparseMatrixCSR, B::DeviceSparseMatrixCSR) + size(A, 2) == size(B, 1) || throw( + DimensionMismatch( + "second dimension of A, $(size(A,2)), does not match first dimension of B, $(size(B,1))", + ), + ) + + backend_A = get_backend(A) + backend_B = get_backend(B) + backend_A == backend_B || + throw(ArgumentError("Both matrices must have the same backend")) + + m, k, n = size(A, 1), size(A, 2), size(B, 2) + Ti = eltype(getrowptr(A)) + Tv = promote_type(eltype(nonzeros(A)), eltype(nonzeros(B))) + + backend = backend_A + + # Allocate workspace for counting (one flag per column per row of A) + col_seen = similar(nonzeros(A), Bool, m * n) + + # Count non-zeros per row of C + nnz_per_row = similar(getrowptr(A), m) + fill!(nnz_per_row, zero(Ti)) + + kernel_count! = kernel_count_nnz_spgemm_csr!(backend) + kernel_count!( + nnz_per_row, + col_seen, + getrowptr(A), + getcolval(A), + getrowptr(B), + getcolval(B), + n; + ndrange = (m,), + ) + + # Build rowptr for result matrix + cumsum_nnz = _cumsum_AK(nnz_per_row) + rowptr_C = similar(getrowptr(A), m + 1) + rowptr_C[2:end] .= cumsum_nnz + rowptr_C[2:end] .+= one(Ti) + rowptr_C[1:1] .= one(Ti) + + # Allocate result arrays + nnz_total = @allowscalar rowptr_C[m + 1] - one(Ti) + colval_C = similar(getcolval(A), nnz_total) + nzval_C = similar(nonzeros(A), Tv, nnz_total) + + # Allocate workspace for accumulation + col_accum = similar(nonzeros(A), Tv, m * n) + col_flags = similar(nonzeros(A), Bool, m * n) + + # Compute the product + kernel_mult! = kernel_spgemm_csr!(backend) + kernel_mult!( + colval_C, + nzval_C, + rowptr_C, + getrowptr(A), + getcolval(A), + nonzeros(A), + getrowptr(B), + getcolval(B), + nonzeros(B), + col_accum, + col_flags, + n, + Val{false}(), + Val{false}(); + ndrange = (m,), + ) + + return DeviceSparseMatrixCSR(m, n, rowptr_C, colval_C, nzval_C) +end + +# Multiplication with transpose/adjoint support +for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparseMatrixCSR) + for (wrapb, transb, conjb, unwrapb, whereT2) in + trans_adj_wrappers(:DeviceSparseMatrixCSR) + # Skip the case where both are not transposed (already handled above) + (transa == false && transb == false) && continue + + TypeA = wrapa(:(T1)) + TypeB = wrapb(:(T2)) + + @eval function Base.:(*)( + A::$TypeA, + B::$TypeB, + ) where {$(whereT1(:T1)),$(whereT2(:T2))} + size(A, 2) == size(B, 1) || throw( + DimensionMismatch( + "second dimension of A, $(size(A,2)), does not match first dimension of B, $(size(B,1))", + ), + ) + + backend_A = get_backend($(unwrapa(:A))) + backend_B = get_backend($(unwrapb(:B))) + backend_A == backend_B || + throw(ArgumentError("Both matrices must have the same backend")) + + # For transpose/adjoint, convert to CSC format (which is CSR transposed structurally) + # This follows the same pattern as addition with transpose/adjoint + A_csc = DeviceSparseMatrixCSC(A) + B_csc = DeviceSparseMatrixCSC(B) + result_csc = A_csc * B_csc + + # Convert back to CSR + return DeviceSparseMatrixCSR(result_csc) + end + end +end diff --git a/src/matrix_csr/matrix_csr_kernels.jl b/src/matrix_csr/matrix_csr_kernels.jl index 6b48df9..79ff74a 100644 --- a/src/matrix_csr/matrix_csr_kernels.jl +++ b/src/matrix_csr/matrix_csr_kernels.jl @@ -246,3 +246,98 @@ end i_C += 1 end end + +# Kernels for sparse-sparse matrix multiplication (SpGEMM) in CSR format + +# Kernel for counting non-zeros per row in C = A * B (CSR format) +# For each row i of A, we find all columns that will have nonzeros in row i of C +@kernel inbounds=true function kernel_count_nnz_spgemm_csr!( + nnz_per_row, + col_seen, + @Const(rowptr_A), + @Const(colval_A), + @Const(rowptr_B), + @Const(colval_B), + @Const(n), +) + row_A = @index(Global) + + # For row row_A of A, find all columns that will have nonzeros in row row_A of C + # Use col_seen array to mark columns (needs to be cleared for each row) + offset = (row_A - 1) * n + + # Clear the seen flags for this row + for j = 1:n + col_seen[offset + j] = false + end + + count = 0 + # For each nonzero A[row_A, k] + for idx_A = rowptr_A[row_A]:(rowptr_A[row_A + 1] - 1) + k = colval_A[idx_A] # column index in A (row index in B) + + # Add all columns from row k of B + for idx_B = rowptr_B[k]:(rowptr_B[k + 1] - 1) + j = colval_B[idx_B] # column index + if !col_seen[offset + j] + col_seen[offset + j] = true + count += 1 + end + end + end + + nnz_per_row[row_A] = count +end + +# Kernel for computing C = A * B (CSR format) +@kernel inbounds=true function kernel_spgemm_csr!( + colval_C, + nzval_C, + @Const(rowptr_C), + @Const(rowptr_A), + @Const(colval_A), + @Const(nzval_A), + @Const(rowptr_B), + @Const(colval_B), + @Const(nzval_B), + col_accum, + col_flags, + @Const(n), + ::Val{CONJA}, + ::Val{CONJB}, +) where {CONJA,CONJB} + row_A = @index(Global) + + # Offset for this row's workspace + offset = (row_A - 1) * n + + # Clear accumulator and flags for this row + for j = 1:n + col_accum[offset + j] = zero(eltype(nzval_C)) + col_flags[offset + j] = false + end + + # Accumulate: C[row_A, :] = sum over k of A[row_A, k] * B[k, :] + for idx_A = rowptr_A[row_A]:(rowptr_A[row_A + 1] - 1) + k = colval_A[idx_A] + val_A = CONJA ? conj(nzval_A[idx_A]) : nzval_A[idx_A] + + # Add val_A * B[k, :] to accumulator + for idx_B = rowptr_B[k]:(rowptr_B[k + 1] - 1) + j = colval_B[idx_B] + val_B = CONJB ? conj(nzval_B[idx_B]) : nzval_B[idx_B] + col_accum[offset + j] += val_A * val_B + col_flags[offset + j] = true + end + end + + # Write out results in sorted order + write_pos = rowptr_C[row_A] + for j = 1:n + if col_flags[offset + j] + colval_C[write_pos] = j + nzval_C[write_pos] = col_accum[offset + j] + write_pos += 1 + end + end +end diff --git a/test/shared/matrix_coo.jl b/test/shared/matrix_coo.jl index ad5163d..d9bdfa9 100644 --- a/test/shared/matrix_coo.jl +++ b/test/shared/matrix_coo.jl @@ -284,24 +284,24 @@ function shared_test_linearalgebra_matrix_coo( (identity, transpose, adjoint), (identity, transpose, adjoint), ) - + # Use rectangular matrices for identity+identity, square for transpose/adjoint m, n = (op_A === identity && op_B === identity) ? (50, 40) : (30, 30) dims_A = op_A === identity ? (m, n) : (n, m) dims_B = op_B === identity ? (m, n) : (n, m) - + A = sprand(T, dims_A..., 0.1) B = sprand(T, dims_B..., 0.15) - + dA = adapt(op, DeviceSparseMatrixCOO(A)) dB = adapt(op, DeviceSparseMatrixCOO(B)) - + # Test sparse + sparse result = op_A(dA) + op_B(dB) expected = op_A(A) + op_B(B) @test collect(result) ≈ Matrix(expected) @test result isa DeviceSparseMatrixCOO - + # Additional tests only for identity + identity if op_A === identity && op_B === identity # Test with overlapping entries @@ -322,6 +322,34 @@ function shared_test_linearalgebra_matrix_coo( end end + @testset "Sparse * Sparse Matrix Multiplication" begin + for T in (int_types..., float_types..., complex_types...) + for (op_A, op_B) in Iterators.product( + (identity, transpose, adjoint), + (identity, transpose, adjoint), + ) + + # Use rectangular matrices for identity*identity, square for transpose/adjoint + m, k, n = + (op_A === identity && op_B === identity) ? (50, 40, 30) : (30, 30, 30) + dims_A = op_A === identity ? (m, k) : (k, m) + dims_B = op_B === identity ? (k, n) : (n, k) + + A = sprand(T, dims_A..., 0.1) + B = sprand(T, dims_B..., 0.15) + + dA = adapt(op, DeviceSparseMatrixCOO(A)) + dB = adapt(op, DeviceSparseMatrixCOO(B)) + + # Test sparse * sparse + result = op_A(dA) * op_B(dB) + expected = op_A(A) * op_B(B) + @test collect(result) ≈ Matrix(expected) + @test result isa DeviceSparseMatrixCOO + end + end + end + @testset "Kronecker Product" begin for T in (int_types..., float_types..., complex_types...) # Test with rectangular matrices diff --git a/test/shared/matrix_csc.jl b/test/shared/matrix_csc.jl index f24cfd4..659f585 100644 --- a/test/shared/matrix_csc.jl +++ b/test/shared/matrix_csc.jl @@ -282,24 +282,24 @@ function shared_test_linearalgebra_matrix_csc( (identity, transpose, adjoint), (identity, transpose, adjoint), ) - + # Use rectangular matrices for identity+identity, square for transpose/adjoint m, n = (op_A === identity && op_B === identity) ? (50, 40) : (30, 30) dims_A = op_A === identity ? (m, n) : (n, m) dims_B = op_B === identity ? (m, n) : (n, m) - + A = sprand(T, dims_A..., 0.1) B = sprand(T, dims_B..., 0.15) - + dA = adapt(op, DeviceSparseMatrixCSC(A)) dB = adapt(op, DeviceSparseMatrixCSC(B)) - + # Test sparse + sparse result = op_A(dA) + op_B(dB) expected = op_A(A) + op_B(B) @test collect(result) ≈ Matrix(expected) @test result isa DeviceSparseMatrixCSC - + # Additional tests only for identity + identity if op_A === identity && op_B === identity # Test with overlapping entries @@ -320,6 +320,34 @@ function shared_test_linearalgebra_matrix_csc( end end + @testset "Sparse * Sparse Matrix Multiplication" begin + for T in (int_types..., float_types..., complex_types...) + for (op_A, op_B) in Iterators.product( + (identity, transpose, adjoint), + (identity, transpose, adjoint), + ) + + # Use rectangular matrices for identity*identity, square for transpose/adjoint + m, k, n = + (op_A === identity && op_B === identity) ? (50, 40, 30) : (30, 30, 30) + dims_A = op_A === identity ? (m, k) : (k, m) + dims_B = op_B === identity ? (k, n) : (n, k) + + A = sprand(T, dims_A..., 0.1) + B = sprand(T, dims_B..., 0.15) + + dA = adapt(op, DeviceSparseMatrixCSC(A)) + dB = adapt(op, DeviceSparseMatrixCSC(B)) + + # Test sparse * sparse + result = op_A(dA) * op_B(dB) + expected = op_A(A) * op_B(B) + @test collect(result) ≈ Matrix(expected) + @test result isa DeviceSparseMatrixCSC + end + end + end + @testset "Kronecker Product" begin if array_type != "JLArray" for T in (int_types..., float_types..., complex_types...) diff --git a/test/shared/matrix_csr.jl b/test/shared/matrix_csr.jl index c7b882b..fc19325 100644 --- a/test/shared/matrix_csr.jl +++ b/test/shared/matrix_csr.jl @@ -281,24 +281,24 @@ function shared_test_linearalgebra_matrix_csr( (identity, transpose, adjoint), (identity, transpose, adjoint), ) - + # Use rectangular matrices for identity+identity, square for transpose/adjoint m, n = (op_A === identity && op_B === identity) ? (50, 40) : (30, 30) dims_A = op_A === identity ? (m, n) : (n, m) dims_B = op_B === identity ? (m, n) : (n, m) - + A = sprand(T, dims_A..., 0.1) B = sprand(T, dims_B..., 0.15) - + dA = adapt(op, DeviceSparseMatrixCSR(A)) dB = adapt(op, DeviceSparseMatrixCSR(B)) - + # Test sparse + sparse result = op_A(dA) + op_B(dB) expected = op_A(A) + op_B(B) @test collect(result) ≈ Matrix(expected) @test result isa DeviceSparseMatrixCSR - + # Additional tests only for identity + identity if op_A === identity && op_B === identity # Test with overlapping entries @@ -319,6 +319,34 @@ function shared_test_linearalgebra_matrix_csr( end end + @testset "Sparse * Sparse Matrix Multiplication" begin + for T in (int_types..., float_types..., complex_types...) + for (op_A, op_B) in Iterators.product( + (identity, transpose, adjoint), + (identity, transpose, adjoint), + ) + + # Use rectangular matrices for identity*identity, square for transpose/adjoint + m, k, n = + (op_A === identity && op_B === identity) ? (50, 40, 30) : (30, 30, 30) + dims_A = op_A === identity ? (m, k) : (k, m) + dims_B = op_B === identity ? (k, n) : (n, k) + + A = sprand(T, dims_A..., 0.1) + B = sprand(T, dims_B..., 0.15) + + dA = adapt(op, DeviceSparseMatrixCSR(A)) + dB = adapt(op, DeviceSparseMatrixCSR(B)) + + # Test sparse * sparse + result = op_A(dA) * op_B(dB) + expected = op_A(A) * op_B(B) + @test collect(result) ≈ Matrix(expected) + @test result isa DeviceSparseMatrixCSR + end + end + end + @testset "Kronecker Product" begin if array_type != "JLArray" for T in (int_types..., float_types..., complex_types...)