Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ext/DeviceSparseArraysJLArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,7 @@ import DeviceSparseArrays

DeviceSparseArrays._sortperm_AK(x::JLArray) = JLArray(sortperm(collect(x)))
DeviceSparseArrays._cumsum_AK(x::JLArray) = JLArray(cumsum(collect(x)))
DeviceSparseArrays._searchsortedfirst_AK(v::JLArray, x::JLArray) =
JLArray(searchsortedfirst.(Ref(collect(v)), collect(x)))

end
14 changes: 0 additions & 14 deletions src/conversions/conversion_kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,3 @@ end
i = @index(Global)
keys[i] = rowind[i] * n + colind[i]
end

# Kernel for counting entries per column (for COO → CSC)
@kernel inbounds=true function kernel_count_per_col!(colptr, @Const(colind_sorted))
i = @index(Global)
col = colind_sorted[i]
@atomic colptr[col+1] += 1
end

# Kernel for counting entries per row (for COO → CSR)
@kernel inbounds=true function kernel_count_per_row!(rowptr, @Const(rowind_sorted))
i = @index(Global)
row = rowind_sorted[i]
@atomic rowptr[row+1] += 1
end
55 changes: 35 additions & 20 deletions src/conversions/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,21 +165,38 @@ function DeviceSparseMatrixCSC(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
colind_sorted = A.colind[perm]
nzval_sorted = A.nzval[perm]

# Build colptr on device using a histogram approach
colptr = similar(A.colind, Ti, n + 1)
fill!(colptr, zero(Ti))

# Count entries per column
kernel! = kernel_count_per_col!(backend)
kernel!(colptr, colind_sorted; ndrange = (nnz_count,))
# Build colptr on device using searchsortedfirst approach
# Since colind_sorted is sorted, find where each column starts
col_indices = similar(A.colind, Ti, n)
col_indices .= Ti(1):Ti(n)

# Compute cumulative sum
@allowscalar colptr[1] = 1 # TODO: Is there a better way to do this?
colptr[2:end] .= _cumsum_AK(colptr[2:end]) .+ 1
# Find start positions for each column
colptr = similar(A.colind, Ti, n + 1)
colptr[1:n] .= _searchsortedfirst_AK(colind_sorted, col_indices)
@allowscalar colptr[n+1] = Ti(nnz_count + 1)

return DeviceSparseMatrixCSC(m, n, colptr, rowind_sorted, nzval_sorted)
end

# Transpose and Adjoint conversions for COO to CSC
DeviceSparseMatrixCSC(A::Transpose{Tv,<:DeviceSparseMatrixCOO}) where {Tv} =
DeviceSparseMatrixCSC(DeviceSparseMatrixCOO(
size(A, 1),
size(A, 2),
A.parent.colind,
A.parent.rowind,
A.parent.nzval,
))

DeviceSparseMatrixCSC(A::Adjoint{Tv,<:DeviceSparseMatrixCOO}) where {Tv} =
DeviceSparseMatrixCSC(DeviceSparseMatrixCOO(
size(A, 1),
size(A, 2),
A.parent.colind,
A.parent.rowind,
conj.(A.parent.nzval),
))

# ============================================================================
# CSR ↔ COO Conversions
# ============================================================================
Expand Down Expand Up @@ -223,17 +240,15 @@ function DeviceSparseMatrixCSR(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
colind_sorted = A.colind[perm]
nzval_sorted = A.nzval[perm]

# Build rowptr on device using a histogram approach
rowptr = similar(A.rowind, Ti, m + 1)
fill!(rowptr, zero(Ti))

# Count entries per row
kernel! = kernel_count_per_row!(backend)
kernel!(rowptr, rowind_sorted; ndrange = (nnz_count,))
# Build rowptr on device using searchsortedfirst approach
# Since rowind_sorted is sorted, find where each row starts
row_indices = similar(A.rowind, Ti, m)
row_indices .= Ti(1):Ti(m)

# Compute cumulative sum
@allowscalar rowptr[1] = 1 # TODO: Is there a better way to do this?
rowptr[2:end] .= _cumsum_AK(rowptr[2:end]) .+ 1
# Find start positions for each row
rowptr = similar(A.rowind, Ti, m + 1)
rowptr[1:m] .= _searchsortedfirst_AK(rowind_sorted, row_indices)
@allowscalar rowptr[m+1] = Ti(nnz_count + 1)

return DeviceSparseMatrixCSR(m, n, rowptr, colind_sorted, nzval_sorted)
end
1 change: 1 addition & 0 deletions src/helpers.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Helper functions to call AcceleratedKernels methods
_sortperm_AK(x) = AcceleratedKernels.sortperm(x)
_cumsum_AK(x) = AcceleratedKernels.cumsum(x)
_searchsortedfirst_AK(v, x) = AcceleratedKernels.searchsortedfirst(v, x)
142 changes: 122 additions & 20 deletions src/matrix_coo/matrix_coo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,13 @@ function Base.:+(A::DeviceSparseMatrixCOO, B::DeviceSparseMatrixCOO)
# Mark unique entries (first occurrence of each (row, col) pair)
keep_mask = similar(rowind_sorted, Bool, nnz_concat)
kernel_mark! = kernel_mark_unique_coo!(backend)
kernel_mark!(keep_mask, rowind_sorted, colind_sorted, nnz_concat; ndrange = (nnz_concat,))
kernel_mark!(
keep_mask,
rowind_sorted,
colind_sorted,
nnz_concat;
ndrange = (nnz_concat,),
)

# Compute write indices using cumsum
write_indices = _cumsum_AK(keep_mask)
Expand Down Expand Up @@ -415,42 +421,43 @@ end

# Addition with transpose/adjoint support
for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparseMatrixCOO)
for (wrapb, transb, conjb, unwrapb, whereT2) in trans_adj_wrappers(:DeviceSparseMatrixCOO)
for (wrapb, transb, conjb, unwrapb, whereT2) in
trans_adj_wrappers(:DeviceSparseMatrixCOO)
# Skip the case where both are not transposed (already handled above)
(transa == false && transb == false) && continue

TypeA = wrapa(:(T1))
TypeB = wrapb(:(T2))

@eval function Base.:+(A::$TypeA, B::$TypeB) where {$(whereT1(:T1)),$(whereT2(:T2))}
size(A) == size(B) || throw(
DimensionMismatch(
"dimensions must match: A has dims $(size(A)), B has dims $(size(B))",
),
)

_A = $(unwrapa(:A))
_B = $(unwrapb(:B))

backend_A = get_backend(_A)
backend_B = get_backend(_B)
backend_A == backend_B ||
throw(ArgumentError("Both matrices must have the same backend"))

m, n = size(A)
Ti = eltype(getrowind(_A))
Tv = promote_type(eltype(nonzeros(_A)), eltype(nonzeros(_B)))

# For transposed COO, swap row and column indices
nnz_A = nnz(_A)
nnz_B = nnz(_B)
nnz_concat = nnz_A + nnz_B

# Allocate concatenated arrays
rowind_concat = similar(getrowind(_A), nnz_concat)
colind_concat = similar(getcolind(_A), nnz_concat)
nzval_concat = similar(nonzeros(_A), Tv, nnz_concat)

# Copy entries from A (potentially swapping row/col for transpose)
if $transa
rowind_concat[1:nnz_A] .= getcolind(_A) # Swap for transpose
Expand All @@ -464,7 +471,7 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparse
else
nzval_concat[1:nnz_A] .= nonzeros(_A)
end

# Copy entries from B (potentially swapping row/col for transpose)
if $transb
rowind_concat[(nnz_A+1):end] .= getcolind(_B) # Swap for transpose
Expand All @@ -478,29 +485,41 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparse
else
nzval_concat[(nnz_A+1):end] .= nonzeros(_B)
end

# Sort and compact (same as before)
backend = backend_A
keys = similar(rowind_concat, Ti, nnz_concat)
kernel_make_keys! = kernel_make_csc_keys!(backend)
kernel_make_keys!(keys, rowind_concat, colind_concat, m; ndrange = (nnz_concat,))

kernel_make_keys!(
keys,
rowind_concat,
colind_concat,
m;
ndrange = (nnz_concat,),
)

perm = _sortperm_AK(keys)
rowind_sorted = rowind_concat[perm]
colind_sorted = colind_concat[perm]
nzval_sorted = nzval_concat[perm]

keep_mask = similar(rowind_sorted, Bool, nnz_concat)
kernel_mark! = kernel_mark_unique_coo!(backend)
kernel_mark!(keep_mask, rowind_sorted, colind_sorted, nnz_concat; ndrange = (nnz_concat,))

kernel_mark!(
keep_mask,
rowind_sorted,
colind_sorted,
nnz_concat;
ndrange = (nnz_concat,),
)

write_indices = _cumsum_AK(keep_mask)
nnz_final = @allowscalar write_indices[nnz_concat]

rowind_C = similar(getrowind(_A), nnz_final)
colind_C = similar(getcolind(_A), nnz_final)
nzval_C = similar(nonzeros(_A), Tv, nnz_final)

kernel_compact! = kernel_compact_coo!(backend)
kernel_compact!(
rowind_C,
Expand All @@ -513,7 +532,7 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparse
nnz_concat;
ndrange = (nnz_concat,),
)

return DeviceSparseMatrixCOO(m, n, rowind_C, colind_C, nzval_C)
end
end
Expand Down Expand Up @@ -587,3 +606,86 @@ function LinearAlgebra.kron(

return DeviceSparseMatrixCOO(m_C, n_C, rowind_C, colind_C, nzval_C)
end

"""
*(A::DeviceSparseMatrixCOO, B::DeviceSparseMatrixCOO)

Multiply two sparse matrices in COO format. Both matrices must have compatible dimensions
(number of columns of A equals number of rows of B) and be on the same backend (device).

The multiplication converts to CSC format, performs the multiplication with GPU-compatible
kernels, and converts back to COO format. This approach is used for all cases including
transpose/adjoint since COO doesn't have an efficient direct multiplication algorithm.

# Examples
```jldoctest
julia> using DeviceSparseArrays, SparseArrays

julia> A = DeviceSparseMatrixCOO(sparse([1, 2], [1, 2], [2.0, 3.0], 2, 2));

julia> B = DeviceSparseMatrixCOO(sparse([1, 2], [1, 2], [4.0, 5.0], 2, 2));

julia> C = A * B;

julia> collect(C)
2×2 Matrix{Float64}:
8.0 0.0
0.0 15.0
```
"""
function Base.:(*)(A::DeviceSparseMatrixCOO, B::DeviceSparseMatrixCOO)
size(A, 2) == size(B, 1) || throw(
DimensionMismatch(
"second dimension of A, $(size(A,2)), does not match first dimension of B, $(size(B,1))",
),
)

backend_A = get_backend(A)
backend_B = get_backend(B)
backend_A == backend_B ||
throw(ArgumentError("Both matrices must have the same backend"))

# Convert to CSC, multiply, convert back to COO
# This is acceptable as COO doesn't have an efficient direct multiplication algorithm
# and CSC provides the sorted structure needed for efficient SpGEMM
A_csc = DeviceSparseMatrixCSC(A)
B_csc = DeviceSparseMatrixCSC(B)
C_csc = A_csc * B_csc
return DeviceSparseMatrixCOO(C_csc)
end

# Multiplication with transpose/adjoint support - all cases use the same approach
for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparseMatrixCOO)
for (wrapb, transb, conjb, unwrapb, whereT2) in
trans_adj_wrappers(:DeviceSparseMatrixCOO)
# Skip the case where both are not transposed (already handled above)
(transa == false && transb == false) && continue

TypeA = wrapa(:(T1))
TypeB = wrapb(:(T2))

@eval function Base.:(*)(
A::$TypeA,
B::$TypeB,
) where {$(whereT1(:T1)),$(whereT2(:T2))}
size(A, 2) == size(B, 1) || throw(
DimensionMismatch(
"second dimension of A, $(size(A,2)), does not match first dimension of B, $(size(B,1))",
),
)

backend_A = get_backend($(unwrapa(:A)))
backend_B = get_backend($(unwrapb(:B)))
backend_A == backend_B ||
throw(ArgumentError("Both matrices must have the same backend"))

# Convert to CSC (handles transpose/adjoint), multiply, convert back to COO
# Same approach as the base case since COO doesn't have an efficient
# direct multiplication algorithm
A_csc = DeviceSparseMatrixCSC(A)
B_csc = DeviceSparseMatrixCSC(B)
C_csc = A_csc * B_csc
return DeviceSparseMatrixCOO(C_csc)
end
end
end
Comment thread
albertomercurio marked this conversation as resolved.
8 changes: 5 additions & 3 deletions src/matrix_coo/matrix_coo_kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,18 @@ end

if i <= nnz_in
out_idx = write_indices[i]

# If this is a new entry (or first of duplicates), write it
if i == 1 || (rowind_in[i] != rowind_in[i-1] || colind_in[i] != colind_in[i-1])
rowind_out[out_idx] = rowind_in[i]
colind_out[out_idx] = colind_in[i]

# Sum all duplicates
val_sum = nzval_in[i]
j = i + 1
while j <= nnz_in && rowind_in[j] == rowind_in[i] && colind_in[j] == colind_in[i]
while j <= nnz_in &&
rowind_in[j] == rowind_in[i] &&
colind_in[j] == colind_in[i]
val_sum += nzval_in[j]
j += 1
end
Expand Down
Loading
Loading