Skip to content
Merged
1 change: 1 addition & 0 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ include("inner.jl")
include("normalize.jl")
include("expect.jl")
include("environment.jl")
include("initialize_cache.jl")
include("exports.jl")

end
2 changes: 1 addition & 1 deletion src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ function inner_network(
return BilinearFormNetwork(A, x, y; kwargs...)
end

norm_sqr_network(ψ::AbstractITensorNetwork) = inner_network(ψ, ψ)
norm_sqr_network(ψ::AbstractITensorNetwork) = QuadraticFormNetwork(ψ)

#
# Printing
Expand Down
30 changes: 6 additions & 24 deletions src/caches/abstractbeliefpropagationcache.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using Adapt: Adapt, adapt, adapt_structure
using DataGraphs: DataGraphs, underlying_graph, vertex_data
using Dictionaries: Dictionary
using Graphs: Graphs, IsDirected, dst, src
using ITensors: commoninds, delta, dir
using ITensors: dir
using LinearAlgebra: diag, dot
using NDTensors: NDTensors
using NamedGraphs.GraphsExtensions: subgraph
Expand Down Expand Up @@ -34,14 +35,6 @@ function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor})
return 1 - f
end

function default_message(datatype::Type{<:AbstractArray}, inds_e)
return [adapt(datatype, denseblocks(delta(i))) for i in inds_e]
end

function default_message(elt::Type{<:Number}, inds_e)
return default_message(Vector{elt}, inds_e)
end
default_messages(ptn::PartitionedGraph) = Dictionary()
@traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing
@traitfn function default_bp_maxiter(g::::IsDirected)
return default_bp_maxiter(undirected_graph(underlying_graph(g)))
Expand All @@ -50,11 +43,6 @@ default_partitioned_vertices(ψ::AbstractITensorNetwork) = group(v -> v, vertice

partitioned_tensornetwork(bpc::AbstractBeliefPropagationCache) = not_implemented()
messages(bpc::AbstractBeliefPropagationCache) = not_implemented()
function default_message(
bpc::AbstractBeliefPropagationCache, edge::QuotientEdge; kwargs...
)
return not_implemented()
end
default_update_alg(bpc::AbstractBeliefPropagationCache) = not_implemented()
default_message_update_alg(bpc::AbstractBeliefPropagationCache) = not_implemented()
Base.copy(bpc::AbstractBeliefPropagationCache) = not_implemented()
Expand Down Expand Up @@ -162,11 +150,6 @@ function PartitionedGraphs.quotientedge(
return PartitionedGraphs.quotientedge(partitioned_tensornetwork(bpc), edge)
end

function linkinds(bpc::AbstractBeliefPropagationCache, pe::QuotientEdge)
pitn = partitioned_tensornetwork(bpc)
return commoninds(subgraph(pitn, src(pe)), subgraph(pitn, dst(pe)))
end

NDTensors.scalartype(bpc::AbstractBeliefPropagationCache) = scalartype(tensornetwork(bpc))

"""
Expand All @@ -187,12 +170,11 @@ function update_factor(bpc, vertex, factor)
return bpc
end

function message(bpc::AbstractBeliefPropagationCache, edge::QuotientEdge; kwargs...)
mts = messages(bpc)
return get(() -> default_message(bpc, edge; kwargs...), mts, edge)
function message(bpc::AbstractBeliefPropagationCache, edge::QuotientEdge)
return messages(bpc)[edge]
end
function messages(bpc::AbstractBeliefPropagationCache, edges; kwargs...)
return map(edge -> message(bpc, edge; kwargs...), edges)
function messages(bpc::AbstractBeliefPropagationCache, edges)
return map(edge -> message(bpc, edge), edges)
end
function set_messages!(bpc::AbstractBeliefPropagationCache, quotientedges_messages)
ms = messages(bpc)
Expand Down
19 changes: 2 additions & 17 deletions src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using DataGraphs: DataGraphs, set_vertex_data!
using Dictionaries: Dictionary
using Graphs: IsDirected
using ITensors: dir
using LinearAlgebra: diag, dot
Expand All @@ -9,22 +10,14 @@ using NamedGraphs.PartitionedGraphs: AbstractPartitionedGraph, PartitionedGraph,
using SimpleTraits: SimpleTraits, @traitfn, Not
using SplitApplyCombine: group

function default_cache_construction_kwargs(alg::Algorithm"bp", ψ::AbstractITensorNetwork)
return (; partitioned_vertices = default_partitioned_vertices(ψ))
end

function default_cache_construction_kwargs(alg::Algorithm"bp", pg::PartitionedGraph)
return (;)
end

struct BeliefPropagationCache{V, PV, PTN <: AbstractPartitionedGraph{V, PV}, MTS} <:
AbstractBeliefPropagationCache{V, PV}
partitioned_tensornetwork::PTN
messages::MTS
end

#Constructors...
function BeliefPropagationCache(ptn::PartitionedGraph; messages = default_messages(ptn))
function BeliefPropagationCache(ptn::PartitionedGraph; messages = Dictionary())
return BeliefPropagationCache(ptn, messages)
end

Expand All @@ -41,20 +34,12 @@ function BeliefPropagationCache(
return BeliefPropagationCache(tn, partitioned_vertices; kwargs...)
end

function cache(alg::Algorithm"bp", tn; kwargs...)
return BeliefPropagationCache(tn; kwargs...)
end

function partitioned_tensornetwork(bp_cache::BeliefPropagationCache)
return bp_cache.partitioned_tensornetwork
end

messages(bp_cache::BeliefPropagationCache) = bp_cache.messages

function default_message(bp_cache::BeliefPropagationCache, edge::QuotientEdge)
return default_message(datatype(bp_cache), linkinds(bp_cache, edge))
end

function Base.copy(bp_cache::BeliefPropagationCache)
return BeliefPropagationCache(
copy(partitioned_tensornetwork(bp_cache)), copy(messages(bp_cache))
Expand Down
4 changes: 2 additions & 2 deletions src/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ function logscalar(
alg::Algorithm,
tn::AbstractITensorNetwork;
(cache!) = nothing,
cache_construction_kwargs = default_cache_construction_kwargs(alg, tn),
cache_construction_kwargs = (;),
update_cache = isnothing(cache!),
cache_update_kwargs = (;)
)
if isnothing(cache!)
cache! = Ref(cache(alg, tn; cache_construction_kwargs...))
cache! = Ref(initialize_cache(alg, tn; cache_construction_kwargs...))
end

if update_cache
Expand Down
4 changes: 2 additions & 2 deletions src/environment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ function environment(
vertices::Vector;
(cache!) = nothing,
update_cache = isnothing(cache!),
cache_construction_kwargs = default_cache_construction_kwargs(alg, ptn),
cache_construction_kwargs = (;),
cache_update_kwargs = (;)
)
if isnothing(cache!)
cache! = Ref(cache(alg, ptn; cache_construction_kwargs...))
cache! = Ref(initialize_cache(alg, ptn; cache_construction_kwargs...))
end

if update_cache
Expand Down
3 changes: 1 addition & 2 deletions src/expect.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using Dictionaries: Dictionary, set!
using ITensors: Op, contract, op, which_op

default_expect_alg() = "bp"
Expand Down Expand Up @@ -31,7 +30,7 @@ function expect(
)
ψIψ = QuadraticFormNetwork(ψ)
if isnothing(cache!)
cache! = Ref(cache(alg, ψIψ; cache_construction_kwargs...))
cache! = Ref(initialize_cache(alg, ψIψ; cache_construction_kwargs...))
end

if update_cache
Expand Down
2 changes: 1 addition & 1 deletion src/formnetworks/bilinearformnetwork.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Adapt: adapt
using DataGraphs: DataGraphs, set_vertex_data!
using ITensors.NDTensors: datatype, denseblocks
using ITensors: ITensor, Op, delta, prime, sim
using ITensors: ITensor, Index, Op, dag, delta, prime, sim
using NamedGraphs.GraphsExtensions: disjoint_union

default_dual_site_index_map = prime
Expand Down
2 changes: 1 addition & 1 deletion src/formnetworks/linearformnetwork.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using DataGraphs: DataGraphs, set_vertex_data!
using ITensors: ITensor, prime
using ITensors: ITensor, dag, prime
using NamedGraphs.GraphsExtensions: disjoint_union

default_dual_link_index_map = prime
Expand Down
38 changes: 38 additions & 0 deletions src/formnetworks/quadraticformnetwork.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
using DataGraphs: DataGraphs, set_vertex_data!, underlying_graph, vertex_data
using Dictionaries: Dictionary, set!
using ITensors: ITensor, commoninds, dag, delta
using NamedGraphs.PartitionedGraphs: PartitionedGraph, QuotientEdge, quotientedges

default_index_map = prime
default_inv_index_map = noprime
Expand Down Expand Up @@ -85,6 +88,41 @@ function QuadraticFormNetwork(
return QuadraticFormNetwork(blf, dual_index_map, dual_inv_index_map)
end

# Build initial BP messages on each quotient edge as `delta(bra, ket)`
# pairs, one per ket link Index crossing the cut. The bra-side counterpart
# of each ket Index is computed explicitly via `dual_index_map(fn)`, so
# the pairing is correct even when multiple link indices share an edge
# (where `commoninds`-zip ordering between layers is not guaranteed).
function identity_messages(
fn::QuadraticFormNetwork;
partitioned_vertices = default_partitioned_vertices(fn)
)
ptn = PartitionedGraph(fn, partitioned_vertices)
messages = Dictionary{QuotientEdge, Vector{ITensor}}()
tn = tensornetwork(fn)
elt = scalartype(tn)
map_idx = dual_index_map(fn)
pv = partitioned_vertices
ket_s = ket_vertex_suffix(fn)
for pe in quotientedges(ptn)
src_orig = unique(first.(filter(v -> last(v) == ket_s, pv[parent(src(pe))])))
dst_orig = unique(first.(filter(v -> last(v) == ket_s, pv[parent(dst(pe))])))
for (from_orig, to_orig, e) in (
(src_orig, dst_orig, pe),
(dst_orig, src_orig, reverse(pe)),
)
ms = ITensor[]
for v_from in from_orig, v_to in to_orig
for k in commoninds(tn[ket_vertex(fn, v_from)], tn[ket_vertex(fn, v_to)])
push!(ms, delta(elt, dag(map_idx(k)), k))
end
end
set!(messages, e, ms)
end
end
return messages
end

function update(qf::QuadraticFormNetwork, original_state_vertex, ket_state::ITensor)
state_inds = inds(ket_state)
bra_state = replaceinds(dag(ket_state), state_inds, dual_index_map(qf).(state_inds))
Expand Down
29 changes: 29 additions & 0 deletions src/initialize_cache.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using Dictionaries: Dictionary
using Graphs: is_tree
using ITensors.NDTensors: @Algorithm_str, Algorithm
using NamedGraphs.PartitionedGraphs: PartitionedGraph, quotient_graph

# Build a cache for algorithm `alg` on `tn`. The fallback constructs a
# plain `BeliefPropagationCache` with no message defaults; the
# `QuadraticFormNetwork` specialization injects `identity_messages` on
# loopy quotient graphs (canonical for the structurally ψ-vs-ψ case).
function initialize_cache(alg::Algorithm"bp", tn; kwargs...)
return BeliefPropagationCache(tn; kwargs...)
end

function initialize_cache(
alg::Algorithm"bp",
fn::QuadraticFormNetwork;
partitioned_vertices = default_partitioned_vertices(fn),
messages = nothing
)
ptn = PartitionedGraph(fn, partitioned_vertices)
if isnothing(messages)
messages = if is_tree(quotient_graph(ptn))
Dictionary()
else
identity_messages(fn; partitioned_vertices)
end
end
return BeliefPropagationCache(ptn; messages)
end
6 changes: 4 additions & 2 deletions src/inner.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ITensors: inner, scalar
using ITensors: inner, scalar, sim
using LinearAlgebra: norm, norm_sqr

default_contract_alg(tns::Tuple) = "bp"
Expand Down Expand Up @@ -173,7 +173,9 @@ end

# TODO: rename `sqnorm` to match https://github.com/JuliaStats/Distances.jl,
# or `norm_sqr` to match `LinearAlgebra.norm_sqr`
LinearAlgebra.norm_sqr::AbstractITensorNetwork; kwargs...) = inner(ψ, ψ; kwargs...)
function LinearAlgebra.norm_sqr::AbstractITensorNetwork; kwargs...)
return scalar(norm_sqr_network(ψ); kwargs...)
end

function LinearAlgebra.norm::AbstractITensorNetwork; kwargs...)
return sqrt(abs(real(norm_sqr(ψ; kwargs...))))
Expand Down
13 changes: 5 additions & 8 deletions src/normalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ function rescale(
tn::AbstractITensorNetwork,
args...;
(cache!) = nothing,
cache_construction_kwargs = default_cache_construction_kwargs(alg, tn),
cache_construction_kwargs = (;),
update_cache = isnothing(cache!),
cache_update_kwargs = (;),
kwargs...
)
if isnothing(cache!)
cache! = Ref(cache(alg, tn; cache_construction_kwargs...))
cache! = Ref(initialize_cache(alg, tn; cache_construction_kwargs...))
end

if update_cache
Expand Down Expand Up @@ -55,7 +55,7 @@ end
function LinearAlgebra.normalize(
alg::Algorithm"exact", tn::AbstractITensorNetwork; kwargs...
)
logn = logscalar(alg, inner_network(tn, tn); kwargs...)
logn = logscalar(alg, norm_sqr_network(tn); kwargs...)
c = inv(exp(logn / (2 * length(vertices(tn)))))
return map(t -> c * t, tn)
end
Expand All @@ -64,17 +64,14 @@ function LinearAlgebra.normalize(
alg::Algorithm,
tn::AbstractITensorNetwork;
(cache!) = nothing,
cache_construction_function = tn ->
cache(alg, tn; default_cache_construction_kwargs(alg, tn)...),
update_cache = isnothing(cache!),
cache_update_kwargs = (;),
cache_construction_kwargs = (;)
)
norm_tn = inner_network(tn, tn)
norm_tn = norm_sqr_network(tn)
if isnothing(cache!)
cache! = Ref(cache(alg, norm_tn; cache_construction_kwargs...))
cache! = Ref(initialize_cache(alg, norm_tn; cache_construction_kwargs...))
end

vs = collect(vertices(tn))
verts = vcat([ket_vertex(norm_tn, v) for v in vs], [bra_vertex(norm_tn, v) for v in vs])
norm_tn = rescale(alg, norm_tn; verts, cache!, update_cache, cache_update_kwargs)
Expand Down
35 changes: 23 additions & 12 deletions test/test_apply.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
using Compat: Compat
using Graphs: vertices
using ITensorNetworks:
BeliefPropagationCache, apply, environment, norm_sqr_network, siteinds, update
using ITensors: ITensors, ITensor, inner, op
using ITensorNetworks: BeliefPropagationCache, apply, environment, initialize_cache,
norm_sqr_network, siteinds, update
using ITensors: ITensors, Algorithm, ITensor, inner, op
using NamedGraphs.NamedGraphGenerators: named_grid
using SplitApplyCombine: group
using StableRNGs: StableRNG
Expand All @@ -18,18 +18,29 @@ include("utils.jl")
ψ = random_tensornetwork(rng, s; link_space = χ)
v1, v2 = (2, 2), (1, 2)
ψψ = norm_sqr_network(ψ)
# Simple Belief Propagation grouping (one bra+ket per partition) gives
# a product environment around `[v1, v2]`, which is what `apply` requires.
bp_cache = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
bp_cache = update(bp_cache; maxiter = 20)
envsSBP = environment(bp_cache, [(v1, "bra"), (v1, "ket"), (v2, "bra"), (v2, "ket")])
# Vertices of `[v1, v2]` across all layers of `ψψ` (bra/ket/operator),
# so the environment around them is just the incoming BP messages —
# the per-site operator tensors aren't pulled in as central tensors.
env_verts(vs) = [
(v, suffix) for v in vs for suffix in ("bra", "ket", "operator")
]
# Simple Belief Propagation grouping (one bra/ket/operator triple per
# partition) gives a product environment around `[v1, v2]`, which is
# what `apply` requires.
pv_SBP = group(v -> v[1], vertices(ψψ))
bp_cache = update(
initialize_cache(Algorithm("bp"), ψψ; partitioned_vertices = pv_SBP);
maxiter = 20
)
envsSBP = environment(bp_cache, env_verts((v1, v2)))
# Column-grouping (one whole column per partition) gives a non-product
# environment; `apply` should reject it.
bp_cache_col = BeliefPropagationCache(ψψ, group(v -> v[1][1], vertices(ψψ)))
bp_cache_col = update(bp_cache_col; maxiter = 20)
envsGBP = environment(
bp_cache_col, [(v1, "bra"), (v1, "ket"), (v2, "bra"), (v2, "ket")]
pv_col = group(v -> v[1][1], vertices(ψψ))
bp_cache_col = update(
initialize_cache(Algorithm("bp"), ψψ; partitioned_vertices = pv_col);
maxiter = 20
)
envsGBP = environment(bp_cache_col, env_verts((v1, v2)))
inner_alg = "exact"
ngates = 5
truncerr = 0.0
Expand Down
Loading
Loading