diff --git a/src/tensors/abstractblocktensor/conversion.jl b/src/tensors/abstractblocktensor/conversion.jl index 1108cda..e7bb245 100644 --- a/src/tensors/abstractblocktensor/conversion.jl +++ b/src/tensors/abstractblocktensor/conversion.jl @@ -1,37 +1,50 @@ # Conversion # ---------- -function Base.convert(::Type{TensorMap}, t::AbstractBlockTensorMap) - S = spacetype(t) - N₁, N₂ = numout(t), numin(t) - cod = ProductSpace{S, N₁}(oplus.(codomain(t).spaces)) - dom = ProductSpace{S, N₂}(oplus.(domain(t).spaces)) - tdst = similar(t, cod ← dom) - - issparse(t) && zerovector!(tdst) +function _copy_subblocks!(tdst, tsrc) + S = spacetype(tsrc) + N₁, N₂ = numout(tsrc), numin(tsrc) for ((f₁, f₂), arr) in subblocks(tdst) blockax = ntuple(N₁ + N₂) do i return if i <= N₁ - blockedrange(map(Base.Fix2(dim, f₁.uncoupled[i]), space(t, i))) + blockedrange(map(Base.Fix2(dim, f₁.uncoupled[i]), space(tsrc, i))) else - blockedrange(map(Base.Fix2(dim, f₂.uncoupled[i - N₁]), space(t, i)')) + blockedrange(map(Base.Fix2(dim, f₂.uncoupled[i - N₁]), space(tsrc, i)')) end end - for (k, v) in nonzero_pairs(t) + for (k, v) in nonzero_pairs(tsrc) indices = getindex.(blockax, Block.(Tuple(k))) arr_slice = arr[indices...] # need to check for empty since fusion tree pair might not be present isempty(arr_slice) || copy!(arr_slice, v[f₁, f₂]) end end + return tdst +end +function Base.convert(::Type{TensorMap}, t::AbstractBlockTensorMap) + S = spacetype(t) + N₁, N₂ = numout(t), numin(t) + cod = ProductSpace{S, N₁}(oplus.(codomain(t).spaces)) + dom = ProductSpace{S, N₂}(oplus.(domain(t).spaces)) + tdst = TensorKit.TensorMapWithStorage{scalartype(t), storagetype(t)}(undef, cod, dom) + + issparse(t) && zerovector!(tdst) + _copy_subblocks!(tdst, t) return tdst end -function Base.convert(::Type{T}, t::AbstractBlockTensorMap) where {T <: TensorMap} - tdst = convert(TensorMap, t) - return convert(T, tdst) +function Base.convert(::Type{TT}, t::AbstractBlockTensorMap) where {TT <: TensorMap} + S = spacetype(t) + N₁, N₂ = numout(t), numin(t) + cod = ProductSpace{S, N₁}(oplus.(codomain(t).spaces)) + dom = ProductSpace{S, N₂}(oplus.(domain(t).spaces)) + tdst = TT(undef, cod ← dom) + issparse(t) && zerovector!(tdst) + + _copy_subblocks!(tdst, t) + return tdst end function Base.convert(::Type{TT}, t::AbstractTensorMap) where {TT <: AbstractBlockTensorMap} diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index 73c0c18..1859af2 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -15,6 +15,13 @@ function TO.tensoradd_type(TC, A::AdjointBlockTensorMap, pA::Index2Tuple, conjA: return TO.tensoradd_type(TC, A', adjointtensorindices(A, pA), !conjA) end +# copy blocks back to CPU/collect them into an array +# seems necessary for GPU-backed BlockTensorMaps but +# maybe not the most efficient approach? +function TO.tensorscalar(t::AbstractBlockTensorMap{T, S, 0, 0}) where {T, S} + return prod(TO.tensorscalar, nonzero_values(t)) +end + # tensoralloc_contract # -------------------- for TTA in (:AbstractTensorMap, :AbstractBlockTensorMap), TTB in (:AbstractTensorMap, :AbstractBlockTensorMap)