Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
name = "GenericSparseArrays"
uuid = "da3fe0eb-88a8-4d14-ae1a-857c283e9c70"
authors = ["Alberto Mercurio <alberto.mercurio96@gmail.com> and contributors"]
version = "0.1.0"
authors = ["Alberto Mercurio <alberto.mercurio96@gmail.com> and contributors"]

[deps]
AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -15,11 +16,13 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"

[extensions]
GenericSparseArraysFillArraysExt = "FillArrays"
GenericSparseArraysJLArraysExt = "JLArrays"

[compat]
AcceleratedKernels = "0.4"
Adapt = "4"
FillArrays = "1"
GPUArraysCore = "0.2.0"
JLArrays = "0.3"
KernelAbstractions = "0.9"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module GenericSparseArraysFillArraysExt

using GenericSparseArrays: GenericSparseMatrixCOO, GenericSparseMatrixCSC, GenericSparseMatrixCSR
using LinearAlgebra: Diagonal, Transpose, Adjoint, kron

import FillArrays: AbstractFill, getindex_value
import KernelAbstractions: @kernel, @index, get_backend, synchronize
import LinearAlgebra: kron
import SparseArrays: nnz

include("kernels.jl")
include("kron_coo.jl")
include("kron_csc.jl")
include("kron_csr.jl")

end # module
61 changes: 61 additions & 0 deletions ext/GenericSparseArraysFillArraysExt/kernels.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Optimized kernels for kron with Diagonal{AbstractFill}

# Optimized kernel for kron(D, B) where D has uniform diagonal value (e.g., scaled identity)
@kernel inbounds = true function kernel_kron_scaled_identity_coo!(
rowind_C,
colind_C,
nzval_C,
@Const(scale_val), # The uniform diagonal value
@Const(n_D::Int), # Size of the diagonal matrix
@Const(rowind_B),
@Const(colind_B),
@Const(nzval_B),
@Const(m_B::Int),
@Const(n_B::Int),
)
idx = @index(Global, Linear)

nnz_B = length(nzval_B)

if idx <= n_D * nnz_B
idx_D = div(idx - 1, nnz_B) + 1
idx_B = mod(idx - 1, nnz_B) + 1

i_B = rowind_B[idx_B]
j_B = colind_B[idx_B]
val_B = nzval_B[idx_B]

rowind_C[idx] = (idx_D - 1) * m_B + i_B
colind_C[idx] = (idx_D - 1) * n_B + j_B
nzval_C[idx] = scale_val * val_B
end
end

# Optimized kernel for kron(A, D) where D has uniform diagonal value (e.g., scaled identity)
@kernel inbounds = true function kernel_kron_coo_scaled_identity!(
rowind_C,
colind_C,
nzval_C,
@Const(rowind_A),
@Const(colind_A),
@Const(nzval_A),
@Const(scale_val), # The uniform diagonal value
@Const(p::Int), # Size of the diagonal matrix
)
idx = @index(Global, Linear)

nnz_A = length(nzval_A)

if idx <= nnz_A * p
idx_A = div(idx - 1, p) + 1
idx_D = mod(idx - 1, p) + 1

i_A = rowind_A[idx_A]
j_A = colind_A[idx_A]
val_A = nzval_A[idx_A]

rowind_C[idx] = (i_A - 1) * p + idx_D
colind_C[idx] = (j_A - 1) * p + idx_D
nzval_C[idx] = val_A * scale_val
end
end
113 changes: 113 additions & 0 deletions ext/GenericSparseArraysFillArraysExt/kron_coo.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Kronecker product with Diagonal{AbstractFill} for COO format

using GenericSparseArrays: trans_adj_wrappers

function kron(
D::Diagonal{Tv1, <:AbstractFill{Tv1}},
B::GenericSparseMatrixCOO{Tv2, Ti},
) where {Tv1, Tv2, Ti}
n_D = size(D, 1)
m_B, n_B = size(B)

# Result dimensions
m_C = n_D * m_B
n_C = n_D * n_B
nnz_C = n_D * nnz(B)

# Determine result types
Tv = promote_type(Tv1, Tv2)

backend = get_backend(B)

# Allocate output arrays
rowind_C = similar(B.rowind, Ti, nnz_C)
colind_C = similar(B.colind, Ti, nnz_C)
nzval_C = similar(B.nzval, Tv, nnz_C)

# Get the uniform fill value
fill_value = getindex_value(D.diag)

# Launch optimized kernel
kernel! = kernel_kron_scaled_identity_coo!(backend)
kernel!(
rowind_C,
colind_C,
nzval_C,
fill_value,
n_D,
B.rowind,
B.colind,
B.nzval,
m_B,
n_B;
ndrange = nnz_C,
)

return GenericSparseMatrixCOO(m_C, n_C, rowind_C, colind_C, nzval_C)
end

function kron(
A::GenericSparseMatrixCOO{Tv1, Ti},
D::Diagonal{Tv2, <:AbstractFill{Tv2}},
) where {Tv1, Ti, Tv2}
m_A, n_A = size(A)
p = size(D, 1) # D is p×p

# Result dimensions
m_C = m_A * p
n_C = n_A * p
nnz_C = nnz(A) * p

# Determine result types
Tv = promote_type(Tv1, Tv2)

backend = get_backend(A)

# Allocate output arrays
rowind_C = similar(A.rowind, Ti, nnz_C)
colind_C = similar(A.colind, Ti, nnz_C)
nzval_C = similar(A.nzval, Tv, nnz_C)

# Get the uniform fill value
fill_value = getindex_value(D.diag)

# Launch optimized kernel
kernel! = kernel_kron_coo_scaled_identity!(backend)
kernel!(
rowind_C,
colind_C,
nzval_C,
A.rowind,
A.colind,
A.nzval,
fill_value,
p;
ndrange = nnz_C,
)

return GenericSparseMatrixCOO(m_C, n_C, rowind_C, colind_C, nzval_C)
end

# kron with Diagonal{AbstractFill} and transpose/adjoint wrappers for COO
for (wrap, trans, conj, unwrap, whereT) in trans_adj_wrappers(:GenericSparseMatrixCOO)
trans == false && continue

TypeB = wrap(:(T))

@eval function kron(
D::Diagonal{Tv1, <:AbstractFill{Tv1}},
B::$TypeB,
) where {Tv1, $(whereT(:T))}
B_coo = GenericSparseMatrixCOO(B)
return kron(D, B_coo)
end

TypeA = wrap(:(T))
@eval function kron(
A::$TypeA,
D::Diagonal{Tv2, <:AbstractFill{Tv2}},
) where {$(whereT(:T)), Tv2}
A_coo = GenericSparseMatrixCOO(A)
return kron(A_coo, D)
end
end
46 changes: 46 additions & 0 deletions ext/GenericSparseArraysFillArraysExt/kron_csc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Kronecker product with Diagonal{AbstractFill} for CSC format

using GenericSparseArrays: trans_adj_wrappers

function kron(
D::Diagonal{Tv, <:AbstractFill{Tv}},
B::GenericSparseMatrixCSC,
) where {Tv}
B_coo = GenericSparseMatrixCOO(B)
C_coo = kron(D, B_coo)
return GenericSparseMatrixCSC(C_coo)
end

function kron(
A::GenericSparseMatrixCSC,
D::Diagonal{Tv, <:AbstractFill{Tv}},
) where {Tv}
A_coo = GenericSparseMatrixCOO(A)
C_coo = kron(A_coo, D)
return GenericSparseMatrixCSC(C_coo)
end

for (wrap, trans, conj, unwrap, whereT) in trans_adj_wrappers(:GenericSparseMatrixCSC)
trans == false && continue

TypeB = wrap(:(T))

@eval function kron(
D::Diagonal{Tv1, <:AbstractFill{Tv1}},
B::$TypeB,
) where {Tv1, $(whereT(:T))}
B_coo = GenericSparseMatrixCOO(B)
C_coo = kron(D, B_coo)
return GenericSparseMatrixCSC(C_coo)
end

TypeA = wrap(:(T))
@eval function kron(
A::$TypeA,
D::Diagonal{Tv2, <:AbstractFill{Tv2}},
) where {$(whereT(:T)), Tv2}
A_coo = GenericSparseMatrixCOO(A)
C_coo = kron(A_coo, D)
return GenericSparseMatrixCSC(C_coo)
end
end
46 changes: 46 additions & 0 deletions ext/GenericSparseArraysFillArraysExt/kron_csr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Kronecker product with Diagonal{AbstractFill} for CSR format

using GenericSparseArrays: trans_adj_wrappers

function kron(
D::Diagonal{Tv, <:AbstractFill{Tv}},
B::GenericSparseMatrixCSR,
) where {Tv}
B_coo = GenericSparseMatrixCOO(B)
C_coo = kron(D, B_coo)
return GenericSparseMatrixCSR(C_coo)
end

function kron(
A::GenericSparseMatrixCSR,
D::Diagonal{Tv, <:AbstractFill{Tv}},
) where {Tv}
A_coo = GenericSparseMatrixCOO(A)
C_coo = kron(A_coo, D)
return GenericSparseMatrixCSR(C_coo)
end

for (wrap, trans, conj, unwrap, whereT) in trans_adj_wrappers(:GenericSparseMatrixCSR)
trans == false && continue

TypeB = wrap(:(T))

@eval function kron(
D::Diagonal{Tv1, <:AbstractFill{Tv1}},
B::$TypeB,
) where {Tv1, $(whereT(:T))}
B_coo = GenericSparseMatrixCOO(B)
C_coo = kron(D, B_coo)
return GenericSparseMatrixCSR(C_coo)
end

TypeA = wrap(:(T))
@eval function kron(
A::$TypeA,
D::Diagonal{Tv2, <:AbstractFill{Tv2}},
) where {$(whereT(:T)), Tv2}
A_coo = GenericSparseMatrixCOO(A)
C_coo = kron(A_coo, D)
return GenericSparseMatrixCSR(C_coo)
end
end
2 changes: 1 addition & 1 deletion src/GenericSparseArrays.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module GenericSparseArrays

using LinearAlgebra
import LinearAlgebra: wrap, copymutable_oftype, __normalize!, kron
import LinearAlgebra: wrap, copymutable_oftype, __normalize!, kron, Diagonal
using SparseArrays
import SparseArrays: SparseVector, SparseMatrixCSC
import SparseArrays: getcolptr, getrowval, getnzval, nonzeroinds
Expand Down
Loading
Loading