From b4ae4d7a6fc775f2d73066727fd34833abc854d2 Mon Sep 17 00:00:00 2001 From: Joey Date: Thu, 2 Oct 2025 13:25:12 -0400 Subject: [PATCH 01/86] Working BP Commit --- src/ITensorNetworksNext.jl | 3 +++ src/abstracttensornetwork.jl | 2 +- test/test_beliefpropagation.jl | 25 +++++++++++++++++++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 test/test_beliefpropagation.jl diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 19c4109..905d783 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -9,4 +9,7 @@ include("abstract_problem.jl") include("iterators.jl") include("adapters.jl") +include("beliefpropagation/abstractbeliefpropagationcache.jl") +include("beliefpropagation/beliefpropagationcache.jl") + end diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index e566752..1ecbffa 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -254,4 +254,4 @@ function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) return nothing end -Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) +Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) \ No newline at end of file diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl new file mode 100644 index 0000000..4b179fb --- /dev/null +++ b/test/test_beliefpropagation.jl @@ -0,0 +1,25 @@ +using Dictionaries: Dictionary +using ITensorBase: Index +using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, adapt_messages, default_message, default_messages, edge_scalars, messages, setmessages!, factors, freenergy, + partitionfunction +using Graphs: edges, vertices +using NamedGraphs.NamedGraphGenerators: named_grid +using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges +using Test: @test, @testset + +@testset "BeliefPropagation" begin + dims = (4, 1) + g = named_grid(dims) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + bpc = BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.update(bpc; maxiter = 10) + z_bp = partitionfunction(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test abs(z_bp - z_exact) <= 1e-14 +end \ No newline at end of file From d77d0632e6e88a13ab817d9d8a99a90442d37efe Mon Sep 17 00:00:00 2001 From: Joey Date: Thu, 23 Oct 2025 18:23:27 -0400 Subject: [PATCH 02/86] BP Code --- .../abstractbeliefpropagationcache.jl | 151 +++++++++++ .../beliefpropagationcache.jl | 237 ++++++++++++++++++ test/test_beliefpropagation.jl | 20 +- 3 files changed, 407 insertions(+), 1 deletion(-) create mode 100644 src/beliefpropagation/abstractbeliefpropagationcache.jl create mode 100644 src/beliefpropagation/beliefpropagationcache.jl diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl new file mode 100644 index 0000000..5eae283 --- /dev/null +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -0,0 +1,151 @@ +abstract type AbstractBeliefPropagationCache{V} <: AbstractGraph{V} end + +#Interface +factor(bp_cache::AbstractBeliefPropagationCache, vertex) = not_implemented() +setfactor!(bp_cache::AbstractBeliefPropagationCache, vertex, factor) = not_implemented() +messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) = not_implemented() +function default_message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) + return not_implemented() +end +default_messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +function setmessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge, message) + return not_implemented() +end +function deletemessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) + return not_implemented() +end +function rescale_messages( + bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge}; kwargs... + ) + return not_implemented() +end +function rescale_vertices( + bp_cache::AbstractBeliefPropagationCache, vertices::Vector; kwargs... + ) + return not_implemented() +end + +function vertex_scalar(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...) + return not_implemented() +end +function edge_scalar( + bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs... + ) + return not_implemented() +end + +#Graph functionality needed +Graphs.vertices(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +Graphs.edges(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +function NamedGraphs.GraphsExtensions.boundary_edges( + bp_cache::AbstractBeliefPropagationCache, vertices; kwargs... + ) + return not_implemented() +end + +#Functions derived from the interface +function setmessages!(bp_cache::AbstractBeliefPropagationCache, edges, messages) + for (e, m) in zip(edges) + setmessage!(bp_cache, e, m) + end + return +end + +function deletemessages!( + bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge} = edges(bp_cache) + ) + for e in edges + deletemessage!(bp_cache, e) + end + return bp_cache +end + +function vertex_scalars( + bp_cache::AbstractBeliefPropagationCache, vertices = Graphs.vertices(bp_cache); kwargs... + ) + return map(v -> region_scalar(bp_cache, v; kwargs...), vertices) +end + +function edge_scalars( + bp_cache::AbstractBeliefPropagationCache, edges = Graphs.edges(bp_cache); kwargs... + ) + return map(e -> region_scalar(bp_cache, e; kwargs...), edges) +end + +function scalar_factors_quotient(bp_cache::AbstractBeliefPropagationCache) + return vertex_scalars(bp_cache), edge_scalars(bp_cache) +end + +function incoming_messages( + bp_cache::AbstractBeliefPropagationCache, vertices::Vector{<:Any}; ignore_edges = [] + ) + b_edges = NamedGraphs.GraphsExtensions.boundary_edges(bp_cache, vertices; dir = :in) + b_edges = !isempty(ignore_edges) ? setdiff(b_edges, ignore_edges) : b_edges + return messages(bp_cache, b_edges) +end + +function incoming_messages(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...) + return incoming_messages(bp_cache, [vertex]; kwargs...) +end + +#Adapt interface for changing device +function map_messages(f, bp_cache::AbstractBeliefPropagationCache, es = edges(bp_cache)) + bp_cache = copy(bp_cache) + for e in es + setmessage!(bp_cache, e, f(message(bp_cache, e))) + end + return bp_cache +end +function map_factors(f, bp_cache::AbstractBeliefPropagationCache, vs = vertices(bp_cache)) + bp_cache = copy(bp_cache) + for v in vs + setfactor!(bp_cache, v, f(factor(bp_cache, v))) + end + return bp_cache +end +function adapt_messages(to, bp_cache::AbstractBeliefPropagationCache, args...) + return map_messages(adapt(to), bp_cache, args...) +end +function adapt_factors(to, bp_cache::AbstractBeliefPropagationCache, args...) + return map_factors(adapt(to), bp_cache, args...) +end + +function freenergy(bp_cache::AbstractBeliefPropagationCache) + numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache) + if any(t -> real(t) < 0, numerator_terms) + numerator_terms = complex.(numerator_terms) + end + if any(t -> real(t) < 0, denominator_terms) + denominator_terms = complex.(denominator_terms) + end + + any(iszero, denominator_terms) && return -Inf + return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) +end + +function partitionfunction(bp_cache::AbstractBeliefPropagationCache) + return exp(freenergy(bp_cache)) +end + +function rescale_messages(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) + return rescale_messages(bp_cache, [edge]) +end + +function rescale_messages(bp_cache::AbstractBeliefPropagationCache) + return rescale_messages(bp_cache, edges(bp_cache)) +end + +function rescale_vertices(bpc::AbstractBeliefPropagationCache; kwargs...) + return rescale_vertices(bpc, collect(vertices(bpc)); kwargs...) +end + +function rescale_vertex(bpc::AbstractBeliefPropagationCache, vertex; kwargs...) + return rescale_vertices(bpc, [vertex]; kwargs...) +end + +function rescale(bpc::AbstractBeliefPropagationCache, args...; kwargs...) + bpc = rescale_messages(bpc) + bpc = rescale_partitions(bpc, args...; kwargs...) + return bpc +end diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl new file mode 100644 index 0000000..295502a --- /dev/null +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -0,0 +1,237 @@ +using DiagonalArrays: delta +using Dictionaries: Dictionary, set!, delete! +using Graphs: AbstractGraph, is_tree, connected_components +using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges +using ITensorBase: ITensor, dim +using TypeParameterAccessors: unwrap_array_type, unwrap_array, parenttype + +struct BeliefPropagationCache{V, N <: AbstractDataGraph{V}} <: + AbstractBeliefPropagationCache{V} + network::N + messages::Dictionary +end + +messages(bp_cache::BeliefPropagationCache) = bp_cache.messages +network(bp_cache::BeliefPropagationCache) = bp_cache.network +default_messages() = Dictionary() + +BeliefPropagationCache(network) = BeliefPropagationCache(network, default_messages()) + +function Base.copy(bp_cache::BeliefPropagationCache) + return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) +end + +function deletemessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge) + ms = messages(bp_cache) + delete!(ms, e) + return bp_cache +end + +function setmessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge, message) + ms = messages(bp_cache) + set!(ms, e, message) + return bp_cache +end + +function message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs...) + ms = messages(bp_cache) + return get(() -> default_message(bp_cache, edge; kwargs...), ms, edge) +end + +function messages(bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge}) + return [message(bp_cache, e) for e in edges] +end + +default_bp_maxiter(g::AbstractGraph) = is_tree(g) ? 1 : nothing +#Forward onto the network +for f in [ + :(Graphs.vertices), + :(Graphs.edges), + :(Graphs.is_tree), + :(NamedGraphs.GraphsExtensions.boundary_edges), + :(factors), + :(default_bp_maxiter), + :(ITensorNetworksNext.setfactor!), + :(ITensorNetworksNext.linkinds), + :(ITensorNetworksNext.underlying_graph), + ] + @eval begin + function $f(bp_cache::BeliefPropagationCache, args...; kwargs...) + return $f(network(bp_cache), args...; kwargs...) + end + end +end + +#TODO: Get subgraph working on an ITensorNetwork to overload this directly +function default_bp_edge_sequence(bp_cache::BeliefPropagationCache) + return forest_cover_edge_sequence(underlying_graph(bp_cache)) +end + +function factors(tn::AbstractTensorNetwork, vertex) + return [tn[vertex]] +end + +function region_scalar(bp_cache::BeliefPropagationCache, edge::AbstractEdge) + return (message(bp_cache, edge) * message(bp_cache, reverse(edge)))[] +end + +function region_scalar(bp_cache::BeliefPropagationCache, vertex) + incoming_ms = incoming_messages(bp_cache, vertex) + state = factors(bp_cache, vertex) + return (reduce(*, incoming_ms) * reduce(*, state))[] +end + +function default_message(bp_cache::BeliefPropagationCache, edge::AbstractEdge) + return default_message(network(bp_cache), edge::AbstractEdge) +end + +function default_message(tn::AbstractTensorNetwork, edge::AbstractEdge) + t = ITensor(ones(dim.(linkinds(tn, edge))...), linkinds(tn, edge)...) + #TODO: Get datatype working on tensornetworks so we can support GPU, etc... + return t +end + +#Algorithmic defaults +default_update_alg(bp_cache::BeliefPropagationCache) = "bp" +default_message_update_alg(bp_cache::BeliefPropagationCache) = "contract" +default_normalize(::Algorithm"contract") = true +default_sequence_alg(::Algorithm"contract") = "optimal" +function set_default_kwargs(alg::Algorithm"contract") + normalize = get(alg, :normalize, default_normalize(alg)) + sequence_alg = get(alg, :sequence_alg, default_sequence_alg(alg)) + return Algorithm("contract"; normalize, sequence_alg) +end +function set_default_kwargs(alg::Algorithm"adapt_update") + _alg = set_default_kwargs(get(alg, :alg, Algorithm("contract"))) + return Algorithm("adapt_update"; adapt = alg.adapt, alg = _alg) +end +default_verbose(::Algorithm"bp") = false +default_tol(::Algorithm"bp") = nothing +function set_default_kwargs(alg::Algorithm"bp", bp_cache::BeliefPropagationCache) + verbose = get(alg, :verbose, default_verbose(alg)) + maxiter = get(alg, :maxiter, default_bp_maxiter(bp_cache)) + edge_sequence = get(alg, :edge_sequence, default_bp_edge_sequence(bp_cache)) + tol = get(alg, :tol, default_tol(alg)) + message_update_alg = set_default_kwargs( + get(alg, :message_update_alg, Algorithm(default_message_update_alg(bp_cache))) + ) + return Algorithm("bp"; verbose, maxiter, edge_sequence, tol, message_update_alg) +end + +#TODO: Update message etc should go here... +function updated_message( + alg::Algorithm"contract", bp_cache::BeliefPropagationCache, edge::AbstractEdge + ) + vertex = src(edge) + incoming_ms = incoming_messages( + bp_cache, vertex; ignore_edges = typeof(edge)[reverse(edge)] + ) + state = factors(bp_cache, vertex) + #contract_list = ITensor[incoming_ms; state] + #sequence = contraction_sequence(contract_list; alg=alg.kwargs.sequence_alg) + #updated_messages = contract(contract_list; sequence) + updated_message = + !isempty(incoming_ms) ? reduce(*, state) * reduce(*, incoming_ms) : reduce(*, state) + if alg.normalize + message_norm = LinearAlgebra.norm(updated_message) + if !iszero(message_norm) + updated_message /= message_norm + end + end + return updated_message +end + +function updated_message( + bp_cache::BeliefPropagationCache, + edge::AbstractEdge; + alg = default_message_update_alg(bpc), + kwargs..., + ) + return updated_message(set_default_kwargs(Algorithm(alg; kwargs...)), bp_cache, edge) +end + +function update_message!( + message_update_alg::Algorithm, bp_cache::BeliefPropagationCache, edge::AbstractEdge + ) + return setmessage!(bp_cache, edge, updated_message(message_update_alg, bp_cache, edge)) +end + +""" +Do a sequential update of the message tensors on `edges` +""" +function update_iteration( + alg::Algorithm"bp", + bpc::AbstractBeliefPropagationCache, + edges::Vector; + (update_diff!) = nothing, + ) + bpc = copy(bpc) + for e in edges + prev_message = !isnothing(update_diff!) ? message(bpc, e) : nothing + update_message!(alg.message_update_alg, bpc, e) + if !isnothing(update_diff!) + update_diff![] += message_diff(message(bpc, e), prev_message) + end + end + return bpc +end + +""" +Do parallel updates between groups of edges of all message tensors +Currently we send the full message tensor data struct to update for each edge_group. But really we only need the +mts relevant to that group. +""" +function update_iteration( + alg::Algorithm"bp", + bpc::AbstractBeliefPropagationCache, + edge_groups::Vector{<:Vector{<:AbstractEdge}}; + (update_diff!) = nothing, + ) + new_mts = empty(messages(bpc)) + for edges in edge_groups + bpc_t = update_iteration(alg.kwargs.message_update_alg, bpc, edges; (update_diff!)) + for e in edges + set!(new_mts, e, message(bpc_t, e)) + end + end + return set_messages(bpc, new_mts) +end + +""" +More generic interface for update, with default params +""" +function update(alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache) + compute_error = !isnothing(alg.tol) + if isnothing(alg.maxiter) + error("You need to specify a number of iterations for BP!") + end + for i in 1:alg.maxiter + diff = compute_error ? Ref(0.0) : nothing + bpc = update_iteration(alg, bpc, alg.edge_sequence; (update_diff!) = diff) + if compute_error && (diff.x / length(alg.edge_sequence)) <= alg.tol + if alg.verbose + println("BP converged to desired precision after $i iterations.") + end + break + end + end + return bpc +end + +function update(bpc::AbstractBeliefPropagationCache; alg = default_update_alg(bpc), kwargs...) + return update(set_default_kwargs(Algorithm(alg; kwargs...), bpc), bpc) +end + +#Edge sequence stuff +function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root_vertex) + forests = forest_cover(g) + edges = edgetype(g)[] + for forest in forests + trees = [forest[vs] for vs in connected_components(forest)] + for tree in trees + tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) + push!(edges, vcat(tree_edges, reverse(reverse.(tree_edges)))...) + end + end + return edges +end \ No newline at end of file diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 4b179fb..81ee722 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -3,11 +3,13 @@ using ITensorBase: Index using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, adapt_messages, default_message, default_messages, edge_scalars, messages, setmessages!, factors, freenergy, partitionfunction using Graphs: edges, vertices -using NamedGraphs.NamedGraphGenerators: named_grid +using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges using Test: @test, @testset @testset "BeliefPropagation" begin + + #Chain of tensors dims = (4, 1) g = named_grid(dims) l = Dict(e => Index(2) for e in edges(g)) @@ -17,6 +19,22 @@ using Test: @test, @testset return randn(Tuple(is)) end + bpc = BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.update(bpc; maxiter = 1) + z_bp = partitionfunction(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test abs(z_bp - z_exact) <= 1e-14 + + #Tree of tensors + dims = (4, 3) + g = named_comb_tree(dims) + l = Dict(e => Index(3) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + bpc = BeliefPropagationCache(tn) bpc = ITensorNetworksNext.update(bpc; maxiter = 10) z_bp = partitionfunction(bpc) From b80e36eaf6aac3a3702bd0403d7858603366b1e7 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 28 Oct 2025 15:18:28 -0400 Subject: [PATCH 03/86] Express BP in terms of `SweepIterator` interface Introduce `BeliefPropagationProblem` wrapper to hold the cache and the error `diff` field. Also simplifies some kwargs wrangling. --- Project.toml | 2 + src/ITensorNetworksNext.jl | 1 + .../beliefpropagationcache.jl | 126 ++---------------- .../beliefpropagationproblem.jl | 85 ++++++++++++ 4 files changed, 101 insertions(+), 113 deletions(-) create mode 100644 src/beliefpropagation/beliefpropagationproblem.jl diff --git a/Project.toml b/Project.toml index 95b8be0..e0aea23 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" +ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" @@ -39,6 +40,7 @@ DerivableInterfaces = "0.5.5" DiagonalArrays = "0.3.23" Dictionaries = "0.4.5" Graphs = "1.13.1" +ITensorBase = "0.2.14" LinearAlgebra = "1.10" MacroTools = "0.5.16" NamedDimsArrays = "0.8, 0.9" diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 905d783..cca4b6d 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -11,5 +11,6 @@ include("adapters.jl") include("beliefpropagation/abstractbeliefpropagationcache.jl") include("beliefpropagation/beliefpropagationcache.jl") +include("beliefpropagation/beliefpropagationproblem.jl") end diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 295502a..cdae651 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -1,9 +1,7 @@ -using DiagonalArrays: delta using Dictionaries: Dictionary, set!, delete! using Graphs: AbstractGraph, is_tree, connected_components using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges using ITensorBase: ITensor, dim -using TypeParameterAccessors: unwrap_array_type, unwrap_array, parenttype struct BeliefPropagationCache{V, N <: AbstractDataGraph{V}} <: AbstractBeliefPropagationCache{V} @@ -13,9 +11,8 @@ end messages(bp_cache::BeliefPropagationCache) = bp_cache.messages network(bp_cache::BeliefPropagationCache) = bp_cache.network -default_messages() = Dictionary() -BeliefPropagationCache(network) = BeliefPropagationCache(network, default_messages()) +BeliefPropagationCache(network) = BeliefPropagationCache(network, Dictionary()) function Base.copy(bp_cache::BeliefPropagationCache) return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) @@ -33,16 +30,15 @@ function setmessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge, message) return bp_cache end -function message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs...) +function message(bp_cache::BeliefPropagationCache, edge::AbstractEdge; kwargs...) ms = messages(bp_cache) return get(() -> default_message(bp_cache, edge; kwargs...), ms, edge) end -function messages(bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge}) +function messages(bp_cache::BeliefPropagationCache, edges::Vector{<:AbstractEdge}) return [message(bp_cache, e) for e in edges] end -default_bp_maxiter(g::AbstractGraph) = is_tree(g) ? 1 : nothing #Forward onto the network for f in [ :(Graphs.vertices), @@ -62,11 +58,6 @@ for f in [ end end -#TODO: Get subgraph working on an ITensorNetwork to overload this directly -function default_bp_edge_sequence(bp_cache::BeliefPropagationCache) - return forest_cover_edge_sequence(underlying_graph(bp_cache)) -end - function factors(tn::AbstractTensorNetwork, vertex) return [tn[vertex]] end @@ -91,33 +82,6 @@ function default_message(tn::AbstractTensorNetwork, edge::AbstractEdge) return t end -#Algorithmic defaults -default_update_alg(bp_cache::BeliefPropagationCache) = "bp" -default_message_update_alg(bp_cache::BeliefPropagationCache) = "contract" -default_normalize(::Algorithm"contract") = true -default_sequence_alg(::Algorithm"contract") = "optimal" -function set_default_kwargs(alg::Algorithm"contract") - normalize = get(alg, :normalize, default_normalize(alg)) - sequence_alg = get(alg, :sequence_alg, default_sequence_alg(alg)) - return Algorithm("contract"; normalize, sequence_alg) -end -function set_default_kwargs(alg::Algorithm"adapt_update") - _alg = set_default_kwargs(get(alg, :alg, Algorithm("contract"))) - return Algorithm("adapt_update"; adapt = alg.adapt, alg = _alg) -end -default_verbose(::Algorithm"bp") = false -default_tol(::Algorithm"bp") = nothing -function set_default_kwargs(alg::Algorithm"bp", bp_cache::BeliefPropagationCache) - verbose = get(alg, :verbose, default_verbose(alg)) - maxiter = get(alg, :maxiter, default_bp_maxiter(bp_cache)) - edge_sequence = get(alg, :edge_sequence, default_bp_edge_sequence(bp_cache)) - tol = get(alg, :tol, default_tol(alg)) - message_update_alg = set_default_kwargs( - get(alg, :message_update_alg, Algorithm(default_message_update_alg(bp_cache))) - ) - return Algorithm("bp"; verbose, maxiter, edge_sequence, tol, message_update_alg) -end - #TODO: Update message etc should go here... function updated_message( alg::Algorithm"contract", bp_cache::BeliefPropagationCache, edge::AbstractEdge @@ -141,85 +105,21 @@ function updated_message( return updated_message end -function updated_message( - bp_cache::BeliefPropagationCache, - edge::AbstractEdge; - alg = default_message_update_alg(bpc), - kwargs..., +function default_algorithm( + ::Type{<:Algorithm"contract"}; normalize = true, sequence_alg = "optimal" ) - return updated_message(set_default_kwargs(Algorithm(alg; kwargs...)), bp_cache, edge) + return Algorithm("contract"; normalize, sequence_alg) end - -function update_message!( - message_update_alg::Algorithm, bp_cache::BeliefPropagationCache, edge::AbstractEdge +function default_algorithm( + ::Type{<:Algorithm"adapt_update"}; adapt, alg = default_algorithm(Algorithm"contract") ) - return setmessage!(bp_cache, edge, updated_message(message_update_alg, bp_cache, edge)) + return Algorithm("adapt_update"; adapt, alg) end -""" -Do a sequential update of the message tensors on `edges` -""" -function update_iteration( - alg::Algorithm"bp", - bpc::AbstractBeliefPropagationCache, - edges::Vector; - (update_diff!) = nothing, - ) - bpc = copy(bpc) - for e in edges - prev_message = !isnothing(update_diff!) ? message(bpc, e) : nothing - update_message!(alg.message_update_alg, bpc, e) - if !isnothing(update_diff!) - update_diff![] += message_diff(message(bpc, e), prev_message) - end - end - return bpc -end - -""" -Do parallel updates between groups of edges of all message tensors -Currently we send the full message tensor data struct to update for each edge_group. But really we only need the -mts relevant to that group. -""" -function update_iteration( - alg::Algorithm"bp", - bpc::AbstractBeliefPropagationCache, - edge_groups::Vector{<:Vector{<:AbstractEdge}}; - (update_diff!) = nothing, +function update_message!( + message_update_alg::Algorithm, bpc::BeliefPropagationCache, edge::AbstractEdge ) - new_mts = empty(messages(bpc)) - for edges in edge_groups - bpc_t = update_iteration(alg.kwargs.message_update_alg, bpc, edges; (update_diff!)) - for e in edges - set!(new_mts, e, message(bpc_t, e)) - end - end - return set_messages(bpc, new_mts) -end - -""" -More generic interface for update, with default params -""" -function update(alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache) - compute_error = !isnothing(alg.tol) - if isnothing(alg.maxiter) - error("You need to specify a number of iterations for BP!") - end - for i in 1:alg.maxiter - diff = compute_error ? Ref(0.0) : nothing - bpc = update_iteration(alg, bpc, alg.edge_sequence; (update_diff!) = diff) - if compute_error && (diff.x / length(alg.edge_sequence)) <= alg.tol - if alg.verbose - println("BP converged to desired precision after $i iterations.") - end - break - end - end - return bpc -end - -function update(bpc::AbstractBeliefPropagationCache; alg = default_update_alg(bpc), kwargs...) - return update(set_default_kwargs(Algorithm(alg; kwargs...), bpc), bpc) + return setmessage!(bpc, edge, updated_message(message_update_alg, bpc, edge)) end #Edge sequence stuff @@ -234,4 +134,4 @@ function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root end end return edges -end \ No newline at end of file +end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl new file mode 100644 index 0000000..a497363 --- /dev/null +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -0,0 +1,85 @@ +mutable struct BeliefPropagationProblem{V, Cache <: AbstractBeliefPropagationCache{V}} <: + AbstractProblem + const cache::Cache + diff::Union{Nothing, Float64} +end + +function default_algorithm( + ::Type{<:Algorithm"bp"}, + bpc::BeliefPropagationCache; + verbose = false, + tol = nothing, + edge_sequence = forest_cover_edge_sequence(underlying_graph(bpc)), + message_update_alg = default_algorithm(Algorithm"contract"), + maxiter = is_tree(bpc) ? 1 : nothing, + ) + return Algorithm("bp"; verbose, tol, edge_sequence, message_update_alg, maxiter) +end + +function compute!(iter::RegionIterator{<:BeliefPropagationProblem}) + prob = iter.problem + + edge_group, kwargs = current_region_plan(iter) + + new_message_tensors = map(edge_group) do edge + old_message = message(prob.cache, edge) + + new_message = updated_message(kwargs.message_update_alg, prob.cache, edge) + + if !isnothing(prob.diff) + # TODO: Define `message_diff` + prob.diff += message_diff(new_message, old_message) + end + + return new_message + end + + foreach(edge_group, new_message_tensors) do edge, new_message + setmessage!(prob.cache, edge, new_message) + end + + return iter +end + +function region_plan( + prob::BeliefPropagationProblem; root_vertex = default_root_vertex, sweep_kwargs... + ) + edges = forest_cover_edge_sequence(underlying_graph(prob.cache); root_vertex) + + plan = map(edges) do e + return [e] => (; sweep_kwargs...) + end + + return plan +end + +function update(bpc::AbstractBeliefPropagationCache; kwargs...) + return update(default_algorithm(Algorithm"bp", bpc; kwargs...), bpc) +end +function update(alg::Algorithm"bp", bpc) + compute_error = !isnothing(alg.tol) + + diff = compute_error ? 0.0 : nothing + + prob = BeliefPropagationProblem(bpc, diff) + + iter = SweepIterator(prob, alg.maxiter; compute_error, getfield(alg, :kwargs)...) + + for _ in iter + if compute_error && prob.diff <= alg.tol + break + end + end + + if alg.verbose && compute_error + if prob.diff <= alg.tol + println("BP converged to desired precision after $(iter.which_sweep) iterations.") + else + println( + "BP failed to converge to precision $(alg.tol), got $(prob.diff) after $(iter.which_sweep) iterations", + ) + end + end + + return bpc +end From fe44b804f7461106caa3a8dbc6f0dad38ff67ede Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 31 Oct 2025 12:46:03 -0400 Subject: [PATCH 04/86] Add method for `setmessages!` that allows messages from one cache to be set from another cache --- src/beliefpropagation/beliefpropagationcache.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index cdae651..b3a32b1 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -30,6 +30,14 @@ function setmessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge, message) return bp_cache end +function setmessages!(bpc_dst::BeliefPropagationCache, bpc_src::BeliefPropagationCache, edges) + ms_dst = messages(bpc_dst) + for e in edges + set!(ms_dst, e, message(bpc_src, e)) + end + return bpc_dst +end + function message(bp_cache::BeliefPropagationCache, edge::AbstractEdge; kwargs...) ms = messages(bp_cache) return get(() -> default_message(bp_cache, edge; kwargs...), ms, edge) From 3ce08983b2a9feae9057dc10ca55491bddf08079 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 10 Nov 2025 14:03:59 -0500 Subject: [PATCH 05/86] Network is now passed to `forest_cover_edge_sequence` directly. --- src/beliefpropagation/beliefpropagationproblem.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index a497363..967b454 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -9,7 +9,7 @@ function default_algorithm( bpc::BeliefPropagationCache; verbose = false, tol = nothing, - edge_sequence = forest_cover_edge_sequence(underlying_graph(bpc)), + edge_sequence = forest_cover_edge_sequence(network(bpc)), message_update_alg = default_algorithm(Algorithm"contract"), maxiter = is_tree(bpc) ? 1 : nothing, ) @@ -44,7 +44,8 @@ end function region_plan( prob::BeliefPropagationProblem; root_vertex = default_root_vertex, sweep_kwargs... ) - edges = forest_cover_edge_sequence(underlying_graph(prob.cache); root_vertex) + + edges = forest_cover_edge_sequence(network(prob.cache); root_vertex) plan = map(edges) do e return [e] => (; sweep_kwargs...) From f6e4fd0ea748f4a3da272dc1011a855fdaee7a9e Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 11:19:31 -0500 Subject: [PATCH 06/86] test file formatting --- test/test_beliefpropagation.jl | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 81ee722..fc657e7 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -1,7 +1,17 @@ using Dictionaries: Dictionary using ITensorBase: Index -using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, adapt_messages, default_message, default_messages, edge_scalars, messages, setmessages!, factors, freenergy, - partitionfunction +using ITensorNetworksNext: + BeliefPropagationCache, + ITensorNetworksNext, + TensorNetwork, + adapt_messages, + default_message, + default_messages, + edge_scalars, + factors, + messages, + partitionfunction, + setmessages! using Graphs: edges, vertices using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges @@ -15,15 +25,15 @@ using Test: @test, @testset l = Dict(e => Index(2) for e in edges(g)) l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) end bpc = BeliefPropagationCache(tn) bpc = ITensorNetworksNext.update(bpc; maxiter = 1) z_bp = partitionfunction(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] - @test abs(z_bp - z_exact) <= 1e-14 + @test abs(z_bp - z_exact) <= 1.0e-14 #Tree of tensors dims = (4, 3) @@ -31,13 +41,14 @@ using Test: @test, @testset l = Dict(e => Index(3) for e in edges(g)) l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) end bpc = BeliefPropagationCache(tn) bpc = ITensorNetworksNext.update(bpc; maxiter = 10) z_bp = partitionfunction(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] - @test abs(z_bp - z_exact) <= 1e-14 -end \ No newline at end of file + @test abs(z_bp - z_exact) <= 1.0e-14 +end + From 63840a90df869893d87c1ce6a6c58e06bb13973c Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 11:25:31 -0500 Subject: [PATCH 07/86] Add `DataGraphsPartitionedGraphsExt` glue for `TensorNetwork` type Also includes some fixes to the way `TensorNetwork` types are constructed based on index structure. --- src/tensornetwork.jl | 79 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 76 insertions(+), 3 deletions(-) diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 582eec6..11c2e88 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -1,10 +1,21 @@ using Combinatorics: combinations using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph using Dictionaries: AbstractDictionary, Indices, dictionary -using Graphs: AbstractSimpleGraph +using Graphs: AbstractSimpleGraph, rem_vertex!, rem_edge! using NamedDimsArrays: AbstractNamedDimsArray, dimnames using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype -using NamedGraphs.GraphsExtensions: add_edges!, arrange_edge, arranged_edges, vertextype +using NamedGraphs.GraphsExtensions: GraphsExtensions, arranged_edges, vertextype +using NamedGraphs.PartitionedGraphs: + AbstractPartitionedGraph, + PartitionedGraphs, + departition, + partitioned_vertices, + partitionedgraph, + quotient_graph, + quotient_graph_type +using .LazyNamedDimsArrays: lazy, Mul +using DataGraphs: vertex_data_eltype, vertex_data, edge_data +using DataGraphs.DataGraphsPartitionedGraphsExt function _TensorNetwork end @@ -24,8 +35,14 @@ function _TensorNetwork(graph::AbstractGraph, tensors) return _TensorNetwork(graph, Dictionary(keys(tensors), values(tensors))) end +function TensorNetwork{V, VD, UG, Tensors}(graph::UG) where {V, VD, UG <: AbstractGraph{V}, Tensors} + return _TensorNetwork(graph, Tensors()) +end + DataGraphs.underlying_graph(tn::TensorNetwork) = getfield(tn, :underlying_graph) DataGraphs.vertex_data(tn::TensorNetwork) = getfield(tn, :tensors) +DataGraphs.edge_data(tn::TensorNetwork) = Dictionary{edgetype(tn), Nothing}() +DataGraphs.vertex_data_eltype(T::Type{<:TensorNetwork}) = eltype(fieldtype(T, :tensors)) function DataGraphs.underlying_graph_type(type::Type{<:TensorNetwork}) return fieldtype(type, :underlying_graph) end @@ -70,7 +87,10 @@ function fix_links!(tn::AbstractTensorNetwork) for e in setdiff(arranged_edges(graph), tn_edges) insert_trivial_link!(tn, e) end - return tn + for edge in setdiff(arranged_edges(graph), arranged_edges(graph_structure)) + insert_trivial_link!(network, edge) + end + return network end # Determine the graph structure from the tensors. @@ -93,3 +113,56 @@ end NamedGraphs.convert_vertextype(::Type{V}, tn::TensorNetwork{V}) where {V} = tn NamedGraphs.convert_vertextype(V::Type, tn::TensorNetwork) = TensorNetwork{V}(tn) + +Graphs.connected_components(tn::TensorNetwork) = Graphs.connected_components(underlying_graph(tn)) + +function Graphs.rem_edge!(tn::TensorNetwork, e) + if !has_edge(underlying_graph(tn), e) + return false + end + if !isempty(linkinds(tn, e)) + throw(ArgumentError("cannot remove edge $e due to tensor indices existing on this edge.")) + end + rem_edge!(underlying_graph(tn), e) + return true +end + +function GraphsExtensions.graph_from_vertices(type::Type{<:TensorNetwork}, vertices) + DT = fieldtype(type, :tensors) + empty_dict = DT() + return TensorNetwork(similar_graph(underlying_graph_type(type), vertices), empty_dict) +end + +## PartitionedGraphs +function PartitionedGraphs.quotient_graph(tn::TensorNetwork) + ug = quotient_graph(underlying_graph(tn)) + return TensorNetwork(ug, vertex_data(QuotientView(tn))) +end +function PartitionedGraphs.quotient_graph_type(type::Type{<:TensorNetwork}) + UG = quotient_graph_type(underlying_graph_type(type)) + VD = Vector{vertex_data_eltype(type)} + V = vertextype(UG) + return TensorNetwork{V, VD, UG, Dictionary{V, VD}} +end + +function PartitionedGraphs.partitionedgraph(tn::TensorNetwork, parts) + pg = partitionedgraph(underlying_graph(tn), parts) + return TensorNetwork(pg, vertex_data(tn)) +end + +PartitionedGraphs.departition(tn::TensorNetwork) = tn +function PartitionedGraphs.departition( + tn::TensorNetwork{<:Any, <:Any, <:AbstractPartitionedGraph} + ) + return TensorNetwork(departition(underlying_graph(tn)), vertex_data(tn)) +end + +function DataGraphsPartitionedGraphsExt.to_quotient_vertex_data(::TensorNetwork, data) + return mapreduce(lazy, *, collect(last(data))) +end + +function PartitionedGraphs.quotientview(tn::TensorNetwork) + qview = QuotientView(underlying_graph(tn)) + tensors = vertex_data(QuotientView(tn)) + return TensorNetwork(qview, tensors) +end From ba22ab5b107d2b681a5bd1d29395c0f390f23d56 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 12:27:20 -0500 Subject: [PATCH 08/86] Make abstract tensor network interface more generic. --- src/abstracttensornetwork.jl | 106 ++++++++++++++++++----------------- 1 file changed, 54 insertions(+), 52 deletions(-) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index 1ecbffa..b02c789 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -9,19 +9,23 @@ using LinearAlgebra: LinearAlgebra, factorize using MacroTools: @capture using NamedDimsArrays: dimnames, inds using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree -using NamedGraphs.GraphsExtensions: ⊔, directed_graph, incident_edges, rem_edges!, - rename_vertices, vertextype +using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger +using NamedGraphs.GraphsExtensions: + ⊔, + directed_graph, + incident_edges, + rem_edges!, + rename_vertices, + vertextype using SplitApplyCombine: flatten +using NamedGraphs.SimilarType: similar_type abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} end -function Graphs.rem_edge!(tn::AbstractTensorNetwork, e) - rem_edge!(underlying_graph(tn), e) - return tn -end +# Need to be careful about removing edges from tensor networks in case there is a bond +Graphs.rem_edge!(::AbstractTensorNetwork, edge) = not_implemented() -# TODO: Define a generic fallback for `AbstractDataGraph`? -DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = error("No edge data") +DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = not_implemented() # Graphs.jl overloads function Graphs.weights(graph::AbstractTensorNetwork) @@ -36,7 +40,7 @@ function Graphs.weights(graph::AbstractTensorNetwork) end # Copy -Base.copy(tn::AbstractTensorNetwork) = error("Not implemented") +Base.copy(::AbstractTensorNetwork) = not_implemented() # Iteration Base.iterate(tn::AbstractTensorNetwork, args...) = iterate(vertex_data(tn), args...) @@ -49,20 +53,11 @@ Base.eltype(tn::AbstractTensorNetwork) = eltype(vertex_data(tn)) # Overload if needed Graphs.is_directed(::Type{<:AbstractTensorNetwork}) = false -# Derived interface, may need to be overloaded -function DataGraphs.underlying_graph_type(G::Type{<:AbstractTensorNetwork}) - return underlying_graph_type(data_graph_type(G)) -end - # AbstractDataGraphs overloads -function DataGraphs.vertex_data(graph::AbstractTensorNetwork, args...) - return error("Not implemented") -end -function DataGraphs.edge_data(graph::AbstractTensorNetwork, args...) - return error("Not implemented") -end +DataGraphs.vertex_data(::AbstractTensorNetwork) = not_implemented() +DataGraphs.edge_data(::AbstractTensorNetwork) = not_implemented() -DataGraphs.underlying_graph(tn::AbstractTensorNetwork) = error("Not implemented") +DataGraphs.underlying_graph(::AbstractTensorNetwork) = not_implemented() function NamedGraphs.vertex_positions(tn::AbstractTensorNetwork) return NamedGraphs.vertex_positions(underlying_graph(tn)) end @@ -81,40 +76,37 @@ function Adapt.adapt_structure(to, tn::AbstractTensorNetwork) return map_vertex_data_preserve_graph(adapt(to), tn) end -function linkinds(tn::AbstractTensorNetwork, edge::Pair) - return linkinds(tn, edgetype(tn)(edge)) -end -function linkinds(tn::AbstractTensorNetwork, edge::AbstractEdge) - return inds(tn[src(edge)]) ∩ inds(tn[dst(edge)]) -end -function linkaxes(tn::AbstractTensorNetwork, edge::Pair) +linkinds(tn::AbstractGraph, edge::Pair) = linkinds(tn, edgetype(tn)(edge)) +linkinds(tn::AbstractGraph, edge::AbstractEdge) = inds(tn[src(edge)]) ∩ inds(tn[dst(edge)]) + +function linkaxes(tn::AbstractGraph, edge::Pair) return linkaxes(tn, edgetype(tn)(edge)) end -function linkaxes(tn::AbstractTensorNetwork, edge::AbstractEdge) +function linkaxes(tn::AbstractGraph, edge::AbstractEdge) return axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) end -function linknames(tn::AbstractTensorNetwork, edge::Pair) +function linknames(tn::AbstractGraph, edge::Pair) return linknames(tn, edgetype(tn)(edge)) end -function linknames(tn::AbstractTensorNetwork, edge::AbstractEdge) +function linknames(tn::AbstractGraph, edge::AbstractEdge) return dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) end -function siteinds(tn::AbstractTensorNetwork, v) +function siteinds(tn::AbstractGraph, v) s = inds(tn[v]) for v′ in neighbors(tn, v) s = setdiff(s, inds(tn[v′])) end return s end -function siteaxes(tn::AbstractTensorNetwork, edge::AbstractEdge) +function siteaxes(tn::AbstractGraph, edge::AbstractEdge) s = axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) for v′ in neighbors(tn, v) s = setdiff(s, axes(tn[v′])) end return s end -function sitenames(tn::AbstractTensorNetwork, edge::AbstractEdge) +function sitenames(tn::AbstractGraph, edge::AbstractEdge) s = dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) for v′ in neighbors(tn, v) s = setdiff(s, dimnames(tn[v′])) @@ -122,8 +114,8 @@ function sitenames(tn::AbstractTensorNetwork, edge::AbstractEdge) return s end -function setindex_preserve_graph!(tn::AbstractTensorNetwork, value, vertex) - vertex_data(tn)[vertex] = value +function setindex_preserve_graph!(tn::AbstractGraph, value, vertex) + set!(vertex_data(tn), vertex, value) return tn end @@ -153,7 +145,7 @@ end # Update the graph of the TensorNetwork `tn` to include # edges that should exist based on the tensor connectivity. -function add_missing_edges!(tn::AbstractTensorNetwork) +function add_missing_edges!(tn::AbstractGraph) foreach(v -> add_missing_edges!(tn, v), vertices(tn)) return tn end @@ -161,7 +153,7 @@ end # Update the graph of the TensorNetwork `tn` to include # edges that should be incident to the vertex `v` # based on the tensor connectivity. -function add_missing_edges!(tn::AbstractTensorNetwork, v) +function add_missing_edges!(tn::AbstractGraph, v) for v′ in vertices(tn) if v ≠ v′ e = v => v′ @@ -175,13 +167,13 @@ end # Fix the edges of the TensorNetwork `tn` to match # the tensor connectivity. -function fix_edges!(tn::AbstractTensorNetwork) +function fix_edges!(tn::AbstractGraph) foreach(v -> fix_edges!(tn, v), vertices(tn)) return tn end # Fix the edges of the TensorNetwork `tn` to match # the tensor connectivity at vertex `v`. -function fix_edges!(tn::AbstractTensorNetwork, v) +function fix_edges!(tn::AbstractGraph, v) rem_edges!(tn, incident_edges(tn, v)) add_missing_edges!(tn, v) return tn @@ -215,28 +207,20 @@ function Base.setindex!(tn::AbstractTensorNetwork, value, v) fix_edges!(tn, v) return tn end -using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger # Fix ambiguity error. function Base.setindex!(graph::AbstractTensorNetwork, value, vertex::OrdinalSuffixedInteger) graph[vertices(graph)[vertex]] = value return graph end -# Fix ambiguity error. -function Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge) - return error("No edge data.") -end -# Fix ambiguity error. -function Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) - return error("No edge data.") -end -using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger +Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge) = not_implemented() +Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) = not_implemented() # Fix ambiguity error. function Base.setindex!( tn::AbstractTensorNetwork, value, edge::Pair{<:OrdinalSuffixedInteger, <:OrdinalSuffixedInteger}, ) - return error("No edge data.") + return not_implemented() end function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) @@ -254,4 +238,22 @@ function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) return nothing end -Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) \ No newline at end of file +Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) + +function Graphs.induced_subgraph(graph::AbstractTensorNetwork, subvertices::AbstractVector{V}) where {V <: Int} + return tensornetwork_induced_subgraph(graph, subvertices) +end +function Graphs.induced_subgraph(graph::AbstractTensorNetwork, subvertices) + return tensornetwork_induced_subgraph(graph, subvertices) +end + +function tensornetwork_induced_subgraph(graph, subvertices) + underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices) + subgraph = similar_type(graph)(underlying_subgraph) + for v in vertices(subgraph) + if isassigned(graph, v) + set!(vertex_data(subgraph), v, graph[v]) + end + end + return subgraph, vlist +end From 49b087015955f1865cc7b333e43f35b47e704751 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 12:27:50 -0500 Subject: [PATCH 09/86] BP Caching overhauls --- .../abstractbeliefpropagationcache.jl | 184 ++++++++---------- .../beliefpropagationcache.jl | 178 ++++++----------- .../beliefpropagationproblem.jl | 109 ++++++++--- 3 files changed, 226 insertions(+), 245 deletions(-) diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 5eae283..8c6b3dd 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -1,117 +1,124 @@ -abstract type AbstractBeliefPropagationCache{V} <: AbstractGraph{V} end +using Graphs: AbstractGraph, AbstractEdge +using DataGraphs: AbstractDataGraph, edge_data, vertex_data, edge_data_eltype +using NamedGraphs.GraphsExtensions: boundary_edges +using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, parent -#Interface -factor(bp_cache::AbstractBeliefPropagationCache, vertex) = not_implemented() -setfactor!(bp_cache::AbstractBeliefPropagationCache, vertex, factor) = not_implemented() -messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented() -message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) = not_implemented() -function default_message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) - return not_implemented() -end -default_messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented() -function setmessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge, message) - return not_implemented() +messages(::AbstractGraph) = not_implemented() +messages(bp_cache::AbstractDataGraph) = edge_data(bp_cache) +messages(bp_cache::AbstractGraph, edges) = [message(bp_cache, e) for e in edges] + +message(bp_cache::AbstractGraph, edge::AbstractEdge) = messages(bp_cache)[edge] + +deletemessage!(bp_cache::AbstractGraph, edge) = not_implemented() +function deletemessage!(bp_cache::AbstractDataGraph, edge) + ms = messages(bp_cache) + delete!(ms, edge) + return bp_cache end -function deletemessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) - return not_implemented() + +function deletemessages!(bp_cache::AbstractGraph, edges = edges(bp_cache)) + for e in edges + deletemessage!(bp_cache, e) + end + return bp_cache end -function rescale_messages( - bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge}; kwargs... - ) - return not_implemented() + +setmessage!(bp_cache::AbstractGraph, edge, message) = not_implemented() +function setmessage!(bp_cache::AbstractDataGraph, edge, message) + ms = messages(bp_cache) + set!(ms, edge, message) + return bp_cache end -function rescale_vertices( - bp_cache::AbstractBeliefPropagationCache, vertices::Vector; kwargs... - ) - return not_implemented() +function setmessage!(bp_cache::QuotientView, edge, message) + setmessages!(parent(bp_cache), QuotientEdge(edge), message) + return bp_cache end -function vertex_scalar(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...) - return not_implemented() +function setmessages!(bp_cache::AbstractGraph, edge::QuotientEdge, message) + for e in edges(bp_cache, edge) + setmessage!(parent(bp_cache), e, message[e]) + end + return bp_cache end -function edge_scalar( - bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs... - ) - return not_implemented() +function setmessages!(bpc_dst::AbstractGraph, bpc_src::AbstractGraph, edges) + for e in edges + setmessage!(bpc_dst, e, message(bpc_src, e)) + end + return bpc_dst end -#Graph functionality needed -Graphs.vertices(bp_cache::AbstractBeliefPropagationCache) = not_implemented() -Graphs.edges(bp_cache::AbstractBeliefPropagationCache) = not_implemented() -function NamedGraphs.GraphsExtensions.boundary_edges( - bp_cache::AbstractBeliefPropagationCache, vertices; kwargs... - ) - return not_implemented() +factors(bpc::AbstractGraph) = vertex_data(bpc) +factors(bpc::AbstractGraph, vertices::Vector) = [factor(bpc, v) for v in vertices] +factors(bpc::AbstractGraph{V}, vertex::V) where {V} = factors(bpc, V[vertex]) + +factor(bpc::AbstractGraph, vertex) = factors(bpc)[vertex] + +setfactor!(bpc::AbstractGraph, vertex, factor) = not_implemented() +function setfactor!(bpc::AbstractDataGraph, vertex, factor) + fs = factors(bpc) + set!(fs, vertex, factor) + return bpc end -#Functions derived from the interface -function setmessages!(bp_cache::AbstractBeliefPropagationCache, edges, messages) - for (e, m) in zip(edges) - setmessage!(bp_cache, e, m) - end - return +function region_scalar(bp_cache::AbstractGraph, edge::AbstractEdge) + return message(bp_cache, edge) * message(bp_cache, reverse(edge)) end -function deletemessages!( - bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge} = edges(bp_cache) - ) - for e in edges - deletemessage!(bp_cache, e) - end - return bp_cache +function region_scalar(bp_cache::AbstractGraph, vertex) + + messages = incoming_messages(bp_cache, vertex) + state = factors(bp_cache, vertex) + + return reduce(*, messages) * reduce(*, state) end -function vertex_scalars( - bp_cache::AbstractBeliefPropagationCache, vertices = Graphs.vertices(bp_cache); kwargs... - ) - return map(v -> region_scalar(bp_cache, v; kwargs...), vertices) +message_type(bpc::AbstractGraph) = message_type(typeof(bpc)) +message_type(G::Type{<:AbstractGraph}) = eltype(Base.promote_op(messages, G)) +message_type(type::Type{<:AbstractDataGraph}) = edge_data_eltype(type) + +function vertex_scalars(bp_cache::AbstractGraph, vertices = vertices(bp_cache)) + return map(v -> region_scalar(bp_cache, v), vertices) end -function edge_scalars( - bp_cache::AbstractBeliefPropagationCache, edges = Graphs.edges(bp_cache); kwargs... - ) - return map(e -> region_scalar(bp_cache, e; kwargs...), edges) +function edge_scalars(bp_cache::AbstractGraph, edges = edges(bp_cache)) + return map(e -> region_scalar(bp_cache, e), edges) end -function scalar_factors_quotient(bp_cache::AbstractBeliefPropagationCache) +function scalar_factors_quotient(bp_cache::AbstractGraph) return vertex_scalars(bp_cache), edge_scalars(bp_cache) end -function incoming_messages( - bp_cache::AbstractBeliefPropagationCache, vertices::Vector{<:Any}; ignore_edges = [] - ) - b_edges = NamedGraphs.GraphsExtensions.boundary_edges(bp_cache, vertices; dir = :in) +function incoming_messages(bp_cache::AbstractGraph, vertices; ignore_edges = []) + b_edges = boundary_edges(bp_cache, [vertices;]; dir = :in) b_edges = !isempty(ignore_edges) ? setdiff(b_edges, ignore_edges) : b_edges return messages(bp_cache, b_edges) end -function incoming_messages(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...) - return incoming_messages(bp_cache, [vertex]; kwargs...) -end +default_messages(::AbstractGraph) = not_implemented() #Adapt interface for changing device -function map_messages(f, bp_cache::AbstractBeliefPropagationCache, es = edges(bp_cache)) - bp_cache = copy(bp_cache) +map_messages(f, bp_cache, es = edges(bp_cache)) = map_messages!(f, copy(bp_cache), es) +function map_messages!(f, bp_cache, es = edges(bp_cache)) for e in es setmessage!(bp_cache, e, f(message(bp_cache, e))) end return bp_cache end -function map_factors(f, bp_cache::AbstractBeliefPropagationCache, vs = vertices(bp_cache)) - bp_cache = copy(bp_cache) + +map_factors(f, bp_cache, vs = vertices(bp_cache)) = map_factors!(f, copy(bp_cache), vs) +function map_factors!(f, bp_cache, vs = vertices(bp_cache)) for v in vs setfactor!(bp_cache, v, f(factor(bp_cache, v))) end return bp_cache end -function adapt_messages(to, bp_cache::AbstractBeliefPropagationCache, args...) - return map_messages(adapt(to), bp_cache, args...) -end -function adapt_factors(to, bp_cache::AbstractBeliefPropagationCache, args...) - return map_factors(adapt(to), bp_cache, args...) -end -function freenergy(bp_cache::AbstractBeliefPropagationCache) +adapt_messages(to, bp_cache, es = edges(bp_cache)) = map_messages(adapt(to), bp_cache, es) +adapt_factors(to, bp_cache, vs = vertices(bp_cache)) = map_factors(adapt(to), bp_cache, vs) + +abstract type AbstractBeliefPropagationCache{V, ED} <: AbstractDataGraph{V, Nothing, ED} end + +function free_energy(bp_cache::AbstractBeliefPropagationCache) numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache) if any(t -> real(t) < 0, numerator_terms) numerator_terms = complex.(numerator_terms) @@ -123,29 +130,4 @@ function freenergy(bp_cache::AbstractBeliefPropagationCache) any(iszero, denominator_terms) && return -Inf return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) end - -function partitionfunction(bp_cache::AbstractBeliefPropagationCache) - return exp(freenergy(bp_cache)) -end - -function rescale_messages(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) - return rescale_messages(bp_cache, [edge]) -end - -function rescale_messages(bp_cache::AbstractBeliefPropagationCache) - return rescale_messages(bp_cache, edges(bp_cache)) -end - -function rescale_vertices(bpc::AbstractBeliefPropagationCache; kwargs...) - return rescale_vertices(bpc, collect(vertices(bpc)); kwargs...) -end - -function rescale_vertex(bpc::AbstractBeliefPropagationCache, vertex; kwargs...) - return rescale_vertices(bpc, [vertex]; kwargs...) -end - -function rescale(bpc::AbstractBeliefPropagationCache, args...; kwargs...) - bpc = rescale_messages(bpc) - bpc = rescale_partitions(bpc, args...; kwargs...) - return bpc -end +partitionfunction(bp_cache::AbstractBeliefPropagationCache) = exp(free_energy(bp_cache)) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index b3a32b1..4e441fb 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -1,145 +1,93 @@ +using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph using Dictionaries: Dictionary, set!, delete! using Graphs: AbstractGraph, is_tree, connected_components -using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges +using NamedGraphs: convert_vertextype +using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges, is_path_graph using ITensorBase: ITensor, dim +using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, quotient_graph -struct BeliefPropagationCache{V, N <: AbstractDataGraph{V}} <: - AbstractBeliefPropagationCache{V} +struct BeliefPropagationCache{V, N <: AbstractGraph{V}, ET, MT} <: + AbstractBeliefPropagationCache{V, MT} network::N - messages::Dictionary + messages::Dictionary{ET, MT} end -messages(bp_cache::BeliefPropagationCache) = bp_cache.messages -network(bp_cache::BeliefPropagationCache) = bp_cache.network +network(bp_cache) = underlying_graph(bp_cache) -BeliefPropagationCache(network) = BeliefPropagationCache(network, Dictionary()) - -function Base.copy(bp_cache::BeliefPropagationCache) - return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) +DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = getfield(bpc, :network) +DataGraphs.edge_data(bpc::BeliefPropagationCache) = getfield(bpc, :messages) +DataGraphs.vertex_data(bpc::BeliefPropagationCache) = vertex_data(network(bpc)) +function DataGraphs.underlying_graph_type(type::Type{<:BeliefPropagationCache}) + return fieldtype(type, :network) end -function deletemessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge) - ms = messages(bp_cache) - delete!(ms, e) - return bp_cache -end +message_type(::Type{<:BeliefPropagationCache{V, N, ET, MT}}) where {V, N, ET, MT} = MT -function setmessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge, message) - ms = messages(bp_cache) - set!(ms, e, message) - return bp_cache +function BeliefPropagationCache(alg, network::AbstractGraph) + es = collect(edges(network)) + es = vcat(es, reverse.(es)) + messages = map(edge -> default_message(alg, network, edge), es) + return BeliefPropagationCache(network, Dictionary(es, messages)) end -function setmessages!(bpc_dst::BeliefPropagationCache, bpc_src::BeliefPropagationCache, edges) - ms_dst = messages(bpc_dst) - for e in edges - set!(ms_dst, e, message(bpc_src, e)) - end - return bpc_dst +function Base.copy(bp_cache::BeliefPropagationCache) + return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) end -function message(bp_cache::BeliefPropagationCache, edge::AbstractEdge; kwargs...) - ms = messages(bp_cache) - return get(() -> default_message(bp_cache, edge; kwargs...), ms, edge) +# TODO: This needs to go in DataGraphsGraphsExtensionsExt +# +# This function is problematic when `ng isa TensorNetwork` as it relies on deleting edges +# and taking subgraphs, which is not always well-defined for the `TensorNetwork` type, +# hence we just strip off any `AbstractDataGraph` data to avoid this. +function forest_cover_edge_sequence(g::AbstractDataGraph; kwargs...) + return forest_cover_edge_sequence(underlying_graph(g); kwargs...) end - -function messages(bp_cache::BeliefPropagationCache, edges::Vector{<:AbstractEdge}) - return [message(bp_cache, e) for e in edges] +# TODO: This needs to go in PartitionedGraphsGraphsExtensionsExt +# +# While it is not at all necessary to explictly instantiate the `QuotientView`, it allows the +# data of a data graph to be removed using the above method if `parent_type(g)` is an +# `AbstractDataGraph`. +function forest_cover_edge_sequence(g::QuotientView; kwargs...) + return forest_cover_edge_sequence(quotient_graph(parent(g)); kwargs...) end - -#Forward onto the network -for f in [ - :(Graphs.vertices), - :(Graphs.edges), - :(Graphs.is_tree), - :(NamedGraphs.GraphsExtensions.boundary_edges), - :(factors), - :(default_bp_maxiter), - :(ITensorNetworksNext.setfactor!), - :(ITensorNetworksNext.linkinds), - :(ITensorNetworksNext.underlying_graph), - ] - @eval begin - function $f(bp_cache::BeliefPropagationCache, args...; kwargs...) - return $f(network(bp_cache), args...; kwargs...) +# TODO: This needs to go in GraphsExtensions +function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root_vertex) + add_edges!(g, edges(g)) + forests = forest_cover(g) + rv = edgetype(g)[] + for forest in forests + trees = [forest[vs] for vs in connected_components(forest)] + for tree in trees + tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) + push!(rv, vcat(tree_edges, reverse(reverse.(tree_edges)))...) end end + return rv end -function factors(tn::AbstractTensorNetwork, vertex) - return [tn[vertex]] -end - -function region_scalar(bp_cache::BeliefPropagationCache, edge::AbstractEdge) - return (message(bp_cache, edge) * message(bp_cache, reverse(edge)))[] -end - -function region_scalar(bp_cache::BeliefPropagationCache, vertex) - incoming_ms = incoming_messages(bp_cache, vertex) - state = factors(bp_cache, vertex) - return (reduce(*, incoming_ms) * reduce(*, state))[] -end - -function default_message(bp_cache::BeliefPropagationCache, edge::AbstractEdge) - return default_message(network(bp_cache), edge::AbstractEdge) -end - -function default_message(tn::AbstractTensorNetwork, edge::AbstractEdge) - t = ITensor(ones(dim.(linkinds(tn, edge))...), linkinds(tn, edge)...) - #TODO: Get datatype working on tensornetworks so we can support GPU, etc... - return t -end - -#TODO: Update message etc should go here... -function updated_message( - alg::Algorithm"contract", bp_cache::BeliefPropagationCache, edge::AbstractEdge - ) - vertex = src(edge) - incoming_ms = incoming_messages( - bp_cache, vertex; ignore_edges = typeof(edge)[reverse(edge)] - ) - state = factors(bp_cache, vertex) - #contract_list = ITensor[incoming_ms; state] - #sequence = contraction_sequence(contract_list; alg=alg.kwargs.sequence_alg) - #updated_messages = contract(contract_list; sequence) - updated_message = - !isempty(incoming_ms) ? reduce(*, state) * reduce(*, incoming_ms) : reduce(*, state) - if alg.normalize - message_norm = LinearAlgebra.norm(updated_message) - if !iszero(message_norm) - updated_message /= message_norm +function bpcache_induced_subgraph(graph, subvertices) + underlying_subgraph, vlist = Graphs.induced_subgraph(network(graph), subvertices) + subgraph = BeliefPropagationCache(underlying_subgraph, typeof(edge_data(graph))()) + for e in edges(subgraph) + if isassigned(graph, e) + set!(edge_data(subgraph), e, graph[e]) end end - return updated_message + return subgraph, vlist end -function default_algorithm( - ::Type{<:Algorithm"contract"}; normalize = true, sequence_alg = "optimal" - ) - return Algorithm("contract"; normalize, sequence_alg) +function Graphs.induced_subgraph(graph::BeliefPropagationCache, subvertices) + return bpcache_induced_subgraph(graph, subvertices) end -function default_algorithm( - ::Type{<:Algorithm"adapt_update"}; adapt, alg = default_algorithm(Algorithm"contract") - ) - return Algorithm("adapt_update"; adapt, alg) +# For method ambiguity +function Graphs.induced_subgraph(graph::BeliefPropagationCache{V}, subvertices::AbstractVector{V}) where {V <: Int} + return bpcache_induced_subgraph(graph, subvertices) end -function update_message!( - message_update_alg::Algorithm, bpc::BeliefPropagationCache, edge::AbstractEdge - ) - return setmessage!(bpc, edge, updated_message(message_update_alg, bpc, edge)) -end +## PartitionedGraphs -#Edge sequence stuff -function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root_vertex) - forests = forest_cover(g) - edges = edgetype(g)[] - for forest in forests - trees = [forest[vs] for vs in connected_components(forest)] - for tree in trees - tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) - push!(edges, vcat(tree_edges, reverse(reverse.(tree_edges)))...) - end - end - return edges +function PartitionedGraphs.quotientview(bpc::BeliefPropagationCache) + qview = QuotientView(network(bpc)) + messages = edge_data(QuotientView(bpc)) + return BeliefPropagationCache(qview, messages) end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 967b454..a05c97a 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,70 +1,121 @@ -mutable struct BeliefPropagationProblem{V, Cache <: AbstractBeliefPropagationCache{V}} <: - AbstractProblem +using Distributed: WorkerPool, @everywhere, remotecall, myid, waitall, workers +using Graphs: SimpleGraph, vertices, edges, has_edge +using NamedGraphs: AbstractNamedGraph, position_graph +using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices +using NamedGraphs.OrderedDictionaries: OrderedDictionary, OrderedIndices + +abstract type AbstractBeliefPropagationProblem{Alg} <: AbstractProblem end + +mutable struct BeliefPropagationProblem{Alg, Cache} <: AbstractBeliefPropagationProblem{Alg} + const alg::Alg const cache::Cache diff::Union{Nothing, Float64} end +BeliefPropagationProblem(alg, cache) = BeliefPropagationProblem(alg, cache, nothing) + function default_algorithm( ::Type{<:Algorithm"bp"}, - bpc::BeliefPropagationCache; + bpc; verbose = false, tol = nothing, - edge_sequence = forest_cover_edge_sequence(network(bpc)), + edge_sequence = forest_cover_edge_sequence(bpc), message_update_alg = default_algorithm(Algorithm"contract"), maxiter = is_tree(bpc) ? 1 : nothing, ) return Algorithm("bp"; verbose, tol, edge_sequence, message_update_alg, maxiter) end -function compute!(iter::RegionIterator{<:BeliefPropagationProblem}) - prob = iter.problem +function region_plan(prob::BeliefPropagationProblem{<:Algorithm"bp"}; sweep_kwargs...) + edges = prob.alg.edge_sequence - edge_group, kwargs = current_region_plan(iter) + plan = map(edges) do e + return e => (; sweep_kwargs...) + end - new_message_tensors = map(edge_group) do edge - old_message = message(prob.cache, edge) + return plan +end - new_message = updated_message(kwargs.message_update_alg, prob.cache, edge) +function compute!(iter::RegionIterator{<:BeliefPropagationProblem{<:Algorithm"bp"}}) + prob = iter.problem - if !isnothing(prob.diff) - # TODO: Define `message_diff` - prob.diff += message_diff(new_message, old_message) - end + edge, _ = current_region_plan(iter) + new_message = updated_message(prob.alg.message_update_alg, prob.cache, edge) + setmessage!(prob.cache, edge, new_message) - return new_message - end + return iter +end - foreach(edge_group, new_message_tensors) do edge, new_message - setmessage!(prob.cache, edge, new_message) - end +default_message(alg, network, edge) = default_message(typeof(alg), network, edge) - return iter +default_message(::Type{<:Algorithm}, network, edge) = not_implemented() +function default_message(::Type{<:Algorithm"bp"}, network, edge) + + #TODO: Get datatype working on tensornetworks so we can support GPU, etc... + links = linkinds(network, edge) + data = ones(dim.(links)...) + + t = ITensor(data, links) + return t end -function region_plan( - prob::BeliefPropagationProblem; root_vertex = default_root_vertex, sweep_kwargs... +updated_message(alg, bpc, edge) = not_implemented() +function updated_message(alg::Algorithm"contract", bpc, edge) + vertex = src(edge) + + incoming_ms = incoming_messages( + bpc, vertex; ignore_edges = typeof(edge)[reverse(edge)] ) - edges = forest_cover_edge_sequence(network(prob.cache); root_vertex) + updated_message = contract_messages(alg.contraction_alg, factors(bpc, vertex), incoming_ms) - plan = map(edges) do e - return [e] => (; sweep_kwargs...) + if alg.normalize + message_norm = LinearAlgebra.norm(updated_message) + if !iszero(message_norm) + updated_message /= message_norm + end end + return updated_message +end - return plan +contract_messages(alg, factors, messages) = not_implemented() +function contract_messages( + alg, + factors::Vector{<:AbstractArray}, + messages::Vector{<:AbstractArray}, + ) + return contract_network(alg, vcat(factors, messages)) +end + +function default_algorithm( + ::Type{<:Algorithm"contract"}; normalize = true, contraction_alg = default_algorithm(Algorithm"exact") + ) + return Algorithm("contract"; normalize, contraction_alg) +end +function default_algorithm( + ::Type{<:Algorithm"adapt_update"}; adapt, alg = default_algorithm(Algorithm"contract") + ) + return Algorithm("adapt_update"; adapt, alg) +end + +function update_message!( + message_update_alg::Algorithm, bpc::BeliefPropagationCache, edge::AbstractEdge + ) + return setmessage!(bpc, edge, updated_message(message_update_alg, bpc, edge)) end function update(bpc::AbstractBeliefPropagationCache; kwargs...) return update(default_algorithm(Algorithm"bp", bpc; kwargs...), bpc) end -function update(alg::Algorithm"bp", bpc) + +function update(alg, bpc) compute_error = !isnothing(alg.tol) diff = compute_error ? 0.0 : nothing - prob = BeliefPropagationProblem(bpc, diff) + prob = BeliefPropagationProblem(alg, bpc, diff) - iter = SweepIterator(prob, alg.maxiter; compute_error, getfield(alg, :kwargs)...) + iter = SweepIterator(prob, alg.maxiter; compute_error) for _ in iter if compute_error && prob.diff <= alg.tol From db46c04214ed93c05a6bbcc7d88b06c2745f9c34 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 12:47:19 -0500 Subject: [PATCH 10/86] Remove dead deps --- src/beliefpropagation/beliefpropagationproblem.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index a05c97a..f487ccc 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,4 +1,3 @@ -using Distributed: WorkerPool, @everywhere, remotecall, myid, waitall, workers using Graphs: SimpleGraph, vertices, edges, has_edge using NamedGraphs: AbstractNamedGraph, position_graph using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices From 400e373b9fbb7205359bfe5914ba8d6e0763cd16 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 13:05:45 -0500 Subject: [PATCH 11/86] Fix merge --- src/beliefpropagation/beliefpropagationproblem.jl | 2 +- src/tensornetwork.jl | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index f487ccc..61c97df 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -87,7 +87,7 @@ function contract_messages( end function default_algorithm( - ::Type{<:Algorithm"contract"}; normalize = true, contraction_alg = default_algorithm(Algorithm"exact") + ::Type{<:Algorithm"contract"}; normalize = true, contraction_alg = Algorithm("exact") ) return Algorithm("contract"; normalize, contraction_alg) end diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 11c2e88..44b883a 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -4,7 +4,7 @@ using Dictionaries: AbstractDictionary, Indices, dictionary using Graphs: AbstractSimpleGraph, rem_vertex!, rem_edge! using NamedDimsArrays: AbstractNamedDimsArray, dimnames using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype -using NamedGraphs.GraphsExtensions: GraphsExtensions, arranged_edges, vertextype +using NamedGraphs.GraphsExtensions: GraphsExtensions, arranged_edges, arrange_edge, vertextype using NamedGraphs.PartitionedGraphs: AbstractPartitionedGraph, PartitionedGraphs, @@ -87,10 +87,7 @@ function fix_links!(tn::AbstractTensorNetwork) for e in setdiff(arranged_edges(graph), tn_edges) insert_trivial_link!(tn, e) end - for edge in setdiff(arranged_edges(graph), arranged_edges(graph_structure)) - insert_trivial_link!(network, edge) - end - return network + return tn end # Determine the graph structure from the tensors. From b9aafe890f235c0543d7b209a46fbb86ce9f3b70 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 13:12:01 -0500 Subject: [PATCH 12/86] Fix type inference in TensorNetwork construction --- src/tensornetwork.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 44b883a..0681da5 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -66,8 +66,7 @@ end tensornetwork_edges(tensors) = tensornetwork_edges(NamedEdge, tensors) function TensorNetwork(f::Base.Callable, graph::AbstractGraph) - tensors = Dictionary(vertices(graph), f.(vertices(graph))) - return TensorNetwork(graph, tensors) + return TensorNetwork(graph, Dictionary(map(f, vertices(graph)))) end function TensorNetwork(graph::AbstractGraph, tensors) tn = _TensorNetwork(graph, tensors) From 4090e61f0069084ffd64ff53f65095ea3d05353c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Nov 2025 18:16:04 +0000 Subject: [PATCH 13/86] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/test_beliefpropagation.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index fc657e7..a39e1a6 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -51,4 +51,3 @@ using Test: @test, @testset z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test abs(z_bp - z_exact) <= 1.0e-14 end - From be0750ee8f0ea1323eb94de8c14eec4490ef1995 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 16:45:45 -0500 Subject: [PATCH 14/86] Remove `ITensorBase` dep --- Project.toml | 2 -- src/beliefpropagation/beliefpropagationcache.jl | 1 - src/beliefpropagation/beliefpropagationproblem.jl | 6 ++---- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index e0aea23..95b8be0 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,6 @@ DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" -ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" @@ -40,7 +39,6 @@ DerivableInterfaces = "0.5.5" DiagonalArrays = "0.3.23" Dictionaries = "0.4.5" Graphs = "1.13.1" -ITensorBase = "0.2.14" LinearAlgebra = "1.10" MacroTools = "0.5.16" NamedDimsArrays = "0.8, 0.9" diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 4e441fb..5d8fa35 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -3,7 +3,6 @@ using Dictionaries: Dictionary, set!, delete! using Graphs: AbstractGraph, is_tree, connected_components using NamedGraphs: convert_vertextype using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges, is_path_graph -using ITensorBase: ITensor, dim using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, quotient_graph struct BeliefPropagationCache{V, N <: AbstractGraph{V}, ET, MT} <: diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 61c97df..49d0ef8 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -52,10 +52,8 @@ function default_message(::Type{<:Algorithm"bp"}, network, edge) #TODO: Get datatype working on tensornetworks so we can support GPU, etc... links = linkinds(network, edge) - data = ones(dim.(links)...) - - t = ITensor(data, links) - return t + data = ones(Tuple(links)) + return data end updated_message(alg, bpc, edge) = not_implemented() From b971b89a91954d4175160c9788e2974267dc6fdc Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 1 Dec 2025 17:24:09 -0500 Subject: [PATCH 15/86] `forest_cover_edge_sequence` now constructs a temporary `NamedGraph` instead of trying to operate on existing graphs The reason for this is: - One only cares about the edges of the input graph - A simple graph cannot be used as it "forgets" its edge names resulting in recursion - As shown with `TensorNetwork`, removing edges may not always be defined. --- .../beliefpropagationcache.jl | 22 ++++--------------- 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 5d8fa35..994f480 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -33,25 +33,11 @@ function Base.copy(bp_cache::BeliefPropagationCache) return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) end -# TODO: This needs to go in DataGraphsGraphsExtensionsExt -# -# This function is problematic when `ng isa TensorNetwork` as it relies on deleting edges -# and taking subgraphs, which is not always well-defined for the `TensorNetwork` type, -# hence we just strip off any `AbstractDataGraph` data to avoid this. -function forest_cover_edge_sequence(g::AbstractDataGraph; kwargs...) - return forest_cover_edge_sequence(underlying_graph(g); kwargs...) -end -# TODO: This needs to go in PartitionedGraphsGraphsExtensionsExt -# -# While it is not at all necessary to explictly instantiate the `QuotientView`, it allows the -# data of a data graph to be removed using the above method if `parent_type(g)` is an -# `AbstractDataGraph`. -function forest_cover_edge_sequence(g::QuotientView; kwargs...) - return forest_cover_edge_sequence(quotient_graph(parent(g)); kwargs...) -end # TODO: This needs to go in GraphsExtensions -function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root_vertex) - add_edges!(g, edges(g)) +function forest_cover_edge_sequence(gi::AbstractGraph; root_vertex = default_root_vertex) + # All we care about are the edges so the type of the graph doesnt matter + g = NamedGraph(vertices(gi)) + add_edges!(g, edges(gi)) forests = forest_cover(g) rv = edgetype(g)[] for forest in forests From 9ebf0310c19fdf661cf6afd39c294710f167918b Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 6 Jan 2026 09:42:36 -0500 Subject: [PATCH 16/86] [LazyNamedDimsArrays] Fix `parenttype` method --- src/LazyNamedDimsArrays/lazynameddimsarray.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/LazyNamedDimsArrays/lazynameddimsarray.jl b/src/LazyNamedDimsArrays/lazynameddimsarray.jl index b0ed86a..c269902 100644 --- a/src/LazyNamedDimsArrays/lazynameddimsarray.jl +++ b/src/LazyNamedDimsArrays/lazynameddimsarray.jl @@ -7,7 +7,7 @@ using WrappedUnions: @wrapped union::Union{A, Mul{LazyNamedDimsArray{T, A}}} end -parenttype(::Type{LazyNamedDimsArray{<:Any, A}}) where {A} = A +parenttype(::Type{LazyNamedDimsArray{T, A}}) where {T, A} = A parenttype(::Type{LazyNamedDimsArray{T}}) where {T} = AbstractNamedDimsArray{T} parenttype(::Type{LazyNamedDimsArray}) = AbstractNamedDimsArray From 16fe303b73ab7f9ab3f5a1c46118319063a7af4a Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 6 Jan 2026 09:46:08 -0500 Subject: [PATCH 17/86] BP Cache now uses new `DataGraphs`interface --- .../abstractbeliefpropagationcache.jl | 13 +-- .../beliefpropagationcache.jl | 101 +++++++++++++----- 2 files changed, 82 insertions(+), 32 deletions(-) diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 8c6b3dd..0cae3fa 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -3,11 +3,13 @@ using DataGraphs: AbstractDataGraph, edge_data, vertex_data, edge_data_eltype using NamedGraphs.GraphsExtensions: boundary_edges using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, parent -messages(::AbstractGraph) = not_implemented() -messages(bp_cache::AbstractDataGraph) = edge_data(bp_cache) +messages(bp_cache::AbstractGraph) = edge_data(bp_cache) messages(bp_cache::AbstractGraph, edges) = [message(bp_cache, e) for e in edges] -message(bp_cache::AbstractGraph, edge::AbstractEdge) = messages(bp_cache)[edge] +function message(bp_cache::AbstractGraph, edge::AbstractEdge) + ms = messages(bp_cache) + return get!(ms, edge, default_message(bp_cache, edge)) +end deletemessage!(bp_cache::AbstractGraph, edge) = not_implemented() function deletemessage!(bp_cache::AbstractDataGraph, edge) @@ -25,8 +27,7 @@ end setmessage!(bp_cache::AbstractGraph, edge, message) = not_implemented() function setmessage!(bp_cache::AbstractDataGraph, edge, message) - ms = messages(bp_cache) - set!(ms, edge, message) + setindex!(bp_cache, message, edge) return bp_cache end function setmessage!(bp_cache::QuotientView, edge, message) @@ -56,7 +57,7 @@ factor(bpc::AbstractGraph, vertex) = factors(bpc)[vertex] setfactor!(bpc::AbstractGraph, vertex, factor) = not_implemented() function setfactor!(bpc::AbstractDataGraph, vertex, factor) fs = factors(bpc) - set!(fs, vertex, factor) + setindex!(fs, vertex, factor) return bpc end diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 994f480..c9793e6 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -1,32 +1,85 @@ -using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph +using DataGraphs: + DataGraphs, + AbstractDataGraph, + DataGraph, + has_edge_data, + get_vertex_data, + get_edge_data, + set_vertex_data!, + set_edge_data!, + unset_vertex_data!, + unset_edge_data!, + vertex_data_eltype, + edge_data_eltype, + underlying_graph, + underlying_graph_type using Dictionaries: Dictionary, set!, delete! -using Graphs: AbstractGraph, is_tree, connected_components -using NamedGraphs: convert_vertextype -using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges, is_path_graph -using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, quotient_graph +using Graphs: AbstractGraph, is_tree, connected_components, is_directed +using NamedGraphs: NamedDiGraph, convert_vertextype, parent_graph_indices +using NamedGraphs.GraphsExtensions: default_root_vertex, + forest_cover, + post_order_dfs_edges, + vertextype, + is_path_graph, + undirected_graph +using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, QuotientEdges, quotient_graph, quotientedges -struct BeliefPropagationCache{V, N <: AbstractGraph{V}, ET, MT} <: +struct BeliefPropagationCache{V, G <: AbstractGraph{V}, N <: AbstractGraph{V}, ET, MT} <: AbstractBeliefPropagationCache{V, MT} + underlying_graph::G # we only use this for the edges. network::N messages::Dictionary{ET, MT} + function BeliefPropagationCache(network::AbstractGraph, messages::Dictionary) + + V = vertextype(network) + N = typeof(network) + ET = keytype(messages) + MT = eltype(messages) + + # Construct a directed graph version of the underlying graph of the tensor network. + digraph = directed_graph(underlying_graph(network)) + + bpc = new{V, typeof(digraph), N, ET, MT}(digraph, network, messages) + + for edge in edges(bpc) + get!(() -> default_message(bpc, edge), messages, edge) + end + return bpc + end end -network(bp_cache) = underlying_graph(bp_cache) +network(bp_cache) = getfield(bp_cache, :network) + +DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = getfield(bpc, :underlying_graph) + +DataGraphs.has_vertex_data(bpc::BeliefPropagationCache, vertex) = has_vertex_data(network(bpc), vertex) +DataGraphs.has_edge_data(bpc::BeliefPropagationCache, edge) = haskey(bpc.messages, edge) -DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = getfield(bpc, :network) -DataGraphs.edge_data(bpc::BeliefPropagationCache) = getfield(bpc, :messages) -DataGraphs.vertex_data(bpc::BeliefPropagationCache) = vertex_data(network(bpc)) -function DataGraphs.underlying_graph_type(type::Type{<:BeliefPropagationCache}) - return fieldtype(type, :network) +DataGraphs.get_vertex_data(bpc::BeliefPropagationCache, vertex) = get_vertex_data(network(bpc), vertex) +DataGraphs.get_edge_data(bpc::BeliefPropagationCache, edge::AbstractEdge) = bpc.messages[edge] + +DataGraphs.set_vertex_data!(bpc::BeliefPropagationCache, val, vertex) = set_vertex_data!(network(bpc), val, vertex) +DataGraphs.set_edge_data!(bpc::BeliefPropagationCache, val, edge) = set!(bpc.messages, edge, val) + +DataGraphs.unset_vertex_data!(bpc::BeliefPropagationCache, val, vertex) = unset_vertex_data!(network(bpc), val, vertex) +DataGraphs.unset_edge_data!(bpc::BeliefPropagationCache, val, edge) = unset!(bpc.messages, edge, val) + +function DataGraphs.vertex_data_eltype(T::Type{<:BeliefPropagationCache}) + return vertex_data_eltype(fieldtype(T, :network)) +end +function DataGraphs.edge_data_eltype(T::Type{<:BeliefPropagationCache}) + return eltype(fieldtype(T, :messages)) end -message_type(::Type{<:BeliefPropagationCache{V, N, ET, MT}}) where {V, N, ET, MT} = MT +message_type(T::Type{<:BeliefPropagationCache}) = edge_data_eltype(T) -function BeliefPropagationCache(alg, network::AbstractGraph) - es = collect(edges(network)) - es = vcat(es, reverse.(es)) - messages = map(edge -> default_message(alg, network, edge), es) - return BeliefPropagationCache(network, Dictionary(es, messages)) +function BeliefPropagationCache(network::AbstractGraph) + MT = vertex_data_eltype(typeof(network)) + return BeliefPropagationCache(MT, network) +end +function BeliefPropagationCache(MT::Type, network::AbstractGraph) + dict = Dictionary{edgetype(network), MT}() + return BeliefPropagationCache(network, dict) end function Base.copy(bp_cache::BeliefPropagationCache) @@ -61,18 +114,14 @@ function bpcache_induced_subgraph(graph, subvertices) return subgraph, vlist end -function Graphs.induced_subgraph(graph::BeliefPropagationCache, subvertices) - return bpcache_induced_subgraph(graph, subvertices) -end -# For method ambiguity -function Graphs.induced_subgraph(graph::BeliefPropagationCache{V}, subvertices::AbstractVector{V}) where {V <: Int} +function Graphs.induced_subgraph(graph::BeliefPropagationCache{V}, subvertices::Vector{V}) where {V} return bpcache_induced_subgraph(graph, subvertices) end ## PartitionedGraphs function PartitionedGraphs.quotientview(bpc::BeliefPropagationCache) - qview = QuotientView(network(bpc)) - messages = edge_data(QuotientView(bpc)) - return BeliefPropagationCache(qview, messages) + inds = Indices(parent_graph_indices(QuotientEdges(underlying_graph(bpc)))) + data = map(e -> bpc[QuotientEdge(e)], inds) + return BeliefPropagationCache(QuotientView(network(bpc)), data) end From 24a4335f61699a2d818f8b75a8b2867f7a16b3b5 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 6 Jan 2026 09:46:49 -0500 Subject: [PATCH 18/86] Adjust `default_message` to take a `message` type as its first argument --- .../beliefpropagationproblem.jl | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 49d0ef8..24b024d 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -2,6 +2,9 @@ using Graphs: SimpleGraph, vertices, edges, has_edge using NamedGraphs: AbstractNamedGraph, position_graph using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices using NamedGraphs.OrderedDictionaries: OrderedDictionary, OrderedIndices +using NamedDimsArrays: AbstractNamedDimsArray +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, parenttype, lazy + abstract type AbstractBeliefPropagationProblem{Alg} <: AbstractProblem end @@ -45,15 +48,16 @@ function compute!(iter::RegionIterator{<:BeliefPropagationProblem{<:Algorithm"bp return iter end -default_message(alg, network, edge) = default_message(typeof(alg), network, edge) - -default_message(::Type{<:Algorithm}, network, edge) = not_implemented() -function default_message(::Type{<:Algorithm"bp"}, network, edge) - - #TODO: Get datatype working on tensornetworks so we can support GPU, etc... - links = linkinds(network, edge) - data = ones(Tuple(links)) - return data +function default_message(bpc::BeliefPropagationCache, edge) + return default_message(message_type(bpc), network(bpc), edge) +end +function default_message(T::Type, network, edge) + array = ones(Tuple(linkinds(network, edge))) + return convert(T, array) +end +function default_message(T::Type{<:LazyNamedDimsArray}, network, edge) + message = default_message(parenttype(T), network, edge) + return convert(T, lazy(message)) end updated_message(alg, bpc, edge) = not_implemented() From c43884ecb5185386ab5acc6c08f4344c0d566811 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 6 Jan 2026 09:47:44 -0500 Subject: [PATCH 19/86] Remove unnecessary code and fix ambiguities in `AbstractTensorNetwork` --- src/abstracttensornetwork.jl | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index b02c789..b820867 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -53,10 +53,6 @@ Base.eltype(tn::AbstractTensorNetwork) = eltype(vertex_data(tn)) # Overload if needed Graphs.is_directed(::Type{<:AbstractTensorNetwork}) = false -# AbstractDataGraphs overloads -DataGraphs.vertex_data(::AbstractTensorNetwork) = not_implemented() -DataGraphs.edge_data(::AbstractTensorNetwork) = not_implemented() - DataGraphs.underlying_graph(::AbstractTensorNetwork) = not_implemented() function NamedGraphs.vertex_positions(tn::AbstractTensorNetwork) return NamedGraphs.vertex_positions(underlying_graph(tn)) @@ -240,10 +236,7 @@ end Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) -function Graphs.induced_subgraph(graph::AbstractTensorNetwork, subvertices::AbstractVector{V}) where {V <: Int} - return tensornetwork_induced_subgraph(graph, subvertices) -end -function Graphs.induced_subgraph(graph::AbstractTensorNetwork, subvertices) +function Graphs.induced_subgraph(graph::AbstractTensorNetwork{V}, subvertices::Vector{V}) where {V} return tensornetwork_induced_subgraph(graph, subvertices) end From dd6f6454f01380e03e609cd60b1d4bfdf5499718 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 6 Jan 2026 09:48:10 -0500 Subject: [PATCH 20/86] `TensorNetwork` type now uses new DataGraphs interface --- src/tensornetwork.jl | 50 +++++++++++++++++++++++++++++++------------- 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 0681da5..16c80e3 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -1,9 +1,9 @@ using Combinatorics: combinations using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph -using Dictionaries: AbstractDictionary, Indices, dictionary +using Dictionaries: AbstractDictionary, Indices, dictionary, set!, unset! using Graphs: AbstractSimpleGraph, rem_vertex!, rem_edge! using NamedDimsArrays: AbstractNamedDimsArray, dimnames -using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype +using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype, Vertices, parent_graph_indices using NamedGraphs.GraphsExtensions: GraphsExtensions, arranged_edges, arrange_edge, vertextype using NamedGraphs.PartitionedGraphs: AbstractPartitionedGraph, @@ -12,9 +12,13 @@ using NamedGraphs.PartitionedGraphs: partitioned_vertices, partitionedgraph, quotient_graph, - quotient_graph_type + quotient_graph_type, + QuotientVertex, + QuotientVertices, + QuotientVertexVertices, + quotientvertices using .LazyNamedDimsArrays: lazy, Mul -using DataGraphs: vertex_data_eltype, vertex_data, edge_data +using DataGraphs: vertex_data_eltype, vertex_data, edge_data, get_vertices_data using DataGraphs.DataGraphsPartitionedGraphsExt function _TensorNetwork end @@ -31,7 +35,7 @@ struct TensorNetwork{V, VD, UG <: AbstractGraph{V}, Tensors <: AbstractDictionar end end # This assumes the tensor connectivity matches the graph structure. -function _TensorNetwork(graph::AbstractGraph, tensors) +function TensorNetwork(graph::AbstractGraph, tensors::AbstractDictionary) return _TensorNetwork(graph, Dictionary(keys(tensors), values(tensors))) end @@ -39,10 +43,18 @@ function TensorNetwork{V, VD, UG, Tensors}(graph::UG) where {V, VD, UG <: Abstra return _TensorNetwork(graph, Tensors()) end -DataGraphs.underlying_graph(tn::TensorNetwork) = getfield(tn, :underlying_graph) -DataGraphs.vertex_data(tn::TensorNetwork) = getfield(tn, :tensors) -DataGraphs.edge_data(tn::TensorNetwork) = Dictionary{edgetype(tn), Nothing}() -DataGraphs.vertex_data_eltype(T::Type{<:TensorNetwork}) = eltype(fieldtype(T, :tensors)) +# DataGraphs interface + +DataGraphs.underlying_graph(tn::TensorNetwork) = tn.underlying_graph + +DataGraphs.has_vertex_data(tn::TensorNetwork, v) = haskey(tn.tensors, v) +DataGraphs.has_edge_data(tn::TensorNetwork, e) = false + +DataGraphs.get_vertex_data(tn::TensorNetwork, v) = tn.tensors[v] + +DataGraphs.set_vertex_data!(tn::TensorNetwork, val, v) = set!(tn.tensors, v, val) +DataGraphs.unset_vertex_data!(tn::TensorNetwork, val, v) = unset!(tn.tensors, v, val) + function DataGraphs.underlying_graph_type(type::Type{<:TensorNetwork}) return fieldtype(type, :underlying_graph) end @@ -123,17 +135,23 @@ function Graphs.rem_edge!(tn::TensorNetwork, e) return true end -function GraphsExtensions.graph_from_vertices(type::Type{<:TensorNetwork}, vertices) +function GraphsExtensions.similar(type::Type{<:TensorNetwork}) DT = fieldtype(type, :tensors) empty_dict = DT() - return TensorNetwork(similar_graph(underlying_graph_type(type), vertices), empty_dict) + return TensorNetwork(similar_graph(underlying_graph_type(type)), empty_dict) end ## PartitionedGraphs function PartitionedGraphs.quotient_graph(tn::TensorNetwork) ug = quotient_graph(underlying_graph(tn)) - return TensorNetwork(ug, vertex_data(QuotientView(tn))) + + inds = Indices(parent_graph_indices(QuotientVertices(tn))) + data = map(v -> tn[QuotientVertex(v)], inds) + + return TensorNetwork(ug, data) end +# TODO: This method should not be required with a better interface with a better +# DataGraphsPartitionedGraphsExt interface. function PartitionedGraphs.quotient_graph_type(type::Type{<:TensorNetwork}) UG = quotient_graph_type(underlying_graph_type(type)) VD = Vector{vertex_data_eltype(type)} @@ -141,9 +159,10 @@ function PartitionedGraphs.quotient_graph_type(type::Type{<:TensorNetwork}) return TensorNetwork{V, VD, UG, Dictionary{V, VD}} end +# Partition the underlying graph of the tensor network; does not affect the data. function PartitionedGraphs.partitionedgraph(tn::TensorNetwork, parts) pg = partitionedgraph(underlying_graph(tn), parts) - return TensorNetwork(pg, vertex_data(tn)) + return TensorNetwork(pg, copy(vertex_data(tn))) end PartitionedGraphs.departition(tn::TensorNetwork) = tn @@ -153,8 +172,9 @@ function PartitionedGraphs.departition( return TensorNetwork(departition(underlying_graph(tn)), vertex_data(tn)) end -function DataGraphsPartitionedGraphsExt.to_quotient_vertex_data(::TensorNetwork, data) - return mapreduce(lazy, *, collect(last(data))) +function DataGraphs.get_vertices_data(tn::TensorNetwork, vertex::QuotientVertexVertices) + data = collect(map(v -> tn[v], NamedGraphs.parent_graph_indices(vertex))) + return mapreduce(lazy, *, data) end function PartitionedGraphs.quotientview(tn::TensorNetwork) From 7bb579c7037c93e591a09a0c88e3aa489ef39c5d Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Fri, 19 Dec 2025 16:37:59 -0500 Subject: [PATCH 21/86] Sweeping algorithms based on AlgorithmsInterface.jl (#30) --- Project.toml | 4 +- docs/Project.toml | 2 +- examples/Project.toml | 2 +- .../AlgorithmsInterfaceExtensions.jl | 306 ++++++++++++ src/ITensorNetworksNext.jl | 6 +- src/abstract_problem.jl | 1 - src/adapters.jl | 45 -- src/iterators.jl | 170 ------- src/sweeping/eigenproblem.jl | 44 ++ src/sweeping/utils.jl | 12 + test/Project.toml | 3 +- test/test_algorithmsinterfaceextensions.jl | 472 ++++++++++++++++++ test/test_aqua.jl | 2 +- test/test_dmrg.jl | 34 ++ test/test_iterators.jl | 221 -------- test/test_sweeping.jl | 65 +++ 16 files changed, 944 insertions(+), 445 deletions(-) create mode 100644 src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl delete mode 100644 src/abstract_problem.jl delete mode 100644 src/adapters.jl delete mode 100644 src/iterators.jl create mode 100644 src/sweeping/eigenproblem.jl create mode 100644 src/sweeping/utils.jl create mode 100644 test/test_algorithmsinterfaceextensions.jl create mode 100644 test/test_dmrg.jl delete mode 100644 test/test_iterators.jl create mode 100644 test/test_sweeping.jl diff --git a/Project.toml b/Project.toml index 95b8be0..e6919fc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,12 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.2.4" +version = "0.3.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +AlgorithmsInterface = "d1e3940c-cd12-4505-8585-b0a4b322527d" BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a" @@ -32,6 +33,7 @@ ITensorNetworksNextTensorOperationsExt = "TensorOperations" [compat] AbstractTrees = "0.4.5" Adapt = "4.3" +AlgorithmsInterface = "0.1.0" BackendSelection = "0.1.6" Combinatorics = "1" DataGraphs = "0.2.7" diff --git a/docs/Project.toml b/docs/Project.toml index 15d156a..9e273b0 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -8,5 +8,5 @@ ITensorNetworksNext = {path = ".."} [compat] Documenter = "1" -ITensorNetworksNext = "0.2" +ITensorNetworksNext = "0.3" Literate = "2" diff --git a/examples/Project.toml b/examples/Project.toml index a9cd21b..bd688e9 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -5,4 +5,4 @@ ITensorNetworksNext = "302f2e75-49f0-4526-aef7-d8ba550cb06c" ITensorNetworksNext = {path = ".."} [compat] -ITensorNetworksNext = "0.2" +ITensorNetworksNext = "0.3" diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl new file mode 100644 index 0000000..a8c814e --- /dev/null +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -0,0 +1,306 @@ +module AlgorithmsInterfaceExtensions + +import AlgorithmsInterface as AI + +#========================== Patches for AlgorithmsInterface.jl ============================# + +abstract type Problem <: AI.Problem end +abstract type Algorithm <: AI.Algorithm end +abstract type State <: AI.State end + +function AI.initialize_state!( + problem::Problem, algorithm::Algorithm, state::State; iteration = 0, kwargs... + ) + for (k, v) in pairs(kwargs) + setproperty!(state, k, v) + end + state.iteration = iteration + AI.initialize_state!( + problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state + ) + return state +end + +function AI.initialize_state( + problem::Problem, algorithm::Algorithm; kwargs... + ) + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) + return DefaultState(; stopping_criterion_state, kwargs...) +end + +#============================ DefaultState ================================================# + +@kwdef mutable struct DefaultState{ + Iterate, StoppingCriterionState <: AI.StoppingCriterionState, + } <: State + iterate::Iterate + iteration::Int = 0 + stopping_criterion_state::StoppingCriterionState +end + +#============================ increment! ==================================================# + +# Custom version of `increment!` that also takes the problem and algorithm as arguments. +function AI.increment!(problem::Problem, algorithm::Algorithm, state::State) + return AI.increment!(state) +end + +#============================ solve! ======================================================# + +# Custom version of `solve!` that allows specifying the logger and also overloads +# `increment!` on the problem and algorithm. +function basetypenameof(x) + return Symbol(last(split(String(Symbol(Base.typename(typeof(x)).wrapper)), "."))) +end +default_logging_context_prefix(x) = Symbol(basetypenameof(x), :_) +function default_logging_context_prefix(problem::Problem, algorithm::Algorithm) + return Symbol( + default_logging_context_prefix(problem), + default_logging_context_prefix(algorithm), + ) +end +function AI.solve!( + problem::Problem, algorithm::Algorithm, state::State; + logging_context_prefix = default_logging_context_prefix(problem, algorithm), + kwargs..., + ) + logger = AI.algorithm_logger() + + context_suffixes = [:Start, :PreStep, :PostStep, :Stop] + contexts = Dict(context_suffixes .=> Symbol.(logging_context_prefix, context_suffixes)) + + # initialize the state and emit message + AI.initialize_state!(problem, algorithm, state; kwargs...) + AI.emit_message(logger, problem, algorithm, state, contexts[:Start]) + + # main body of the algorithm + while !AI.is_finished!(problem, algorithm, state) + AI.increment!(problem, algorithm, state) + + # logging event between convergence check and algorithm step + AI.emit_message(logger, problem, algorithm, state, contexts[:PreStep]) + + # algorithm step + AI.step!(problem, algorithm, state; logging_context_prefix) + + # logging event between algorithm step and convergence check + AI.emit_message(logger, problem, algorithm, state, contexts[:PostStep]) + end + + # emit message about finished state + AI.emit_message(logger, problem, algorithm, state, contexts[:Stop]) + return state +end + +function AI.solve( + problem::Problem, algorithm::Algorithm; + logging_context_prefix = default_logging_context_prefix(problem, algorithm), + kwargs..., + ) + state = AI.initialize_state(problem, algorithm; kwargs...) + return AI.solve!(problem, algorithm, state; logging_context_prefix, kwargs...) +end + +#============================ AlgorithmIterator ===========================================# + +abstract type AlgorithmIterator end + +function algorithm_iterator( + problem::Problem, algorithm::Algorithm, state::State + ) + return DefaultAlgorithmIterator(problem, algorithm, state) +end + +function AI.is_finished!(iterator::AlgorithmIterator) + return AI.is_finished!(iterator.problem, iterator.algorithm, iterator.state) +end +function AI.is_finished(iterator::AlgorithmIterator) + return AI.is_finished(iterator.problem, iterator.algorithm, iterator.state) +end +function AI.increment!(iterator::AlgorithmIterator) + return AI.increment!(iterator.problem, iterator.algorithm, iterator.state) +end +function AI.step!(iterator::AlgorithmIterator) + return AI.step!(iterator.problem, iterator.algorithm, iterator.state) +end +function Base.iterate(iterator::AlgorithmIterator, init = nothing) + AI.is_finished!(iterator) && return nothing + AI.increment!(iterator) + AI.step!(iterator) + return iterator.state, nothing +end + +struct DefaultAlgorithmIterator{Problem, Algorithm, State} <: AlgorithmIterator + problem::Problem + algorithm::Algorithm + state::State +end + +#============================ with_algorithmlogger ========================================# + +# Allow passing functions, not just CallbackActions. +@inline function with_algorithmlogger(f, args::Pair{Symbol, AI.LoggingAction}...) + return AI.with_algorithmlogger(f, args...) +end +@inline function with_algorithmlogger(f, args::Pair{Symbol}...) + return AI.with_algorithmlogger(f, (first.(args) .=> AI.CallbackAction.(last.(args)))...) +end + +#============================ NestedAlgorithm =============================================# + +abstract type NestedAlgorithm <: Algorithm end + +function nested_algorithm(f::Function, nalgorithms::Int; kwargs...) + return DefaultNestedAlgorithm(f, nalgorithms; kwargs...) +end + +max_iterations(algorithm::NestedAlgorithm) = length(algorithm.algorithms) + +function get_subproblem( + problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State + ) + subproblem = problem + subalgorithm = algorithm.algorithms[state.iteration] + substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) + return subproblem, subalgorithm, substate +end + +function set_substate!( + problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State, substate::AI.State + ) + state.iterate = substate.iterate + return state +end + +function AI.step!( + problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State; + logging_context_prefix = Symbol() + ) + # Get the subproblem, subalgorithm, and substate. + subproblem, subalgorithm, substate = get_subproblem(problem, algorithm, state) + + # Solve the subproblem with the subalgorithm. + logging_context_prefix = Symbol( + logging_context_prefix, default_logging_context_prefix(subalgorithm) + ) + AI.solve!(subproblem, subalgorithm, substate; logging_context_prefix) + + # Update the state with the substate. + set_substate!(problem, algorithm, state, substate) + + return state +end + +#= + DefaultNestedAlgorithm(sweeps::AbstractVector{<:Algorithm}) + +An algorithm that consists of running an algorithm at each iteration +from a list of stored algorithms. +=# +@kwdef struct DefaultNestedAlgorithm{ + ChildAlgorithm <: Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, + } <: NestedAlgorithm + algorithms::Algorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) +end +function DefaultNestedAlgorithm(f::Function, nalgorithms::Int; kwargs...) + return DefaultNestedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...) +end + +#============================ FlattenedAlgorithm ==========================================# + +# Flatten a nested algorithm. +abstract type FlattenedAlgorithm <: Algorithm end +abstract type FlattenedAlgorithmState <: State end + +function flattened_algorithm(f::Function, nalgorithms::Int; kwargs...) + return DefaultFlattenedAlgorithm(f, nalgorithms; kwargs...) +end + +function AI.initialize_state( + problem::Problem, algorithm::FlattenedAlgorithm; kwargs... + ) + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) + return DefaultFlattenedAlgorithmState(; stopping_criterion_state, kwargs...) +end +function AI.increment!( + problem::Problem, algorithm::Algorithm, state::FlattenedAlgorithmState + ) + # Increment the total iteration count. + state.iteration += 1 + # TODO: Use `is_finished!` instead? + if state.child_iteration ≥ max_iterations(algorithm.algorithms[state.parent_iteration]) + # We're on the last iteration of the child algorithm, so move to the next + # child algorithm. + state.parent_iteration += 1 + state.child_iteration = 1 + else + # Iterate the child algorithm. + state.child_iteration += 1 + end + return state +end +function AI.step!( + problem::AI.Problem, algorithm::FlattenedAlgorithm, state::FlattenedAlgorithmState; + logging_context_prefix = Symbol() + ) + algorithm_sweep = algorithm.algorithms[state.parent_iteration] + state_sweep = AI.initialize_state( + problem, algorithm_sweep; + state.iterate, iteration = state.child_iteration + ) + AI.step!(problem, algorithm_sweep, state_sweep; logging_context_prefix) + state.iterate = state_sweep.iterate + return state +end + +@kwdef struct DefaultFlattenedAlgorithm{ + ChildAlgorithm <: Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, + } <: FlattenedAlgorithm + algorithms::Algorithms + stopping_criterion::StoppingCriterion = + AI.StopAfterIteration(sum(max_iterations, algorithms)) +end +function DefaultFlattenedAlgorithm(f::Function, nalgorithms::Int; kwargs...) + return DefaultFlattenedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...) +end + +@kwdef mutable struct DefaultFlattenedAlgorithmState{ + Iterate, StoppingCriterionState <: AI.StoppingCriterionState, + } <: FlattenedAlgorithmState + iterate::Iterate + iteration::Int = 0 + parent_iteration::Int = 1 + child_iteration::Int = 0 + stopping_criterion_state::StoppingCriterionState +end + +#============================ NonIterativeAlgorithm =======================================# + +# Algorithm that only performs a single step. +abstract type NonIterativeAlgorithm <: Algorithm end +abstract type NonIterativeAlgorithmState <: State end + +function AI.initialize_state(problem::Problem, algorithm::NonIterativeAlgorithm; kwargs...) + return DefaultNonIterativeAlgorithmState(; kwargs...) +end +function AI.solve!( + problem::Problem, algorithm::NonIterativeAlgorithm, state::State; kwargs... + ) + return throw(MethodError(AI.solve!, (problem, algorithm, state))) +end + +@kwdef mutable struct DefaultNonIterativeAlgorithmState{Iterate} <: + NonIterativeAlgorithmState + iterate::Iterate +end + +end diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index cca4b6d..d3c5c21 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -1,13 +1,13 @@ module ITensorNetworksNext +include("AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl") include("LazyNamedDimsArrays/LazyNamedDimsArrays.jl") include("abstracttensornetwork.jl") include("tensornetwork.jl") include("TensorNetworkGenerators/TensorNetworkGenerators.jl") include("contract_network.jl") -include("abstract_problem.jl") -include("iterators.jl") -include("adapters.jl") +include("sweeping/utils.jl") +include("sweeping/eigenproblem.jl") include("beliefpropagation/abstractbeliefpropagationcache.jl") include("beliefpropagation/beliefpropagationcache.jl") diff --git a/src/abstract_problem.jl b/src/abstract_problem.jl deleted file mode 100644 index 5a65e0a..0000000 --- a/src/abstract_problem.jl +++ /dev/null @@ -1 +0,0 @@ -abstract type AbstractProblem end diff --git a/src/adapters.jl b/src/adapters.jl deleted file mode 100644 index 28318fb..0000000 --- a/src/adapters.jl +++ /dev/null @@ -1,45 +0,0 @@ -""" - struct IncrementOnly{S<:AbstractNetworkIterator} <: AbstractNetworkIterator - -Iterator wrapper whos `compute!` function simply returns itself, doing nothing in the -process. This allows one to manually call a custom `compute!` or insert their own code it in -the loop body in place of `compute!`. -""" -struct IncrementOnly{S <: AbstractNetworkIterator} <: AbstractNetworkIterator - parent::S -end - -islaststep(adapter::IncrementOnly) = islaststep(adapter.parent) -increment!(adapter::IncrementOnly) = increment!(adapter.parent) -compute!(adapter::IncrementOnly) = adapter - -IncrementOnly(adapter::IncrementOnly) = adapter - -""" - struct EachRegion{SweepIterator} <: AbstractNetworkIterator - -Adapter that flattens each region iterator in the parent sweep iterator into a single -iterator. -""" -struct EachRegion{SI <: SweepIterator} <: AbstractNetworkIterator - parent::SI -end - -# In keeping with Julia convention. -eachregion(iter::SweepIterator) = EachRegion(iter) - -# Essential definitions -function islaststep(adapter::EachRegion) - region_iter = region_iterator(adapter.parent) - return islaststep(adapter.parent) && islaststep(region_iter) -end -function increment!(adapter::EachRegion) - region_iter = region_iterator(adapter.parent) - islaststep(region_iter) ? increment!(adapter.parent) : increment!(region_iter) - return adapter -end -function compute!(adapter::EachRegion) - region_iter = region_iterator(adapter.parent) - compute!(region_iter) - return adapter -end diff --git a/src/iterators.jl b/src/iterators.jl deleted file mode 100644 index 62d5b21..0000000 --- a/src/iterators.jl +++ /dev/null @@ -1,170 +0,0 @@ -""" - abstract type AbstractNetworkIterator - -A stateful iterator with two states: `increment!` and `compute!`. Each iteration begins -with a call to `increment!` before executing `compute!`, however the initial call to -`iterate` skips the `increment!` call as it is assumed the iterator is initalized such that -this call is implict. Termination of the iterator is controlled by the function `done`. -""" -abstract type AbstractNetworkIterator end - -# We use greater than or equals here as we increment the state at the start of the iteration -islaststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator) - -function Base.iterate(iterator::AbstractNetworkIterator, init = true) - # The assumption is that first "increment!" is implicit, therefore we must skip the - # the termination check for the first iteration, i.e. `AbstractNetworkIterator` is not - # defined when length < 1, - init || islaststep(iterator) && return nothing - # We seperate increment! from step! and demand that any AbstractNetworkIterator *must* - # define a method for increment! This way we avoid cases where one may wish to nest - # calls to different step! methods accidentaly incrementing multiple times. - init || increment!(iterator) - rv = compute!(iterator) - return rv, false -end - -increment!(iterator::AbstractNetworkIterator) = throw(MethodError(increment!, Tuple{typeof(iterator)})) -compute!(iterator::AbstractNetworkIterator) = iterator - -step!(iterator::AbstractNetworkIterator) = step!(identity, iterator) -function step!(f, iterator::AbstractNetworkIterator) - compute!(iterator) - f(iterator) - increment!(iterator) - return iterator -end - -# -# RegionIterator -# -""" - struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator -""" -mutable struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator - problem::Problem - region_plan::RegionPlan - which_region::Int - const which_sweep::Int - function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P, R} - if isempty(region_plan) - throw(ArgumentError("Cannot construct a region iterator with 0 elements.")) - end - return new{P, R}(problem, region_plan, 1, sweep) - end -end - -function RegionIterator(problem; sweep, sweep_kwargs...) - plan = region_plan(problem; sweep_kwargs...) - return RegionIterator(problem, plan, sweep) -end - -state(region_iter::RegionIterator) = region_iter.which_region -Base.length(region_iter::RegionIterator) = length(region_iter.region_plan) - -problem(region_iter::RegionIterator) = region_iter.problem - -function current_region_plan(region_iter::RegionIterator) - return region_iter.region_plan[region_iter.which_region] -end - -function current_region(region_iter::RegionIterator) - region, _ = current_region_plan(region_iter) - return region -end - -function region_kwargs(region_iter::RegionIterator) - _, kwargs = current_region_plan(region_iter) - return kwargs -end -function region_kwargs(f::Function, iter::RegionIterator) - return get(region_kwargs(iter), Symbol(f, :_kwargs), (;)) -end - -function prev_region(region_iter::RegionIterator) - state(region_iter) <= 1 && return nothing - prev, _ = region_iter.region_plan[region_iter.which_region - 1] - return prev -end - -function next_region(region_iter::RegionIterator) - islaststep(region_iter) && return nothing - next, _ = region_iter.region_plan[region_iter.which_region + 1] - return next -end - -# -# Functions associated with RegionIterator -# -function increment!(region_iter::RegionIterator) - region_iter.which_region += 1 - return region_iter -end - -function compute!(iter::RegionIterator) - extract!(iter; region_kwargs(extract!, iter)...) - update!(iter; region_kwargs(update!, iter)...) - insert!(iter; region_kwargs(insert!, iter)...) - - return iter -end - -region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs...) - -# -# SweepIterator -# - -mutable struct SweepIterator{Problem, Iter} <: AbstractNetworkIterator - region_iter::RegionIterator{Problem} - sweep_kwargs::Iterators.Stateful{Iter} - which_sweep::Int - function SweepIterator(problem::Prob, sweep_kwargs::Iter) where {Prob, Iter} - stateful_sweep_kwargs = Iterators.Stateful(sweep_kwargs) - - first_state = Iterators.peel(stateful_sweep_kwargs) - - if isnothing(first_state) - throw(ArgumentError("Cannot construct a sweep iterator with 0 elements.")) - end - - first_kwargs, _ = first_state - region_iter = RegionIterator(problem; sweep = 1, first_kwargs...) - - return new{Prob, Iter}(region_iter, stateful_sweep_kwargs, 1) - end -end - -islaststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kwargs)) - -region_iterator(sweep_iter::SweepIterator) = sweep_iter.region_iter -problem(sweep_iter::SweepIterator) = problem(region_iterator(sweep_iter)) - -state(sweep_iter::SweepIterator) = sweep_iter.which_sweep -Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kwargs) -function increment!(sweep_iter::SweepIterator) - sweep_iter.which_sweep += 1 - sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kwargs) - update_region_iterator!(sweep_iter; sweep_kwargs...) - return sweep_iter -end - -function update_region_iterator!(iterator::SweepIterator; kwargs...) - sweep = state(iterator) - iterator.region_iter = RegionIterator(problem(iterator); sweep, kwargs...) - return iterator -end - -function compute!(sweep_iter::SweepIterator) - for _ in sweep_iter.region_iter - # TODO: Is it sensible to execute the default region callback function? - end - return -end - -# More basic constructor where sweep_kwargs are constant throughout sweeps -function SweepIterator(problem, nsweeps::Int; sweep_kwargs...) - # Initialize this to an empty RegionIterator - sweep_kwargs_iter = Iterators.repeated(sweep_kwargs, nsweeps) - return SweepIterator(problem, sweep_kwargs_iter) -end diff --git a/src/sweeping/eigenproblem.jl b/src/sweeping/eigenproblem.jl new file mode 100644 index 0000000..36978b2 --- /dev/null +++ b/src/sweeping/eigenproblem.jl @@ -0,0 +1,44 @@ +import AlgorithmsInterface as AI +import .AlgorithmsInterfaceExtensions as AIE + +function dmrg(operator, algorithm, state) + problem = EigenProblem(operator) + return AI.solve(problem, algorithm; iterate = state).iterate +end +function dmrg(operator, state; kwargs...) + problem = EigenProblem(operator) + algorithm = select_algorithm(dmrg, operator, state; kwargs...) + return AI.solve(problem, algorithm; iterate = state).iterate +end + +# TODO: Allow specifying the region algorithm type? +function select_algorithm(::typeof(dmrg), operator, state; nsweeps, regions, kwargs...) + extended_kwargs = extend_columns((; kwargs...), nsweeps) + region_kwargs = rows(extended_kwargs) + return AIE.nested_algorithm(nsweeps) do i + return AIE.nested_algorithm(length(regions)) do j + return EigsolveRegion(regions[j]; region_kwargs[i]...) + end + end +end +#= + EigenProblem(operator) + +Represents the problem we are trying to solve and minimal algorithm-independent +information, so for an eigenproblem it is the operator we want the eigenvector of. +=# +struct EigenProblem{Operator} <: AIE.Problem + operator::Operator +end + +struct EigsolveRegion{R, Kwargs <: NamedTuple} <: AIE.NonIterativeAlgorithm + region::R + kwargs::Kwargs +end +EigsolveRegion(region; kwargs...) = EigsolveRegion(region, (; kwargs...)) + +function AI.solve!( + problem::EigenProblem, algorithm::EigsolveRegion, state::AIE.State; kwargs... + ) + return error("EigsolveRegion step for EigenProblem not implemented yet.") +end diff --git a/src/sweeping/utils.jl b/src/sweeping/utils.jl new file mode 100644 index 0000000..39e09e4 --- /dev/null +++ b/src/sweeping/utils.jl @@ -0,0 +1,12 @@ +# Utility functions for processing keyword arguments. +function repeat_last(v::AbstractVector, len::Int) + return [v; fill(v[end], max(len - length(v), 0))] +end +repeat_last(v, len::Int) = fill(v, len) +function extend_columns(nt::NamedTuple, len::Int) + return (; (keys(nt) .=> map(v -> repeat_last(v, len), values(nt)))...) +end +rowlength(nt::NamedTuple) = only(unique(length.(values(nt)))) +function rows(nt::NamedTuple, len::Int = rowlength(nt)) + return [(; (keys(nt) .=> map(v -> v[i], values(nt)))...) for i in 1:len] +end diff --git a/test/Project.toml b/test/Project.toml index 4b7dc81..e71e7a4 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +AlgorithmsInterface = "d1e3940c-cd12-4505-8585-b0a4b322527d" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" @@ -26,7 +27,7 @@ DiagonalArrays = "0.3.23" Dictionaries = "0.4.5" Graphs = "1.13.1" ITensorBase = "0.3" -ITensorNetworksNext = "0.2" +ITensorNetworksNext = "0.3" NamedDimsArrays = "0.8, 0.9" NamedGraphs = "0.6.8, 0.7, 0.8" QuadGK = "2.11.2" diff --git a/test/test_algorithmsinterfaceextensions.jl b/test/test_algorithmsinterfaceextensions.jl new file mode 100644 index 0000000..8e0665c --- /dev/null +++ b/test/test_algorithmsinterfaceextensions.jl @@ -0,0 +1,472 @@ +import AlgorithmsInterface as AI +import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +using Test: @test, @testset + +# Define test problems, algorithms, and states for testing +struct TestProblem <: AIE.Problem + data::Vector{Float64} +end + +@kwdef struct TestAlgorithm{StoppingCriterion <: AI.StoppingCriterion} <: AIE.Algorithm + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(10) +end + +@kwdef struct TestAlgorithmStep{StoppingCriterion <: AI.StoppingCriterion} <: AIE.Algorithm + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(5) +end + +function AI.step!( + problem::TestProblem, algorithm::TestAlgorithm, state::AIE.DefaultState; + logging_context_prefix = Symbol() + ) + state.iterate .+= 1 # Simple increment step + return state +end + +function AI.step!( + problem::TestProblem, algorithm::TestAlgorithmStep, state::AIE.DefaultState; + kwargs... + ) + state.iterate .+= 2 # Different increment step + return state +end + +@testset "AlgorithmsInterfaceExtensions" begin + @testset "DefaultState" begin + # Test DefaultState construction + iterate = [1.0, 2.0, 3.0] + stopping_criterion_state = AI.initialize_state( + TestProblem([1.0]), TestAlgorithm(), TestAlgorithm().stopping_criterion + ) + state = AIE.DefaultState(; iterate = copy(iterate), stopping_criterion_state) + @test state.iterate == iterate + @test state.iteration == 0 + @test state.stopping_criterion_state isa AI.StoppingCriterionState + + # Test DefaultState with custom iteration + state.iteration = 5 + @test state.iteration == 5 + end + + @testset "initialize_state!" begin + # Test initialize_state! with iterate kwarg + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm() + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) + state = AIE.DefaultState(; + iteration = 2, iterate = [0.0, 0.0], stopping_criterion_state + ) + AI.initialize_state!(problem, algorithm, state) + @test state.iterate == [0.0, 0.0] + @test state.iteration == 0 + @test state.stopping_criterion_state == stopping_criterion_state + end + + @testset "initialize_state" begin + # Test initialize_state without exclamation + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm() + + state = AI.initialize_state(problem, algorithm; iterate = [0.0, 0.0]) + @test state isa AIE.DefaultState + @test state.iteration == 0 + end + + @testset "increment!" begin + # Test increment! with problem and algorithm + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm() + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) + state = AIE.DefaultState(; iterate = [0.0, 0.0], stopping_criterion_state) + + # Increment and verify iteration counter increases + AI.increment!(problem, algorithm, state) + @test state.iteration == 1 + + AI.increment!(problem, algorithm, state) + @test state.iteration == 2 + end + + @testset "solve! and solve" begin + # Test solve! with simple problem + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(3)) + + initial_iterate = [10.0, 20.0] + state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) + + # Solve with custom initial iterate + initial_iterate = [5.0, 10.0] + final_state = AI.solve!( + problem, algorithm, state; iterate = copy(initial_iterate) + ) + + @test final_state.iteration == 3 + # Each step increments by 1, so after 3 steps: [5, 10] + 3 = [8, 13] + @test final_state.iterate ≈ [8.0, 13.0] + + # Test solve without exclamation + problem2 = TestProblem([1.0, 2.0]) + algorithm2 = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) + initial_iterate2 = [5.0, 10.0] + + final_state2 = AI.solve(problem2, algorithm2; iterate = copy(initial_iterate2)) + @test final_state2.iteration == 2 + @test final_state2.iterate ≈ [7.0, 12.0] + end + + @testset "DefaultAlgorithmIterator" begin + # Test algorithm iterator creation + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) + initial_iterate = [0.0, 0.0] + state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) + iterator = AIE.algorithm_iterator(problem, algorithm, state) + + @test iterator isa AIE.DefaultAlgorithmIterator + @test iterator.problem === problem + @test iterator.algorithm === algorithm + @test iterator.state === state + + # Test iteration interface + @test !AI.is_finished!(iterator) + + # Step through iterator + state_out, _ = iterate(iterator) + @test state_out.iteration == 1 + @test state_out.iterate ≈ [1.0, 1.0] # Incremented by step! + + state_out, _ = iterate(iterator) + @test state_out.iteration == 2 + + @test AI.is_finished!(iterator) + end + + @testset "with_algorithmlogger" begin + # Test with_algorithmlogger with functions + results = [] + function callback1(problem, algorithm, state) + push!(results, :callback1) + return nothing + end + function callback2(problem, algorithm, state) + push!(results, :callback2) + return nothing + end + + problem = TestProblem([1.0]) + algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)) + + # Test with CallbackAction (wrapped functions) + state = AIE.with_algorithmlogger( + :TestProblem_TestAlgorithm_PreStep => callback1, + :TestProblem_TestAlgorithm_PostStep => callback2, + ) do + return AI.solve(problem, algorithm; iterate = [0.0]) + end + @test results == [:callback1, :callback2] + end + + @testset "DefaultNestedAlgorithm" begin + # Test creating nested algorithm with function + nested_alg = AIE.nested_algorithm(3) do i + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + + @test nested_alg isa AIE.DefaultNestedAlgorithm + @test length(nested_alg.algorithms) == 3 + @test AIE.max_iterations(nested_alg) == 3 + + # Test stepping through nested algorithm + problem = TestProblem([1.0, 2.0]) + stopping_criterion_state = AI.initialize_state( + problem, nested_alg, nested_alg.stopping_criterion + ) + state = AIE.DefaultState(; iterate = [0.0, 0.0], stopping_criterion_state) + + initial_iterate = [0.0, 0.0] + AI.solve!( + problem, nested_alg, state; iterate = copy(initial_iterate) + ) + + @test state.iteration == 3 + # Each nested algorithm runs once with 2 steps, incrementing by 2 + # Total: 3 algorithms × 2 iterations × 2 increment = 12 + @test state.iterate ≈ [12.0, 12.0] + end + + @testset "NestedAlgorithm basic tests" begin + # Test basic nested algorithm functionality + nested_alg = AIE.nested_algorithm(2) do i + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + + problem = TestProblem([1.0, 2.0]) + + # Test state initialization + state_nested = AI.initialize_state(problem, nested_alg; iterate = [0.0, 0.0]) + + @test state_nested isa AIE.DefaultState + @test state_nested.iteration == 0 + @test AIE.max_iterations(nested_alg) == 2 + end + + @testset "increment! for nested algorithms" begin + # Test increment! logic for nested algorithm state + problem = TestProblem([1.0]) + nested_alg = AIE.nested_algorithm(2) do i + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + + stopping_criterion_state = AI.initialize_state( + problem, nested_alg, nested_alg.stopping_criterion + ) + state = AIE.DefaultState(; + iterate = [0.0], + stopping_criterion_state = stopping_criterion_state, + ) + + # Test progression through iterations + @test state.iteration == 0 + + AI.increment!(problem, nested_alg, state) + @test state.iteration == 1 + + AI.increment!(problem, nested_alg, state) + @test state.iteration == 2 + end + + @testset "get_subproblem and set_substate!" begin + # Test get_subproblem + problem = TestProblem([1.0, 2.0]) + nested_alg = AIE.nested_algorithm(2) do i + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(1)) + end + + stopping_criterion_state = AI.initialize_state( + problem, nested_alg, nested_alg.stopping_criterion + ) + state = AIE.DefaultState(; + iterate = [5.0, 10.0], + iteration = 1, + stopping_criterion_state, + ) + + subproblem, subalgorithm, substate = AIE.get_subproblem(problem, nested_alg, state) + @test subproblem === problem + @test subalgorithm === nested_alg.algorithms[1] + @test substate.iterate ≈ [5.0, 10.0] + + # Test set_substate! + new_substate = AIE.DefaultState(; + iterate = [100.0, 200.0], + substate.stopping_criterion_state, + ) + AIE.set_substate!(problem, nested_alg, state, new_substate) + @test state.iterate ≈ [100.0, 200.0] + end + + @testset "basetypenameof and default_logging_context_prefix" begin + # Test basetypenameof utility + problem = TestProblem([1.0]) + algorithm = TestAlgorithm() + + prefix_problem = AIE.default_logging_context_prefix(problem) + prefix_algorithm = AIE.default_logging_context_prefix(algorithm) + prefix_combined = AIE.default_logging_context_prefix(problem, algorithm) + + @test prefix_problem isa Symbol + @test prefix_algorithm isa Symbol + @test prefix_combined isa Symbol + @test contains(String(prefix_combined), String(prefix_problem)) + end + + @testset "DefaultFlattenedAlgorithm" begin + # Create nested algorithms that support max_iterations + nested_algs = map(1:3) do i + return AIE.nested_algorithm(1) do j + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + end + + flattened_alg = AIE.DefaultFlattenedAlgorithm(; + algorithms = nested_algs, + stopping_criterion = AI.StopAfterIteration(6) # 3 algorithms × 2 iterations each + ) + + @test flattened_alg isa AIE.DefaultFlattenedAlgorithm + @test length(flattened_alg.algorithms) == 3 + + # Test state initialization + problem = TestProblem([1.0, 2.0]) + state_flat = AI.initialize_state(problem, flattened_alg; iterate = [0.0, 0.0]) + + @test state_flat isa AIE.DefaultFlattenedAlgorithmState + @test state_flat.iteration == 0 + @test state_flat.parent_iteration == 1 + @test state_flat.child_iteration == 0 + end + + @testset "DefaultFlattenedAlgorithmState increment!" begin + # Create nested algorithms for flattened algorithm + nested_algs = map(1:2) do i + return AIE.nested_algorithm(1) do j + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + end + + flattened_alg = AIE.DefaultFlattenedAlgorithm(; + algorithms = nested_algs, + stopping_criterion = AI.StopAfterIteration(4), + ) + + problem = TestProblem([1.0]) + stopping_criterion_state = AI.initialize_state( + problem, flattened_alg, flattened_alg.stopping_criterion + ) + state = AIE.DefaultFlattenedAlgorithmState(; + iterate = [0.0], + stopping_criterion_state = stopping_criterion_state, + ) + + # Test initial state + @test state.iteration == 0 + @test state.parent_iteration == 1 + @test state.child_iteration == 0 + + # First increment - should increment child_iteration + AI.increment!(problem, flattened_alg, state) + @test state.iteration == 1 + @test state.parent_iteration == 1 + @test state.child_iteration == 1 + + # Second increment - should increment child_iteration again + AI.increment!(problem, flattened_alg, state) + @test state.iteration == 2 + @test state.parent_iteration == 2 # Should move to next parent + @test state.child_iteration == 1 + end + + @testset "FlattenedAlgorithm step!" begin + # Test individual step! calls for flattened algorithm + nested_algs = map(1:2) do i + return AIE.nested_algorithm(1) do j + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + end + + flattened_alg = AIE.DefaultFlattenedAlgorithm(; + algorithms = nested_algs, + stopping_criterion = AI.StopAfterIteration(4) + ) + + problem = TestProblem([1.0, 2.0]) + state = AI.initialize_state(problem, flattened_alg; iterate = [0.0, 0.0]) + + # Manually step through to test step! functionality + AI.increment!(problem, flattened_alg, state) + @test state.parent_iteration == 1 + @test state.child_iteration == 1 + + AI.step!(problem, flattened_alg, state) + # The nested algorithm runs TestAlgorithmStep with 2 iterations, each incrementing by 2 + @test state.iterate ≈ [4.0, 4.0] + end + + @testset "flattened_algorithm helper" begin + # Test the flattened_algorithm helper function + nested_algs = map(1:2) do i + return AIE.nested_algorithm(1) do j + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + end + + # Using the helper function + flattened_alg = AIE.flattened_algorithm(2) do i + AIE.nested_algorithm(1) do j + TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + end + + @test flattened_alg isa AIE.DefaultFlattenedAlgorithm + @test length(flattened_alg.algorithms) == 2 + end + + @testset "AlgorithmIterator is_finished (without !)" begin + # Test is_finished without mutation + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)) + initial_iterate = [0.0, 0.0] + state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) + iterator = AIE.algorithm_iterator(problem, algorithm, state) + + # Before any iterations + @test !AI.is_finished(iterator) + + # Run the algorithm + AI.solve!(problem, algorithm, state; iterate = copy(initial_iterate)) + + # After completion + @test AI.is_finished(iterator) + end + + @testset "AlgorithmIterator step!" begin + # Test step! method for iterator + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) + initial_iterate = [0.0, 0.0] + state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) + iterator = AIE.algorithm_iterator(problem, algorithm, state) + + # Step the iterator + AI.step!(iterator) + @test iterator.state.iterate ≈ [1.0, 1.0] + + AI.step!(iterator) + @test iterator.state.iterate ≈ [2.0, 2.0] + end + + @testset "NestedAlgorithm with different sub-algorithms" begin + # Test nested algorithm with varying sub-algorithms + nested_alg = AIE.DefaultNestedAlgorithm(; + algorithms = [ + TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)), + TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)), + TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)), + ] + ) + + @test AIE.max_iterations(nested_alg) == 3 + @test length(nested_alg.algorithms) == 3 + + problem = TestProblem([1.0, 2.0]) + state = AI.initialize_state(problem, nested_alg; iterate = [0.0, 0.0]) + + AI.solve!(problem, nested_alg, state; iterate = [0.0, 0.0]) + + # First algorithm: 1 iteration × 1 increment = 1 + # Second algorithm: 2 iterations × 2 increment = 4 + # Third algorithm: 1 iteration × 1 increment = 1 + # Total: 1 + 4 + 1 = 6 + @test state.iterate ≈ [6.0, 6.0] + @test state.iteration == 3 + end + + @testset "Edge cases" begin + # Test with single nested algorithm + nested_alg = AIE.nested_algorithm(1) do i + return TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)) + end + + problem = TestProblem([1.0]) + state = AI.initialize_state(problem, nested_alg; iterate = [0.0]) + AI.solve!(problem, nested_alg, state; iterate = [0.0]) + + @test state.iterate ≈ [1.0] + @test state.iteration == 1 + end +end diff --git a/test/test_aqua.jl b/test/test_aqua.jl index 0afead5..a38563a 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -3,5 +3,5 @@ using Aqua: Aqua using Test: @testset @testset "Code quality (Aqua.jl)" begin - Aqua.test_all(ITensorNetworksNext) + Aqua.test_all(ITensorNetworksNext; persistent_tasks = false) end diff --git a/test/test_dmrg.jl b/test/test_dmrg.jl new file mode 100644 index 0000000..01f04ac --- /dev/null +++ b/test/test_dmrg.jl @@ -0,0 +1,34 @@ +import AlgorithmsInterface as AI +using ITensorNetworksNext: EigsolveRegion, dmrg, select_algorithm +import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +using Test: @test, @testset + +@testset "select_algorithm(dmrg, ...)" begin + operator = "operator" + init = "init" + nsweeps = 3 + regions = ["region1", "region2"] + maxdim = [10, 20] + cutoff = 1.0e-7 + algorithm = select_algorithm(dmrg, operator, init; nsweeps, regions, maxdim, cutoff) + @test algorithm isa AIE.NestedAlgorithm + @test length(algorithm.algorithms) == nsweeps + + maxdims = [10, 20, 20] + cutoffs = [1.0e-7, 1.0e-7, 1.0e-7] + algorithm′ = AIE.nested_algorithm(nsweeps) do i + return AIE.nested_algorithm(length(regions)) do j + return EigsolveRegion( + regions[j]; + maxdim = maxdims[i], + cutoff = cutoffs[i], + ) + end + end + for i in 1:nsweeps + for j in 1:length(regions) + @test algorithm.algorithms[i].algorithms[j] == + algorithm′.algorithms[i].algorithms[j] + end + end +end diff --git a/test/test_iterators.jl b/test/test_iterators.jl deleted file mode 100644 index a17c7be..0000000 --- a/test/test_iterators.jl +++ /dev/null @@ -1,221 +0,0 @@ -using Test: @test, @testset, @test_throws -import ITensorNetworksNext as ITensorNetworks -using .ITensorNetworks: RegionIterator, SweepIterator, IncrementOnly, compute!, increment!, islaststep, state, eachregion - -module TestIteratorUtils - - import ITensorNetworksNext as ITensorNetworks - using .ITensorNetworks - - struct TestProblem <: ITensorNetworks.AbstractProblem - data::Vector{Int} - end - ITensorNetworks.region_plan(::TestProblem) = [:a => (; val = 1), :b => (; val = 2)] - function ITensorNetworks.compute!(iter::ITensorNetworks.RegionIterator{<:TestProblem}) - kwargs = ITensorNetworks.region_kwargs(iter) - push!(ITensorNetworks.problem(iter).data, kwargs.val) - return iter - end - - - mutable struct TestIterator <: ITensorNetworks.AbstractNetworkIterator - state::Int - max::Int - output::Vector{Int} - end - - ITensorNetworks.increment!(TI::TestIterator) = TI.state += 1 - Base.length(TI::TestIterator) = TI.max - ITensorNetworks.state(TI::TestIterator) = TI.state - function ITensorNetworks.compute!(TI::TestIterator) - push!(TI.output, ITensorNetworks.state(TI)) - return TI - end - - mutable struct SquareAdapter <: ITensorNetworks.AbstractNetworkIterator - parent::TestIterator - end - - Base.length(SA::SquareAdapter) = length(SA.parent) - ITensorNetworks.increment!(SA::SquareAdapter) = ITensorNetworks.increment!(SA.parent) - ITensorNetworks.state(SA::SquareAdapter) = ITensorNetworks.state(SA.parent) - function ITensorNetworks.compute!(SA::SquareAdapter) - ITensorNetworks.compute!(SA.parent) - return last(SA.parent.output)^2 - end - -end - -@testset "Iterators" begin - - import .TestIteratorUtils - - @testset "`AbstractNetworkIterator` Interface" begin - - @testset "Edge cases" begin - TI = TestIteratorUtils.TestIterator(1, 1, []) - cb = [] - @test islaststep(TI) - for _ in TI - @test islaststep(TI) - push!(cb, state(TI)) - end - @test length(cb) == 1 - @test length(TI.output) == 1 - @test only(cb) == 1 - - prob = TestIteratorUtils.TestProblem([]) - @test_throws ArgumentError SweepIterator(prob, 0) - @test_throws ArgumentError RegionIterator(prob, [], 1) - end - - TI = TestIteratorUtils.TestIterator(1, 4, []) - - @test !islaststep((TI)) - - # First iterator should compute only - rv, st = iterate(TI) - @test !islaststep((TI)) - @test !st - @test rv === TI - @test length(TI.output) == 1 - @test only(TI.output) == 1 - @test state(TI) == 1 - @test !st - - rv, st = iterate(TI, st) - @test !islaststep((TI)) - @test !st - @test length(TI.output) == 2 - @test state(TI) == 2 - @test TI.output == [1, 2] - - increment!(TI) - @test !islaststep((TI)) - @test state(TI) == 3 - @test length(TI.output) == 2 - @test TI.output == [1, 2] - - compute!(TI) - @test !islaststep((TI)) - @test state(TI) == 3 - @test length(TI.output) == 3 - @test TI.output == [1, 2, 3] - - # Final Step - iterate(TI, false) - @test islaststep((TI)) - @test state(TI) == 4 - @test length(TI.output) == 4 - @test TI.output == [1, 2, 3, 4] - - @test iterate(TI, false) === nothing - - TI = TestIteratorUtils.TestIterator(1, 5, []) - - cb = [] - - for _ in TI - @test length(cb) == length(TI.output) - 1 - @test cb == (TI.output)[1:(end - 1)] - push!(cb, state(TI)) - @test cb == TI.output - end - - @test islaststep((TI)) - @test length(TI.output) == 5 - @test length(cb) == 5 - @test cb == TI.output - - - TI = TestIteratorUtils.TestIterator(1, 5, []) - end - - @testset "Adapters" begin - TI = TestIteratorUtils.TestIterator(1, 5, []) - SA = TestIteratorUtils.SquareAdapter(TI) - - @testset "Generic" begin - - i = 0 - for rv in SA - i += 1 - @test rv isa Int - @test rv == i^2 - @test state(SA) == i - end - - @test islaststep((SA)) - - TI = TestIteratorUtils.TestIterator(1, 5, []) - SA = TestIteratorUtils.SquareAdapter(TI) - - SA_c = collect(SA) - - @test SA_c isa Vector - @test length(SA_c) == 5 - @test SA_c == [1, 4, 9, 16, 25] - - end - - @testset "IncrementOnly" begin - TI = TestIteratorUtils.TestIterator(1, 5, []) - NI = IncrementOnly(TI) - - NI_c = [] - - for _ in IncrementOnly(TI) - push!(NI_c, state(TI)) - end - - @test length(NI_c) == 5 - @test isempty(TI.output) - end - - @testset "EachRegion" begin - prob = TestIteratorUtils.TestProblem([]) - prob_region = TestIteratorUtils.TestProblem([]) - - SI = SweepIterator(prob, 5) - SI_region = SweepIterator(prob_region, 5) - - callback = [] - callback_region = [] - - let i = 1 - for _ in SI - push!(callback, i) - i += 1 - end - end - - @test length(callback) == 5 - - let i = 1 - for _ in eachregion(SI_region) - push!(callback_region, i) - i += 1 - end - end - - @test length(callback_region) == 10 - - @test prob.data == prob_region.data - - @test prob.data[1:2:end] == fill(1, 5) - @test prob.data[2:2:end] == fill(2, 5) - - - let i = 1, prob = TestIteratorUtils.TestProblem([]) - SI = SweepIterator(prob, 1) - cb = [] - for _ in eachregion(SI) - push!(cb, i) - i += 1 - end - @test length(cb) == 2 - end - - end - end -end diff --git a/test/test_sweeping.jl b/test/test_sweeping.jl new file mode 100644 index 0000000..215a8b8 --- /dev/null +++ b/test/test_sweeping.jl @@ -0,0 +1,65 @@ +import AlgorithmsInterface as AI +import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +using Test: @test, @testset + +struct TestProblem <: AIE.Problem +end + +struct TestRegion{R, Kwargs <: NamedTuple} <: AIE.NonIterativeAlgorithm + region::R + kwargs::Kwargs +end +TestRegion(region; kwargs...) = TestRegion(region, (; kwargs...)) + +function AI.solve!(problem::TestProblem, algorithm::TestRegion, state::AIE.State; kwargs...) + new_iterate = (; algorithm.region, algorithm.kwargs.foo, algorithm.kwargs.bar) + state.iterate = [state.iterate; [new_iterate]] + return state +end + +@testset "Sweeping" begin + @testset "TestRegion" begin + algorithm = TestRegion("region"; foo = 1, bar = 2) + @test algorithm isa AIE.NonIterativeAlgorithm + @test algorithm isa AIE.Algorithm + @test algorithm isa AI.Algorithm + @test algorithm.region == "region" + @test algorithm.kwargs == (; foo = 1, bar = 2) + + problem = TestProblem() + iterate = [] + state = AI.solve(problem, algorithm; iterate) + @test state.iterate == [(; region = "region", foo = 1, bar = 2)] + end + @testset "Sweep" begin + algorithm = AIE.nested_algorithm(3) do i + return TestRegion("region$i"; foo = i, bar = 2i) + end + problem = TestProblem() + iterate = [] + state = AI.solve(problem, algorithm; iterate) + @test state.iterate == [ + (; region = "region1", foo = 1, bar = 2), + (; region = "region2", foo = 2, bar = 4), + (; region = "region3", foo = 3, bar = 6), + ] + end + @testset "Sweeping" begin + algorithm = AIE.nested_algorithm(2) do i + AIE.nested_algorithm(3) do j + return TestRegion("sweep$i, region$j"; foo = (i, j), bar = (2i, 2j)) + end + end + problem = TestProblem() + iterate = [] + state = AI.solve(problem, algorithm; iterate) + @test state.iterate == [ + (; region = "sweep1, region1", foo = (1, 1), bar = (2, 2)), + (; region = "sweep1, region2", foo = (1, 2), bar = (2, 4)), + (; region = "sweep1, region3", foo = (1, 3), bar = (2, 6)), + (; region = "sweep2, region1", foo = (2, 1), bar = (4, 2)), + (; region = "sweep2, region2", foo = (2, 2), bar = (4, 4)), + (; region = "sweep2, region3", foo = (2, 3), bar = (4, 6)), + ] + end +end From 032447a00de29e7a8fba27f76bb0ae6a8c193e26 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Tue, 23 Dec 2025 18:15:22 -0500 Subject: [PATCH 22/86] Upgrade to NamedDimsArrays.jl v0.11 (#38) --- Project.toml | 6 +++--- test/Project.toml | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index e6919fc..7b86558 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.3.0" +version = "0.3.1" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" @@ -33,7 +33,7 @@ ITensorNetworksNextTensorOperationsExt = "TensorOperations" [compat] AbstractTrees = "0.4.5" Adapt = "4.3" -AlgorithmsInterface = "0.1.0" +AlgorithmsInterface = "0.1" BackendSelection = "0.1.6" Combinatorics = "1" DataGraphs = "0.2.7" @@ -43,7 +43,7 @@ Dictionaries = "0.4.5" Graphs = "1.13.1" LinearAlgebra = "1.10" MacroTools = "0.5.16" -NamedDimsArrays = "0.8, 0.9" +NamedDimsArrays = "0.8, 0.9, 0.10, 0.11" NamedGraphs = "0.6.9, 0.7, 0.8" SimpleTraits = "0.9.5" SplitApplyCombine = "1.2.3" diff --git a/test/Project.toml b/test/Project.toml index e71e7a4..0e74eef 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -22,13 +22,14 @@ ITensorNetworksNext = {path = ".."} [compat] AbstractTrees = "0.4.5" +AlgorithmsInterface = "0.1" Aqua = "0.8.14" DiagonalArrays = "0.3.23" Dictionaries = "0.4.5" Graphs = "1.13.1" -ITensorBase = "0.3" +ITensorBase = "0.3, 0.4" ITensorNetworksNext = "0.3" -NamedDimsArrays = "0.8, 0.9" +NamedDimsArrays = "0.8, 0.9, 0.10, 0.11" NamedGraphs = "0.6.8, 0.7, 0.8" QuadGK = "2.11.2" SafeTestsets = "0.1" From b256d79f250cc5f06b83885381879b8f0fa41f10 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 9 Jan 2026 14:34:38 -0500 Subject: [PATCH 23/86] [LazyNamedDimsArrays] New `symnameddims` method that pulls out indices from an array. --- src/LazyNamedDimsArrays/symbolicnameddimsarray.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl b/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl index a215319..628baf3 100644 --- a/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl +++ b/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl @@ -5,6 +5,9 @@ const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} = function symnameddims(name, dims) return lazy(nameddims(SymbolicArray(name, dename.(dims)), dims)) end +function symnameddims(name, ndarray::AbstractNamedDimsArray) + return symnameddims(name, Tuple(inds(ndarray))) +end symnameddims(name) = symnameddims(name, ()) using AbstractTrees: AbstractTrees function AbstractTrees.printnode(io::IO, a::SymbolicNamedDimsArray) From b2da9d80a35da7ea5a2b51fb791a1115342cd8ca Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 9 Jan 2026 14:35:32 -0500 Subject: [PATCH 24/86] The function `region_scalar` should now return a scalar, rather than a order-0 array --- src/beliefpropagation/abstractbeliefpropagationcache.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 0cae3fa..3545b53 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -62,7 +62,7 @@ function setfactor!(bpc::AbstractDataGraph, vertex, factor) end function region_scalar(bp_cache::AbstractGraph, edge::AbstractEdge) - return message(bp_cache, edge) * message(bp_cache, reverse(edge)) + return (message(bp_cache, edge) * message(bp_cache, reverse(edge)))[] end function region_scalar(bp_cache::AbstractGraph, vertex) @@ -70,7 +70,7 @@ function region_scalar(bp_cache::AbstractGraph, vertex) messages = incoming_messages(bp_cache, vertex) state = factors(bp_cache, vertex) - return reduce(*, messages) * reduce(*, state) + return (reduce(*, messages) * reduce(*, state))[] end message_type(bpc::AbstractGraph) = message_type(typeof(bpc)) From 8506e26a3d8814e3e51487a48469f27c9cd64a8f Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 9 Jan 2026 14:37:43 -0500 Subject: [PATCH 25/86] Fix double counting in `edge_scalars` function This was caused by the change to the `cache` being backed by a directed graph. --- src/beliefpropagation/abstractbeliefpropagationcache.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 3545b53..8e7185e 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -81,7 +81,7 @@ function vertex_scalars(bp_cache::AbstractGraph, vertices = vertices(bp_cache)) return map(v -> region_scalar(bp_cache, v), vertices) end -function edge_scalars(bp_cache::AbstractGraph, edges = edges(bp_cache)) +function edge_scalars(bp_cache::AbstractGraph, edges = edges(undirected_graph(underlying_graph(bp_cache)))) return map(e -> region_scalar(bp_cache, e), edges) end @@ -120,7 +120,9 @@ adapt_factors(to, bp_cache, vs = vertices(bp_cache)) = map_factors(adapt(to), bp abstract type AbstractBeliefPropagationCache{V, ED} <: AbstractDataGraph{V, Nothing, ED} end function free_energy(bp_cache::AbstractBeliefPropagationCache) + numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache) + if any(t -> real(t) < 0, numerator_terms) numerator_terms = complex.(numerator_terms) end From 938180af0e35b3e091aa39bfa405a0dd5842d523 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 9 Jan 2026 14:37:59 -0500 Subject: [PATCH 26/86] Minor code formatting --- src/beliefpropagation/abstractbeliefpropagationcache.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 8e7185e..0efc95d 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -130,7 +130,10 @@ function free_energy(bp_cache::AbstractBeliefPropagationCache) denominator_terms = complex.(denominator_terms) end - any(iszero, denominator_terms) && return -Inf + if any(iszero, denominator_terms) + return -Inf + end + return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) end partitionfunction(bp_cache::AbstractBeliefPropagationCache) = exp(free_energy(bp_cache)) From 44619673fedaf47c59bd2557222086807f12a2ec Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 9 Jan 2026 14:39:43 -0500 Subject: [PATCH 27/86] Expressed belief propagation in terms of AlgorithmsInterface --- .../beliefpropagationcache.jl | 13 + .../beliefpropagationproblem.jl | 279 +++++++++++++----- src/sweeping/utils.jl | 8 +- test/test_beliefpropagation.jl | 10 +- 4 files changed, 222 insertions(+), 88 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index c9793e6..27a580d 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -23,6 +23,7 @@ using NamedGraphs.GraphsExtensions: default_root_vertex, is_path_graph, undirected_graph using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, QuotientEdges, quotient_graph, quotientedges +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, lazy, parenttype struct BeliefPropagationCache{V, G <: AbstractGraph{V}, N <: AbstractGraph{V}, ET, MT} <: AbstractBeliefPropagationCache{V, MT} @@ -125,3 +126,15 @@ function PartitionedGraphs.quotientview(bpc::BeliefPropagationCache) data = map(e -> bpc[QuotientEdge(e)], inds) return BeliefPropagationCache(QuotientView(network(bpc)), data) end + +function default_message(bpc::BeliefPropagationCache, edge) + return default_message(message_type(bpc), network(bpc), edge) +end +function default_message(T::Type, network, edge) + array = ones(Tuple(linkinds(network, edge))) + return convert(T, array) +end +function default_message(T::Type{<:LazyNamedDimsArray}, network, edge) + message = default_message(parenttype(T), network, edge) + return convert(T, lazy(message)) +end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 24b024d..0d997ee 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,82 +1,200 @@ -using Graphs: SimpleGraph, vertices, edges, has_edge +using Graphs: SimpleGraph, vertices, edges, has_edge, AbstractEdge using NamedGraphs: AbstractNamedGraph, position_graph using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices using NamedGraphs.OrderedDictionaries: OrderedDictionary, OrderedIndices using NamedDimsArrays: AbstractNamedDimsArray -using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, parenttype, lazy +using DataGraphs: edge_data +import AlgorithmsInterface as AI +import .AlgorithmsInterfaceExtensions as AIE -abstract type AbstractBeliefPropagationProblem{Alg} <: AbstractProblem end +@kwdef struct StopWhenConverged <: AI.StoppingCriterion + tol::Float64 = 0.0 +end -mutable struct BeliefPropagationProblem{Alg, Cache} <: AbstractBeliefPropagationProblem{Alg} - const alg::Alg - const cache::Cache - diff::Union{Nothing, Float64} +@kwdef mutable struct StopWhenConvergedState <: AI.StoppingCriterionState + delta::Float64 = Inf end -BeliefPropagationProblem(alg, cache) = BeliefPropagationProblem(alg, cache, nothing) +function AI.initialize_state(::AIE.Problem, ::AIE.Algorithm, ::StopWhenConverged) + return StopWhenConvergedState() +end -function default_algorithm( - ::Type{<:Algorithm"bp"}, - bpc; - verbose = false, - tol = nothing, - edge_sequence = forest_cover_edge_sequence(bpc), - message_update_alg = default_algorithm(Algorithm"contract"), - maxiter = is_tree(bpc) ? 1 : nothing, +function AI.initialize_state!( + ::AIE.Problem, + ::AIE.Algorithm, + ::StopWhenConverged, + st::StopWhenConvergedState, ) - return Algorithm("bp"; verbose, tol, edge_sequence, message_update_alg, maxiter) + st.delta = Inf + return st end -function region_plan(prob::BeliefPropagationProblem{<:Algorithm"bp"}; sweep_kwargs...) - edges = prob.alg.edge_sequence +function AI.is_finished!( + ::AIE.Problem, + ::AIE.Algorithm, + state::AIE.State, + c::StopWhenConverged, + st::StopWhenConvergedState, + ) - plan = map(edges) do e - return e => (; sweep_kwargs...) + # maxdiff = 0.0 initially, so skip this the first time. + if state.iteration > 0 + st.delta = state.iterate.maxdiff end - return plan + return st.delta < c.tol +end + +struct BeliefPropagationProblem{Network} <: AIE.Problem + network::Network +end + +@kwdef mutable struct BeliefPropagationState{ + Iterate <: BeliefPropagationCache, + Diffs, + } <: AIE.NonIterativeAlgorithmState + iterate::Iterate + diffs::Diffs = similar(edge_data(iterate), Float64) + maxdiff::Float64 = 0.0 +end + +function AI.initialize_state( + problem::BeliefPropagationProblem, + algorithm::AIE.NonIterativeAlgorithm; iterate, kwargs... + ) + + diffs = iterate.diffs + maxdiff = iterate.maxdiff + + return BeliefPropagationState(; iterate = iterate.iterate, diffs, maxdiff, kwargs...) end -function compute!(iter::RegionIterator{<:BeliefPropagationProblem{<:Algorithm"bp"}}) - prob = iter.problem +# This gets called at the start of every sweep. +function AI.initialize_state!( + problem::BeliefPropagationProblem, + algorithm::AIE.NestedAlgorithm, + state::AIE.State, + ) + state.iterate.maxdiff = 0.0 + return state +end + +function AIE.set_substate!( + ::BeliefPropagationProblem, + ::AIE.NestedAlgorithm, + state::AIE.State, + substate::BeliefPropagationState + ) + + state.iterate = substate + + return state +end - edge, _ = current_region_plan(iter) - new_message = updated_message(prob.alg.message_update_alg, prob.cache, edge) - setmessage!(prob.cache, edge, new_message) +abstract type AbstractMessageUpdate <: AIE.NonIterativeAlgorithm end - return iter +struct SimpleMessageUpdate{E <: AbstractEdge, Kwargs <: NamedTuple} <: AbstractMessageUpdate + edge::E + kwargs::Kwargs end -function default_message(bpc::BeliefPropagationCache, edge) - return default_message(message_type(bpc), network(bpc), edge) +function SimpleMessageUpdate( + edge; + normalize = false, + contraction_alg = "eager", + compute_diff = false, + kwargs... + ) + return SimpleMessageUpdate(edge, (; normalize, contraction_alg, compute_diff, kwargs...)) end -function default_message(T::Type, network, edge) - array = ones(Tuple(linkinds(network, edge))) - return convert(T, array) + +function Base.getproperty(alg::SimpleMessageUpdate, name::Symbol) + if name in (:edge, :kwargs) + return getfield(alg, name) + else + return getproperty(getfield(alg, :kwargs), name) + end end -function default_message(T::Type{<:LazyNamedDimsArray}, network, edge) - message = default_message(parenttype(T), network, edge) - return convert(T, lazy(message)) + +struct MessageUpdateProblem{Messages, Factors} <: AIE.Problem + messages::Messages + factors::Factors end -updated_message(alg, bpc, edge) = not_implemented() -function updated_message(alg::Algorithm"contract", bpc, edge) +function AI.solve!( + problem::BeliefPropagationProblem, + algorithm::AbstractMessageUpdate, + state::BeliefPropagationState; + logging_context_prefix = default_logging_context_prefix(problem, algorithm), + ) + + logger = AI.algorithm_logger() + + cache = state.iterate + edge = algorithm.edge + + AI.emit_message( + logger, problem, algorithm, state, Symbol(logging_context_prefix, :PreUpdate) + ) + + new_message = updated_message(algorithm, cache) + + if algorithm.compute_diff + diff = message_diff(new_message, cache[edge]) + + if diff > state.maxdiff + state.maxdiff = diff + end + + state.diffs[edge] = diff + end + + setmessage!(cache, edge, new_message) + + AI.emit_message( + logger, problem, algorithm, state, Symbol(logging_context_prefix, :PostUpdate) + ) + + return state +end + +message_diff(m1, m2) = LinearAlgebra.norm(m1 - m2) + +function updated_message(algorithm, cache) + edge = algorithm.edge + vertex = src(edge) + messages = incoming_messages(cache, vertex; ignore_edges = typeof(edge)[reverse(edge)]) + + update_problem = MessageUpdateProblem(messages, factors(cache, vertex)) + + message_state = AI.solve(update_problem, algorithm; iterate = message(cache, edge)) - incoming_ms = incoming_messages( - bpc, vertex; ignore_edges = typeof(edge)[reverse(edge)] + return message_state.iterate +end + +function AI.solve!( + problem::MessageUpdateProblem, + algorithm::SimpleMessageUpdate, + state::AIE.NonIterativeAlgorithmState; + logging_context_prefix = AI.default_logging_context_prefix(problem, algorithm), + kwargs... ) - updated_message = contract_messages(alg.contraction_alg, factors(bpc, vertex), incoming_ms) + # TODO: logging... - if alg.normalize - message_norm = LinearAlgebra.norm(updated_message) + state.iterate = contract_messages(algorithm.contraction_alg, problem.factors, problem.messages) + + if algorithm.normalize + # TODO: use `sum` not `norm` + message_norm = LinearAlgebra.norm(state.iterate) if !iszero(message_norm) - updated_message /= message_norm + state.iterate /= message_norm end end - return updated_message + + return state end contract_messages(alg, factors, messages) = not_implemented() @@ -85,54 +203,51 @@ function contract_messages( factors::Vector{<:AbstractArray}, messages::Vector{<:AbstractArray}, ) - return contract_network(alg, vcat(factors, messages)) + return contract_network(vcat(factors, messages); alg) end -function default_algorithm( - ::Type{<:Algorithm"contract"}; normalize = true, contraction_alg = Algorithm("exact") - ) - return Algorithm("contract"; normalize, contraction_alg) -end -function default_algorithm( - ::Type{<:Algorithm"adapt_update"}; adapt, alg = default_algorithm(Algorithm"contract") - ) - return Algorithm("adapt_update"; adapt, alg) -end +beliefpropagation(network; kwargs...) = beliefpropagation(BeliefPropagationCache(network); kwargs...) +function beliefpropagation(cache::AbstractBeliefPropagationCache; kwargs...) -function update_message!( - message_update_alg::Algorithm, bpc::BeliefPropagationCache, edge::AbstractEdge - ) - return setmessage!(bpc, edge, updated_message(message_update_alg, bpc, edge)) -end + problem = BeliefPropagationProblem(network(cache)) -function update(bpc::AbstractBeliefPropagationCache; kwargs...) - return update(default_algorithm(Algorithm"bp", bpc; kwargs...), bpc) -end + algorithm = select_algorithm(beliefpropagation, cache; kwargs...) -function update(alg, bpc) - compute_error = !isnothing(alg.tol) + # The nested algorithms will wrap and manipulate this object. + base_state = BeliefPropagationState(; iterate = cache) - diff = compute_error ? 0.0 : nothing + state = AI.solve(problem, algorithm; iterate = base_state) - prob = BeliefPropagationProblem(alg, bpc, diff) + return state.iterate.iterate +end - iter = SweepIterator(prob, alg.maxiter; compute_error) +function select_algorithm( + ::typeof(beliefpropagation), + cache; + edges = forest_cover_edge_sequence(network(cache)), + maxiter = is_tree(network(cache)) ? 1 : nothing, + tol = 0.0, + kwargs... + ) - for _ in iter - if compute_error && prob.diff <= alg.tol - break - end + if isnothing(maxiter) + throw(ArgumentError("`maxiter` must be specified for non-tree graphs")) end - if alg.verbose && compute_error - if prob.diff <= alg.tol - println("BP converged to desired precision after $(iter.which_sweep) iterations.") - else - println( - "BP failed to converge to precision $(alg.tol), got $(prob.diff) after $(iter.which_sweep) iterations", - ) - end + stopping_criterion = AI.StopAfterIteration(maxiter) + compute_diff = false + + if tol > 0.0 + stopping_criterion = stopping_criterion | StopWhenConverged(tol) + compute_diff = true end - return bpc + extended_kwargs = extend_columns((; compute_diff, kwargs...), maxiter) + edge_kwargs = rows(extended_kwargs, len = maxiter) + + return AIE.nested_algorithm(maxiter; stopping_criterion) do repnum + return AIE.nested_algorithm(length(edges)) do edgenum + return SimpleMessageUpdate(edges[edgenum]; edge_kwargs[repnum]...) + end + end end diff --git a/src/sweeping/utils.jl b/src/sweeping/utils.jl index 39e09e4..9a39c9d 100644 --- a/src/sweeping/utils.jl +++ b/src/sweeping/utils.jl @@ -7,6 +7,12 @@ function extend_columns(nt::NamedTuple, len::Int) return (; (keys(nt) .=> map(v -> repeat_last(v, len), values(nt)))...) end rowlength(nt::NamedTuple) = only(unique(length.(values(nt)))) -function rows(nt::NamedTuple, len::Int = rowlength(nt)) +function rows(nt::NamedTuple; len = nothing) + if isnothing(len) + if isempty(nt) + throw(ArgumentError("Got empty named tuple; keyword `len` must be specified in this case.")) + end + len = rowlength(nt) + end return [(; (keys(nt) .=> map(v -> v[i], values(nt)))...) for i in 1:len] end diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index a39e1a6..8c7829b 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -20,7 +20,7 @@ using Test: @test, @testset @testset "BeliefPropagation" begin #Chain of tensors - dims = (4, 1) + dims = (2, 1) g = named_grid(dims) l = Dict(e => Index(2) for e in edges(g)) l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) @@ -30,10 +30,10 @@ using Test: @test, @testset end bpc = BeliefPropagationCache(tn) - bpc = ITensorNetworksNext.update(bpc; maxiter = 1) + bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) z_bp = partitionfunction(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] - @test abs(z_bp - z_exact) <= 1.0e-14 + @test z_bp ≈ z_exact atol = 1.0e-14 #Tree of tensors dims = (4, 3) @@ -46,8 +46,8 @@ using Test: @test, @testset end bpc = BeliefPropagationCache(tn) - bpc = ITensorNetworksNext.update(bpc; maxiter = 10) + bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) z_bp = partitionfunction(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] - @test abs(z_bp - z_exact) <= 1.0e-14 + @test z_bp ≈ z_exact atol = 1.0e-12 end From d68860ae59092f2382fccfee87d03abe9a097b58 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 9 Jan 2026 14:40:23 -0500 Subject: [PATCH 28/86] Fixes to TensorNetwork construction from tensor list --- src/abstracttensornetwork.jl | 4 ++-- src/tensornetwork.jl | 13 +++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index b820867..08f86a1 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -1,7 +1,7 @@ using Adapt: Adapt, adapt, adapt_structure using BackendSelection: @Algorithm_str, Algorithm using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, underlying_graph, - underlying_graph_type, vertex_data + underlying_graph_type, vertex_data, set_vertex_data! using Dictionaries: Dictionary using Graphs: Graphs, AbstractEdge, AbstractGraph, Graph, add_edge!, add_vertex!, bfs_tree, center, dst, edges, edgetype, ne, neighbors, nv, rem_edge!, src, vertices @@ -111,7 +111,7 @@ function sitenames(tn::AbstractGraph, edge::AbstractEdge) end function setindex_preserve_graph!(tn::AbstractGraph, value, vertex) - set!(vertex_data(tn), vertex, value) + set_vertex_data!(tn, value, vertex) return tn end diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 16c80e3..b811e2b 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -35,8 +35,13 @@ struct TensorNetwork{V, VD, UG <: AbstractGraph{V}, Tensors <: AbstractDictionar end end # This assumes the tensor connectivity matches the graph structure. +function TensorNetwork(graph::AbstractGraph, tensors) + return TensorNetwork(graph, Dictionary(keys(tensors), values(tensors))) +end function TensorNetwork(graph::AbstractGraph, tensors::AbstractDictionary) - return _TensorNetwork(graph, Dictionary(keys(tensors), values(tensors))) + tn = _TensorNetwork(graph, tensors) + fix_links!(tn) + return tn end function TensorNetwork{V, VD, UG, Tensors}(graph::UG) where {V, VD, UG <: AbstractGraph{V}, Tensors} @@ -80,11 +85,6 @@ tensornetwork_edges(tensors) = tensornetwork_edges(NamedEdge, tensors) function TensorNetwork(f::Base.Callable, graph::AbstractGraph) return TensorNetwork(graph, Dictionary(map(f, vertices(graph)))) end -function TensorNetwork(graph::AbstractGraph, tensors) - tn = _TensorNetwork(graph, tensors) - fix_links!(tn) - return tn -end # Insert trivial links for missing edges, and also check # the vertices and edges are consistent between the graph and tensors. @@ -172,6 +172,7 @@ function PartitionedGraphs.departition( return TensorNetwork(departition(underlying_graph(tn)), vertex_data(tn)) end +# When getting data according the quotient vertices, take a lazy contraction. function DataGraphs.get_vertices_data(tn::TensorNetwork, vertex::QuotientVertexVertices) data = collect(map(v -> tn[v], NamedGraphs.parent_graph_indices(vertex))) return mapreduce(lazy, *, data) From 2f5c783f4760d813777e392321c97028f05b3f99 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 9 Jan 2026 14:41:18 -0500 Subject: [PATCH 29/86] Minor simplifications to `contract_network` interface. --- src/contract_network.jl | 91 ++++++++++++++++------------------- test/test_contract_network.jl | 12 ++--- 2 files changed, 48 insertions(+), 55 deletions(-) diff --git a/src/contract_network.jl b/src/contract_network.jl index e89fa00..4511595 100644 --- a/src/contract_network.jl +++ b/src/contract_network.jl @@ -1,69 +1,62 @@ using BackendSelection: @Algorithm_str, Algorithm using Base.Broadcast: materialize -using ITensorNetworksNext.LazyNamedDimsArrays: Mul, lazy, optimize_evaluation_order, +using NamedDimsArrays: inds +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArrays, Mul, lazy, optimize_evaluation_order, substitute, symnameddims -# This is related to `MatrixAlgebraKit.select_algorithm`. -# TODO: Define this in BackendSelection.jl. -backend_value(::Algorithm{alg}) where {alg} = alg -using BackendSelection: parameters -function merge_parameters(alg::Algorithm; kwargs...) - return Algorithm(backend_value(alg); merge(parameters(alg), kwargs)...) +function contract_network(tn; alg = default_kwargs(contract_network, tn).alg) + return contract_network(alg, tn) end -to_algorithm(alg::Algorithm; kwargs...) = merge_parameters(alg; kwargs...) -to_algorithm(alg; kwargs...) = Algorithm(alg; kwargs...) -# `contract_network` -function contract_network(alg::Algorithm, tn) - return throw(ArgumentError("`contract_network` algorithm `$(alg)` not implemented.")) -end -function default_kwargs(::typeof(contract_network), tn) - return (; alg = Algorithm"exact"(; order_alg = Algorithm"eager"())) -end -function contract_network(tn; alg = default_kwargs(contract_network, tn).alg, kwargs...) - return contract_network(to_algorithm(alg; kwargs...), tn) +contract_network(alg::String, tn) = contract_network(Algorithm(alg), tn) + +default_kwargs(::typeof(contract_network), tn) = (; alg = "eager") + +function contract_network( + alg, + tensors, + ) + + order = contraction_expression(tensors; order = alg) + symbols_to_tensors = Dict( + symnameddims(i, tensors[i]) => lazy(tensors[i]) for i in keys(tensors) + ) + + return materialize(substitute(order, symbols_to_tensors)) end -# `contract_network(::Algorithm"exact", ...)` -function get_order(alg::Algorithm"exact", tn) - # Allow specifying either `order` or `order_alg`. - order = get(alg, :order, nothing) - order = if !isnothing(order) - order - else - default_order_alg = default_kwargs(contraction_order, tn).alg - order_alg = get(alg, :order_alg, default_order_alg) - # TODO: Capture other keyword arguments and pass them to `contraction_order`. - contraction_order(tn; alg = order_alg) - end +# `contraction_order` +function contraction_order end +default_kwargs(::typeof(contraction_order), tensors) = (; order = "eager") + +function contraction_expression(tensors; order = default_kwargs(contraction_order, tensors).order) + order = contraction_order(order, tensors) + # Contraction order may or may not have indices attached, canonicalize the format # by attaching indices. - subs = Dict(symnameddims(i) => symnameddims(i, Tuple(inds(tn[i]))) for i in keys(tn)) + subs = Dict(symnameddims(i) => symnameddims(i, tensors[i]) for i in keys(tensors)) + return substitute(order, subs) end -function contract_network(alg::Algorithm"exact", tn) - order = get_order(alg, tn) - syms_to_ts = Dict(symnameddims(i, Tuple(inds(tn[i]))) => lazy(tn[i]) for i in keys(tn)) - tn_expression = substitute(order, syms_to_ts) - return materialize(tn_expression) -end -# `contraction_order` -function contraction_order end -default_kwargs(::typeof(contraction_order), tn) = (; alg = Algorithm"eager"()) -function contraction_order(tn; alg = default_kwargs(contraction_order, tn).alg, kwargs...) - return contraction_order(to_algorithm(alg; kwargs...), tn) +contraction_order(order, tensors) = order +function contraction_order(tensors; order = default_kwargs(contraction_order, tensors).order) + return contraction_order(Algorithm(order), tensors) end # Convert the tensor network to a flat symbolic multiplication expression. -function contraction_order(alg::Algorithm"flat", tn) +function contraction_order(::Algorithm"flat", tensors) # Same as: `reduce((a, b) -> *(a, b; flatten = true), syms)`. - syms = vec([symnameddims(i, Tuple(inds(tn[i]))) for i in keys(tn)]) + syms = vec([symnameddims(i, Tuple(inds(tensors[i]))) for i in keys(tensors)]) return lazy(Mul(syms)) end -function contraction_order(alg::Algorithm"left_associative", tn) - return prod(i -> symnameddims(i, Tuple(inds(tn[i]))), keys(tn)) +function contraction_order(::Algorithm"left_associative", tensors) + return prod(i -> symnameddims(i, Tuple(inds(tensors[i]))), keys(tensors)) end -function contraction_order(alg::Algorithm, tn) - s = contraction_order(Algorithm"flat"(), tn) - return optimize_evaluation_order(s; alg) + +function contraction_order( + order_algorithm::Algorithm, + tensors, + ) + order = contraction_order(tensors; order = "flat") + return optimize_evaluation_order(order; alg = order_algorithm) end diff --git a/test/test_contract_network.jl b/test/test_contract_network.jl index c9abfdd..b5ff72e 100644 --- a/test/test_contract_network.jl +++ b/test/test_contract_network.jl @@ -14,9 +14,9 @@ using Test: @test, @testset C = ITensor([5.0, 1.0], j) D = ITensor([-2.0, 3.0, 4.0, 5.0, 1.0], k) - ABCD_1 = contract_network([A, B, C, D]; order_alg = "left_associative") - ABCD_2 = contract_network([A, B, C, D]; order_alg = "eager") - ABCD_3 = contract_network([A, B, C, D]; order_alg = "optimal") + ABCD_1 = contract_network([A, B, C, D]; alg = "left_associative") + ABCD_2 = contract_network([A, B, C, D]; alg = "eager") + ABCD_3 = contract_network([A, B, C, D]; alg = "optimal") @test ABCD_1 == ABCD_2 == ABCD_3 end @@ -31,9 +31,9 @@ using Test: @test, @testset return randn(Tuple(is)) end - z1 = contract_network(tn; order_alg = "left_associative")[] - z2 = contract_network(tn; order_alg = "eager")[] - z3 = contract_network(tn; order_alg = "optimal")[] + z1 = contract_network(tn; alg = "left_associative")[] + z2 = contract_network(tn; alg = "eager")[] + z3 = contract_network(tn; alg = "optimal")[] @test abs(z1 - z2) / abs(z1) <= 1.0e3 * eps(Float64) @test abs(z1 - z3) / abs(z1) <= 1.0e3 * eps(Float64) From 4eec9b65e4917c3feb11926ccf61207773833e2b Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 10 Feb 2026 11:50:00 -0500 Subject: [PATCH 30/86] Upgrade DataGraphs and NamedGraphs dependencies --- src/abstracttensornetwork.jl | 20 +----- .../abstractbeliefpropagationcache.jl | 19 +++--- .../beliefpropagationcache.jl | 63 ++++++++++--------- src/tensornetwork.jl | 40 +++++++++--- test/Project.toml | 4 +- 5 files changed, 79 insertions(+), 67 deletions(-) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index 08f86a1..671ba3a 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -16,7 +16,8 @@ using NamedGraphs.GraphsExtensions: incident_edges, rem_edges!, rename_vertices, - vertextype + vertextype, + similar_graph using SplitApplyCombine: flatten using NamedGraphs.SimilarType: similar_type @@ -25,7 +26,7 @@ abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} # Need to be careful about removing edges from tensor networks in case there is a bond Graphs.rem_edge!(::AbstractTensorNetwork, edge) = not_implemented() -DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = not_implemented() +DataGraphs.edge_data_type(::Type{<:AbstractTensorNetwork}) = not_implemented() # Graphs.jl overloads function Graphs.weights(graph::AbstractTensorNetwork) @@ -235,18 +236,3 @@ function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) end Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) - -function Graphs.induced_subgraph(graph::AbstractTensorNetwork{V}, subvertices::Vector{V}) where {V} - return tensornetwork_induced_subgraph(graph, subvertices) -end - -function tensornetwork_induced_subgraph(graph, subvertices) - underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices) - subgraph = similar_type(graph)(underlying_subgraph) - for v in vertices(subgraph) - if isassigned(graph, v) - set!(vertex_data(subgraph), v, graph[v]) - end - end - return subgraph, vlist -end diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 0efc95d..b77fb4e 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -1,15 +1,12 @@ using Graphs: AbstractGraph, AbstractEdge -using DataGraphs: AbstractDataGraph, edge_data, vertex_data, edge_data_eltype +using DataGraphs: AbstractDataGraph, edge_data, vertex_data, edge_data_type using NamedGraphs.GraphsExtensions: boundary_edges using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, parent messages(bp_cache::AbstractGraph) = edge_data(bp_cache) messages(bp_cache::AbstractGraph, edges) = [message(bp_cache, e) for e in edges] -function message(bp_cache::AbstractGraph, edge::AbstractEdge) - ms = messages(bp_cache) - return get!(ms, edge, default_message(bp_cache, edge)) -end +message(bp_cache::AbstractGraph, edge::AbstractEdge) = messages(bp_cache)[edge] deletemessage!(bp_cache::AbstractGraph, edge) = not_implemented() function deletemessage!(bp_cache::AbstractDataGraph, edge) @@ -52,7 +49,7 @@ factors(bpc::AbstractGraph) = vertex_data(bpc) factors(bpc::AbstractGraph, vertices::Vector) = [factor(bpc, v) for v in vertices] factors(bpc::AbstractGraph{V}, vertex::V) where {V} = factors(bpc, V[vertex]) -factor(bpc::AbstractGraph, vertex) = factors(bpc)[vertex] +factor(bpc::AbstractGraph, vertex) = bpc[vertex] setfactor!(bpc::AbstractGraph, vertex, factor) = not_implemented() function setfactor!(bpc::AbstractDataGraph, vertex, factor) @@ -75,7 +72,7 @@ end message_type(bpc::AbstractGraph) = message_type(typeof(bpc)) message_type(G::Type{<:AbstractGraph}) = eltype(Base.promote_op(messages, G)) -message_type(type::Type{<:AbstractDataGraph}) = edge_data_eltype(type) +message_type(type::Type{<:AbstractDataGraph}) = edge_data_type(type) function vertex_scalars(bp_cache::AbstractGraph, vertices = vertices(bp_cache)) return map(v -> region_scalar(bp_cache, v), vertices) @@ -117,7 +114,13 @@ end adapt_messages(to, bp_cache, es = edges(bp_cache)) = map_messages(adapt(to), bp_cache, es) adapt_factors(to, bp_cache, vs = vertices(bp_cache)) = map_factors(adapt(to), bp_cache, vs) -abstract type AbstractBeliefPropagationCache{V, ED} <: AbstractDataGraph{V, Nothing, ED} end +abstract type AbstractBeliefPropagationCache{V, VD, ED} <: AbstractDataGraph{V, VD, ED} end + +factor_type(bpc::AbstractBeliefPropagationCache) = typeof(bpc) +factor_type(::Type{<:AbstractBeliefPropagationCache{<:Any, VD}}) where {VD} = VD + +message_type(bpc::AbstractBeliefPropagationCache) = message_type(typeof(bpc)) +message_type(::Type{<:AbstractBeliefPropagationCache{<:Any, <:Any, ED}}) where {ED} = ED function free_energy(bp_cache::AbstractBeliefPropagationCache) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 27a580d..10ab586 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -2,20 +2,19 @@ using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph, - has_edge_data, get_vertex_data, get_edge_data, set_vertex_data!, set_edge_data!, - unset_vertex_data!, - unset_edge_data!, - vertex_data_eltype, - edge_data_eltype, + vertex_data_type, + edge_data_type, underlying_graph, - underlying_graph_type + underlying_graph_type, + is_vertex_assigned, + is_edge_assigned using Dictionaries: Dictionary, set!, delete! using Graphs: AbstractGraph, is_tree, connected_components, is_directed -using NamedGraphs: NamedDiGraph, convert_vertextype, parent_graph_indices +using NamedGraphs: NamedDiGraph, convert_vertextype, parent_graph_indices, Vertices using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges, @@ -25,22 +24,23 @@ using NamedGraphs.GraphsExtensions: default_root_vertex, using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, QuotientEdges, quotient_graph, quotientedges using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, lazy, parenttype -struct BeliefPropagationCache{V, G <: AbstractGraph{V}, N <: AbstractGraph{V}, ET, MT} <: - AbstractBeliefPropagationCache{V, MT} +struct BeliefPropagationCache{V, VD, ED, G <: AbstractGraph{V}, N <: AbstractGraph{V}, E} <: + AbstractBeliefPropagationCache{V, VD, ED} underlying_graph::G # we only use this for the edges. network::N - messages::Dictionary{ET, MT} + messages::Dictionary{E, ED} function BeliefPropagationCache(network::AbstractGraph, messages::Dictionary) V = vertextype(network) + VD = vertex_data_type(network) N = typeof(network) ET = keytype(messages) - MT = eltype(messages) + ED = eltype(messages) # Construct a directed graph version of the underlying graph of the tensor network. digraph = directed_graph(underlying_graph(network)) - bpc = new{V, typeof(digraph), N, ET, MT}(digraph, network, messages) + bpc = new{V, VD, ED, typeof(digraph), N, ET}(digraph, network, messages) for edge in edges(bpc) get!(() -> default_message(bpc, edge), messages, edge) @@ -53,8 +53,8 @@ network(bp_cache) = getfield(bp_cache, :network) DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = getfield(bpc, :underlying_graph) -DataGraphs.has_vertex_data(bpc::BeliefPropagationCache, vertex) = has_vertex_data(network(bpc), vertex) -DataGraphs.has_edge_data(bpc::BeliefPropagationCache, edge) = haskey(bpc.messages, edge) +DataGraphs.is_vertex_assigned(bpc::BeliefPropagationCache, vertex) = is_vertex_assigned(network(bpc), vertex) +DataGraphs.is_edge_assigned(bpc::BeliefPropagationCache, edge) = haskey(bpc.messages, edge) DataGraphs.get_vertex_data(bpc::BeliefPropagationCache, vertex) = get_vertex_data(network(bpc), vertex) DataGraphs.get_edge_data(bpc::BeliefPropagationCache, edge::AbstractEdge) = bpc.messages[edge] @@ -62,20 +62,8 @@ DataGraphs.get_edge_data(bpc::BeliefPropagationCache, edge::AbstractEdge) = bpc. DataGraphs.set_vertex_data!(bpc::BeliefPropagationCache, val, vertex) = set_vertex_data!(network(bpc), val, vertex) DataGraphs.set_edge_data!(bpc::BeliefPropagationCache, val, edge) = set!(bpc.messages, edge, val) -DataGraphs.unset_vertex_data!(bpc::BeliefPropagationCache, val, vertex) = unset_vertex_data!(network(bpc), val, vertex) -DataGraphs.unset_edge_data!(bpc::BeliefPropagationCache, val, edge) = unset!(bpc.messages, edge, val) - -function DataGraphs.vertex_data_eltype(T::Type{<:BeliefPropagationCache}) - return vertex_data_eltype(fieldtype(T, :network)) -end -function DataGraphs.edge_data_eltype(T::Type{<:BeliefPropagationCache}) - return eltype(fieldtype(T, :messages)) -end - -message_type(T::Type{<:BeliefPropagationCache}) = edge_data_eltype(T) - function BeliefPropagationCache(network::AbstractGraph) - MT = vertex_data_eltype(typeof(network)) + MT = vertex_data_type(typeof(network)) return BeliefPropagationCache(MT, network) end function BeliefPropagationCache(MT::Type, network::AbstractGraph) @@ -95,7 +83,7 @@ function forest_cover_edge_sequence(gi::AbstractGraph; root_vertex = default_roo forests = forest_cover(g) rv = edgetype(g)[] for forest in forests - trees = [forest[vs] for vs in connected_components(forest)] + trees = [forest[Vertices(vs)] for vs in connected_components(forest)] for tree in trees tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) push!(rv, vcat(tree_edges, reverse(reverse.(tree_edges)))...) @@ -106,16 +94,19 @@ end function bpcache_induced_subgraph(graph, subvertices) underlying_subgraph, vlist = Graphs.induced_subgraph(network(graph), subvertices) - subgraph = BeliefPropagationCache(underlying_subgraph, typeof(edge_data(graph))()) + + edge_data = Dictionary{edgetype(underlying_subgraph), edge_data_type(typeof(graph))}() + + subgraph = BeliefPropagationCache(underlying_subgraph, edge_data) for e in edges(subgraph) if isassigned(graph, e) - set!(edge_data(subgraph), e, graph[e]) + subgraph[e] = graph[e] end end return subgraph, vlist end -function Graphs.induced_subgraph(graph::BeliefPropagationCache{V}, subvertices::Vector{V}) where {V} +function NamedGraphs.induced_subgraph_from_vertices(graph::BeliefPropagationCache, subvertices) return bpcache_induced_subgraph(graph, subvertices) end @@ -138,3 +129,13 @@ function default_message(T::Type{<:LazyNamedDimsArray}, network, edge) message = default_message(parenttype(T), network, edge) return convert(T, lazy(message)) end + +NamedGraphs.to_graph_index(::BeliefPropagationCache, vertex::QuotientVertex) = vertex +# When getting data according the quotient vertices, take a lazy contraction. +function DataGraphs.get_index_data(tn::BeliefPropagationCache, vertex::QuotientVertex) + data = collect(map(v -> tn[v], vertices(tn, vertex))) + return mapreduce(lazy, *, data) +end +function DataGraphs.is_graph_index_assigned(tn::BeliefPropagationCache, vertex::QuotientVertex) + return isassigned(tn, Vertices(vertices(tn, vertex))) +end diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index b811e2b..0d30970 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -18,7 +18,7 @@ using NamedGraphs.PartitionedGraphs: QuotientVertexVertices, quotientvertices using .LazyNamedDimsArrays: lazy, Mul -using DataGraphs: vertex_data_eltype, vertex_data, edge_data, get_vertices_data +using DataGraphs: vertex_data_type, vertex_data, edge_data, get_vertices_data using DataGraphs.DataGraphsPartitionedGraphsExt function _TensorNetwork end @@ -52,13 +52,12 @@ end DataGraphs.underlying_graph(tn::TensorNetwork) = tn.underlying_graph -DataGraphs.has_vertex_data(tn::TensorNetwork, v) = haskey(tn.tensors, v) -DataGraphs.has_edge_data(tn::TensorNetwork, e) = false +DataGraphs.is_vertex_assigned(tn::TensorNetwork, v) = haskey(tn.tensors, v) +DataGraphs.is_edge_assigned(tn::TensorNetwork, e) = false DataGraphs.get_vertex_data(tn::TensorNetwork, v) = tn.tensors[v] DataGraphs.set_vertex_data!(tn::TensorNetwork, val, v) = set!(tn.tensors, v, val) -DataGraphs.unset_vertex_data!(tn::TensorNetwork, val, v) = unset!(tn.tensors, v, val) function DataGraphs.underlying_graph_type(type::Type{<:TensorNetwork}) return fieldtype(type, :underlying_graph) @@ -135,11 +134,30 @@ function Graphs.rem_edge!(tn::TensorNetwork, e) return true end -function GraphsExtensions.similar(type::Type{<:TensorNetwork}) +function GraphsExtensions.similar_graph(type::Type{<:TensorNetwork}) DT = fieldtype(type, :tensors) empty_dict = DT() return TensorNetwork(similar_graph(underlying_graph_type(type)), empty_dict) end +function GraphsExtensions.similar_graph(tn::TensorNetwork, underlying_graph::AbstractGraph) + DT = fieldtype(typeof(tn), :tensors) + empty_dict = DT() + return _TensorNetwork(underlying_graph, empty_dict) +end + +function NamedGraphs.induced_subgraph_from_vertices(graph::TensorNetwork, subvertices) + return tensornetwork_induced_subgraph(graph, subvertices) +end + +function tensornetwork_induced_subgraph(graph, subvertices) + underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices) + + subgraph = TensorNetwork(underlying_subgraph) do vertex + return graph[vertex] + end + + return subgraph, vlist +end ## PartitionedGraphs function PartitionedGraphs.quotient_graph(tn::TensorNetwork) @@ -154,7 +172,7 @@ end # DataGraphsPartitionedGraphsExt interface. function PartitionedGraphs.quotient_graph_type(type::Type{<:TensorNetwork}) UG = quotient_graph_type(underlying_graph_type(type)) - VD = Vector{vertex_data_eltype(type)} + VD = Vector{vertex_data_type(type)} V = vertextype(UG) return TensorNetwork{V, VD, UG, Dictionary{V, VD}} end @@ -172,14 +190,18 @@ function PartitionedGraphs.departition( return TensorNetwork(departition(underlying_graph(tn)), vertex_data(tn)) end +NamedGraphs.to_graph_index(::TensorNetwork, vertex::QuotientVertex) = vertex # When getting data according the quotient vertices, take a lazy contraction. -function DataGraphs.get_vertices_data(tn::TensorNetwork, vertex::QuotientVertexVertices) - data = collect(map(v -> tn[v], NamedGraphs.parent_graph_indices(vertex))) +function DataGraphs.get_index_data(tn::TensorNetwork, vertex::QuotientVertex) + data = collect(map(v -> tn[v], vertices(tn, vertex))) return mapreduce(lazy, *, data) end +function DataGraphs.is_graph_index_assigned(tn::TensorNetwork, vertex::QuotientVertex) + return isassigned(tn, Vertices(vertices(tn, vertex))) +end function PartitionedGraphs.quotientview(tn::TensorNetwork) qview = QuotientView(underlying_graph(tn)) - tensors = vertex_data(QuotientView(tn)) + tensors = map(qv -> vertex_data(tn)[Indices(qv)], Indices(quotientvertices(tn))) return TensorNetwork(qview, tensors) end diff --git a/test/Project.toml b/test/Project.toml index 564db3f..975c2c1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -29,8 +29,8 @@ Dictionaries = "0.4.5" Graphs = "1.13.1" ITensorBase = "0.5" ITensorNetworksNext = "0.3" -NamedDimsArrays = "0.14" -NamedGraphs = "0.6.8, 0.7, 0.8" +NamedDimsArrays = "0.13" +NamedGraphs = "0.10" QuadGK = "2.11.2" SafeTestsets = "0.1" Suppressor = "0.2.8" From 202724ca021139bf7fa5d5cd561406dd497cacd4 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 10 Feb 2026 11:57:32 -0500 Subject: [PATCH 31/86] [AlgorithmsInterfaceExtensions] Allowing mapping over a generic iterable when constructing nested algorithms --- .../AlgorithmsInterfaceExtensions.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index a8c814e..3c887b7 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -152,8 +152,8 @@ end abstract type NestedAlgorithm <: Algorithm end -function nested_algorithm(f::Function, nalgorithms::Int; kwargs...) - return DefaultNestedAlgorithm(f, nalgorithms; kwargs...) +function nested_algorithm(f::Function, iterable; kwargs...) + return DefaultNestedAlgorithm(f, iterable; kwargs...) end max_iterations(algorithm::NestedAlgorithm) = length(algorithm.algorithms) @@ -211,6 +211,9 @@ function DefaultNestedAlgorithm(f::Function, nalgorithms::Int; kwargs...) return DefaultNestedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...) end +function DefaultNestedAlgorithm(f::Function, iterable; kwargs...) + return DefaultNestedAlgorithm(; algorithms = map(f, iterable), kwargs...) +end #============================ FlattenedAlgorithm ==========================================# # Flatten a nested algorithm. From 69542e32ba7d5ad1a4b616a40822dffcd1de4c9c Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 11 Feb 2026 11:44:18 -0500 Subject: [PATCH 32/86] Upgrade serial BP to use own `<:Algorithm` structs. --- .../beliefpropagationproblem.jl | 136 +++++++++++------- 1 file changed, 87 insertions(+), 49 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 0d997ee..75023b3 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,8 +1,9 @@ using Graphs: SimpleGraph, vertices, edges, has_edge, AbstractEdge using NamedGraphs: AbstractNamedGraph, position_graph -using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices +using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices, subgraph, boundary_edges using NamedGraphs.OrderedDictionaries: OrderedDictionary, OrderedIndices using NamedDimsArrays: AbstractNamedDimsArray +using NamedGraphs.PartitionedGraphs: quotientvertices using DataGraphs: edge_data import AlgorithmsInterface as AI @@ -41,55 +42,35 @@ function AI.is_finished!( # maxdiff = 0.0 initially, so skip this the first time. if state.iteration > 0 st.delta = state.iterate.maxdiff + @info "$(state.iteration): $(st.delta)" end return st.delta < c.tol end -struct BeliefPropagationProblem{Network} <: AIE.Problem - network::Network -end +# struct BeliefPropagationProblem{Network} <: AIE.Problem +# network::Network +# end + +struct BeliefPropagationProblem <: AIE.Problem end -@kwdef mutable struct BeliefPropagationState{ - Iterate <: BeliefPropagationCache, - Diffs, - } <: AIE.NonIterativeAlgorithmState +@kwdef mutable struct BeliefPropagationState{Iterate, Diffs} <: AIE.NonIterativeAlgorithmState iterate::Iterate diffs::Diffs = similar(edge_data(iterate), Float64) maxdiff::Float64 = 0.0 end -function AI.initialize_state( - problem::BeliefPropagationProblem, - algorithm::AIE.NonIterativeAlgorithm; iterate, kwargs... - ) - - diffs = iterate.diffs - maxdiff = iterate.maxdiff - - return BeliefPropagationState(; iterate = iterate.iterate, diffs, maxdiff, kwargs...) -end - -# This gets called at the start of every sweep. -function AI.initialize_state!( - problem::BeliefPropagationProblem, - algorithm::AIE.NestedAlgorithm, - state::AIE.State, - ) - state.iterate.maxdiff = 0.0 - return state +@kwdef struct BeliefPropagation{ + ChildAlgorithm <: AIE.Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, + } <: AIE.NestedAlgorithm + algorithms::Algorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end -function AIE.set_substate!( - ::BeliefPropagationProblem, - ::AIE.NestedAlgorithm, - state::AIE.State, - substate::BeliefPropagationState - ) - - state.iterate = substate - - return state +function BeliefPropagation(f::Function, niterations::Int; kwargs...) + return BeliefPropagation(; algorithms = f.(1:niterations), kwargs...) end abstract type AbstractMessageUpdate <: AIE.NonIterativeAlgorithm end @@ -101,7 +82,7 @@ end function SimpleMessageUpdate( edge; - normalize = false, + normalize = true, contraction_alg = "eager", compute_diff = false, kwargs... @@ -117,6 +98,53 @@ function Base.getproperty(alg::SimpleMessageUpdate, name::Symbol) end end +struct BeliefPropagationSweep{ + ChildAlgorithm <: AIE.Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + } <: AIE.NestedAlgorithm + algorithms::Algorithms + stopping_criterion::AI.StopAfterIteration + function BeliefPropagationSweep(; algorithms) + stopping_criterion = AI.StopAfterIteration(length(algorithms)) + return new{eltype(algorithms), typeof(algorithms)}(algorithms, stopping_criterion) + end +end + +BeliefPropagationSweep(f::Function, edges) = BeliefPropagationSweep(; algorithms = f.(edges)) + +function AI.initialize_state( + problem::BeliefPropagationProblem, + update_algorithm::AIE.NonIterativeAlgorithm; iterate, kwargs... + ) + + diffs = iterate.diffs + maxdiff = iterate.maxdiff + + return BeliefPropagationState(; iterate = iterate.iterate, diffs, maxdiff, kwargs...) +end + +# This gets called at the start of every sweep. +function AI.initialize_state!( + ::BeliefPropagationProblem, + ::BeliefPropagationSweep, + iteration_state::AIE.State, + ) + iteration_state.iterate.maxdiff = 0.0 + return iteration_state +end + +function AIE.set_substate!( + ::BeliefPropagationProblem, + sweep_algorithm::BeliefPropagationSweep, + sweep_state::AIE.DefaultState, + noniterative_substate::BeliefPropagationState, + ) + + sweep_state.iterate = noniterative_substate + + return sweep_state +end + struct MessageUpdateProblem{Messages, Factors} <: AIE.Problem messages::Messages factors::Factors @@ -124,7 +152,7 @@ end function AI.solve!( problem::BeliefPropagationProblem, - algorithm::AbstractMessageUpdate, + algorithm::SimpleMessageUpdate, state::BeliefPropagationState; logging_context_prefix = default_logging_context_prefix(problem, algorithm), ) @@ -177,7 +205,7 @@ end function AI.solve!( problem::MessageUpdateProblem, algorithm::SimpleMessageUpdate, - state::AIE.NonIterativeAlgorithmState; + state::AIE.DefaultNonIterativeAlgorithmState; logging_context_prefix = AI.default_logging_context_prefix(problem, algorithm), kwargs... ) @@ -209,24 +237,29 @@ end beliefpropagation(network; kwargs...) = beliefpropagation(BeliefPropagationCache(network); kwargs...) function beliefpropagation(cache::AbstractBeliefPropagationCache; kwargs...) - problem = BeliefPropagationProblem(network(cache)) + # problem = BeliefPropagationProblem(network(cache)) + problem = BeliefPropagationProblem() algorithm = select_algorithm(beliefpropagation, cache; kwargs...) # The nested algorithms will wrap and manipulate this object. + base_state = BeliefPropagationState(; iterate = cache) - state = AI.solve(problem, algorithm; iterate = base_state) + state = AI.initialize_state(problem, algorithm; iterate = base_state) + + state = AI.solve!(problem, algorithm, state) return state.iterate.iterate end + function select_algorithm( ::typeof(beliefpropagation), - cache; + cache::AbstractBeliefPropagationCache; edges = forest_cover_edge_sequence(network(cache)), maxiter = is_tree(network(cache)) ? 1 : nothing, - tol = 0.0, + tol = -Inf, kwargs... ) @@ -237,7 +270,7 @@ function select_algorithm( stopping_criterion = AI.StopAfterIteration(maxiter) compute_diff = false - if tol > 0.0 + if tol > -Inf stopping_criterion = stopping_criterion | StopWhenConverged(tol) compute_diff = true end @@ -245,9 +278,14 @@ function select_algorithm( extended_kwargs = extend_columns((; compute_diff, kwargs...), maxiter) edge_kwargs = rows(extended_kwargs, len = maxiter) - return AIE.nested_algorithm(maxiter; stopping_criterion) do repnum - return AIE.nested_algorithm(length(edges)) do edgenum - return SimpleMessageUpdate(edges[edgenum]; edge_kwargs[repnum]...) - end + return BeliefPropagation(maxiter; stopping_criterion) do repnum + return beliefpropagation_sweep(cache; edges, edge_kwargs[repnum]...) + end +end + +# A single sweep across the given edges. +function beliefpropagation_sweep(cache::BeliefPropagationCache; edges, kwargs...) + return BeliefPropagationSweep(edges) do edge + return SimpleMessageUpdate(edge; kwargs...) end end From 992506900fd225d106a57e03346fd62e6f74bc80 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 13 Feb 2026 17:19:04 -0500 Subject: [PATCH 33/86] Simplify BP cache to only store factors --- src/abstracttensornetwork.jl | 26 ++-- .../beliefpropagationcache.jl | 131 +++++++++--------- .../beliefpropagationproblem.jl | 81 ++++++----- 3 files changed, 115 insertions(+), 123 deletions(-) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index 671ba3a..c4b6fcb 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -1,25 +1,17 @@ -using Adapt: Adapt, adapt, adapt_structure +using Adapt: Adapt, adapt using BackendSelection: @Algorithm_str, Algorithm -using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, underlying_graph, - underlying_graph_type, vertex_data, set_vertex_data! +using DataGraphs: AbstractDataGraph, DataGraphs, edge_data, set_vertex_data!, + underlying_graph, underlying_graph_type, vertex_data using Dictionaries: Dictionary -using Graphs: Graphs, AbstractEdge, AbstractGraph, Graph, add_edge!, add_vertex!, - bfs_tree, center, dst, edges, edgetype, ne, neighbors, nv, rem_edge!, src, vertices -using LinearAlgebra: LinearAlgebra, factorize +using Graphs: AbstractEdge, AbstractGraph, Graphs, add_edge!, add_vertex!, + dst, edges, edgetype, ne, neighbors, nv, rem_edge!, src, vertices +using LinearAlgebra: LinearAlgebra using MacroTools: @capture using NamedDimsArrays: dimnames, inds -using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree +using NamedGraphs: NamedGraph, NamedGraphs, not_implemented using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger -using NamedGraphs.GraphsExtensions: - ⊔, - directed_graph, - incident_edges, - rem_edges!, - rename_vertices, - vertextype, - similar_graph -using SplitApplyCombine: flatten -using NamedGraphs.SimilarType: similar_type +using NamedGraphs.GraphsExtensions: directed_graph, incident_edges, rem_edges!, + similar_graph, vertextype abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} end diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 10ab586..2c253e6 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -1,46 +1,29 @@ -using DataGraphs: - DataGraphs, - AbstractDataGraph, - DataGraph, - get_vertex_data, - get_edge_data, - set_vertex_data!, - set_edge_data!, - vertex_data_type, - edge_data_type, - underlying_graph, - underlying_graph_type, - is_vertex_assigned, - is_edge_assigned -using Dictionaries: Dictionary, set!, delete! -using Graphs: AbstractGraph, is_tree, connected_components, is_directed -using NamedGraphs: NamedDiGraph, convert_vertextype, parent_graph_indices, Vertices -using NamedGraphs.GraphsExtensions: default_root_vertex, - forest_cover, - post_order_dfs_edges, - vertextype, - is_path_graph, - undirected_graph -using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, QuotientEdges, quotient_graph, quotientedges +using DataGraphs: AbstractDataGraph, DataGraphs, edge_data, edge_data_type, + set_vertex_data!, underlying_graph, underlying_graph_type, vertex_data, + vertex_data_type +using Dictionaries: Dictionary, delete!, set!, getindices +using Graphs: AbstractGraph, connected_components, is_tree, is_directed using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, lazy, parenttype +using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges, undirected_graph, vertextype +using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, quotient_graph -struct BeliefPropagationCache{V, VD, ED, G <: AbstractGraph{V}, N <: AbstractGraph{V}, E} <: - AbstractBeliefPropagationCache{V, VD, ED} +using NamedGraphs: Vertices, convert_vertextype, parent_graph_indices + +struct BeliefPropagationCache{V, VD, ED, E, G <: AbstractGraph{V}} <: AbstractBeliefPropagationCache{V, VD, ED} underlying_graph::G # we only use this for the edges. - network::N + factors::Dictionary{V, VD} messages::Dictionary{E, ED} - function BeliefPropagationCache(network::AbstractGraph, messages::Dictionary) + function BeliefPropagationCache(graph::AbstractGraph, factors::Dictionary, messages::Dictionary) + # Ensure the graph is directed, if not make it directed. + digraph = is_directed(graph) ? graph : directed_graph(graph) - V = vertextype(network) - VD = vertex_data_type(network) - N = typeof(network) - ET = keytype(messages) - ED = eltype(messages) + V = keytype(factors) + VD = eltype(factors) - # Construct a directed graph version of the underlying graph of the tensor network. - digraph = directed_graph(underlying_graph(network)) + E = keytype(messages) + ED = eltype(messages) - bpc = new{V, VD, ED, typeof(digraph), N, ET}(digraph, network, messages) + bpc = new{V, VD, ED, E, typeof(digraph)}(digraph, factors, messages) for edge in edges(bpc) get!(() -> default_message(bpc, edge), messages, edge) @@ -49,30 +32,39 @@ struct BeliefPropagationCache{V, VD, ED, G <: AbstractGraph{V}, N <: AbstractGra end end -network(bp_cache) = getfield(bp_cache, :network) - -DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = getfield(bpc, :underlying_graph) +DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = bpc.underlying_graph -DataGraphs.is_vertex_assigned(bpc::BeliefPropagationCache, vertex) = is_vertex_assigned(network(bpc), vertex) +DataGraphs.is_vertex_assigned(bpc::BeliefPropagationCache, vertex) = haskey(bpc.factors, vertex) DataGraphs.is_edge_assigned(bpc::BeliefPropagationCache, edge) = haskey(bpc.messages, edge) -DataGraphs.get_vertex_data(bpc::BeliefPropagationCache, vertex) = get_vertex_data(network(bpc), vertex) +DataGraphs.get_vertex_data(bpc::BeliefPropagationCache, vertex) = bpc.factors[vertex] DataGraphs.get_edge_data(bpc::BeliefPropagationCache, edge::AbstractEdge) = bpc.messages[edge] -DataGraphs.set_vertex_data!(bpc::BeliefPropagationCache, val, vertex) = set_vertex_data!(network(bpc), val, vertex) +DataGraphs.set_vertex_data!(bpc::BeliefPropagationCache, val, vertex) = set!(bpc.factors, vertex, val) DataGraphs.set_edge_data!(bpc::BeliefPropagationCache, val, edge) = set!(bpc.messages, edge, val) +# These two methods assume `network` behaves llike a tensor network +# (could be e.g. a QuotientView) otherwise how would one know what the factors should be. function BeliefPropagationCache(network::AbstractGraph) - MT = vertex_data_type(typeof(network)) - return BeliefPropagationCache(MT, network) + graph = underlying_graph(network) + return BeliefPropagationCache(graph, copy(vertex_data(network))) end function BeliefPropagationCache(MT::Type, network::AbstractGraph) - dict = Dictionary{edgetype(network), MT}() - return BeliefPropagationCache(network, dict) + graph = underlying_graph(network) + return BeliefPropagationCache(MT, graph, copy(vertex_data(network))) +end + +function BeliefPropagationCache(graph::AbstractGraph, factors::Dictionary) + MT = vertex_data_type(typeof(graph)) + return BeliefPropagationCache(MT, graph, factors) +end +function BeliefPropagationCache(MT::Type, graph::AbstractGraph, factors::Dictionary) + messages = Dictionary{edgetype(graph), MT}() + return BeliefPropagationCache(graph, factors, messages) end function Base.copy(bp_cache::BeliefPropagationCache) - return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) + return BeliefPropagationCache(copy(bp_cache.underlying_graph), copy(bp_cache.factors), copy(bp_cache.messages)) end # TODO: This needs to go in GraphsExtensions @@ -92,41 +84,50 @@ function forest_cover_edge_sequence(gi::AbstractGraph; root_vertex = default_roo return rv end -function bpcache_induced_subgraph(graph, subvertices) - underlying_subgraph, vlist = Graphs.induced_subgraph(network(graph), subvertices) +function induced_subgraph_bpcache(graph, subvertices) + underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices) - edge_data = Dictionary{edgetype(underlying_subgraph), edge_data_type(typeof(graph))}() + assigned = v -> isassigned(graph, v) + + assigned_subvertices = Iterators.filter(assigned, subvertices) + assigned_subedges = Iterators.filter(assigned, edges(underlying_subgraph)) + + factors = getindices(vertex_data(graph), Indices(assigned_subvertices)) + messages = getindices(edge_data(graph), Indices(assigned_subedges)) + + subgraph = BeliefPropagationCache(underlying_subgraph, factors, messages) - subgraph = BeliefPropagationCache(underlying_subgraph, edge_data) - for e in edges(subgraph) - if isassigned(graph, e) - subgraph[e] = graph[e] - end - end return subgraph, vlist end function NamedGraphs.induced_subgraph_from_vertices(graph::BeliefPropagationCache, subvertices) - return bpcache_induced_subgraph(graph, subvertices) + return induced_subgraph_bpcache(graph, subvertices) end ## PartitionedGraphs +# Take a QuotientView of the underlying graph. function PartitionedGraphs.quotientview(bpc::BeliefPropagationCache) - inds = Indices(parent_graph_indices(QuotientEdges(underlying_graph(bpc)))) - data = map(e -> bpc[QuotientEdge(e)], inds) - return BeliefPropagationCache(QuotientView(network(bpc)), data) + + graph = underlying_graph(bpc) + + quotient_view = QuotientView(graph) + + factors = map(v -> bpc[QuotientVertex(v)], Indices(vertices(quotient_view))) + messages = map(e -> bpc[QuotientEdge(e)], Indices(edges(quotient_view))) + + return BeliefPropagationCache(quotient_view, factors, messages) end function default_message(bpc::BeliefPropagationCache, edge) - return default_message(message_type(bpc), network(bpc), edge) + return default_message(message_type(bpc), bpc[src(edge)], bpc[dst(edge)]) end -function default_message(T::Type, network, edge) - array = ones(Tuple(linkinds(network, edge))) +function default_message(T::Type, src, dst) + array = ones(Tuple(inds(src) ∩ inds(dst))) return convert(T, array) end -function default_message(T::Type{<:LazyNamedDimsArray}, network, edge) - message = default_message(parenttype(T), network, edge) +function default_message(T::Type{<:LazyNamedDimsArray}, src, dst) + message = default_message(parenttype(T), src, dst) return convert(T, lazy(message)) end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 75023b3..89c28df 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,10 +1,9 @@ -using Graphs: SimpleGraph, vertices, edges, has_edge, AbstractEdge -using NamedGraphs: AbstractNamedGraph, position_graph -using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices, subgraph, boundary_edges -using NamedGraphs.OrderedDictionaries: OrderedDictionary, OrderedIndices +using Graphs: AbstractEdge, edges, has_edge, vertices +using NamedGraphs.GraphsExtensions: add_edges!, boundary_edges, subgraph using NamedDimsArrays: AbstractNamedDimsArray using NamedGraphs.PartitionedGraphs: quotientvertices using DataGraphs: edge_data +using LinearAlgebra: norm, normalize import AlgorithmsInterface as AI import .AlgorithmsInterfaceExtensions as AIE @@ -42,17 +41,14 @@ function AI.is_finished!( # maxdiff = 0.0 initially, so skip this the first time. if state.iteration > 0 st.delta = state.iterate.maxdiff - @info "$(state.iteration): $(st.delta)" end return st.delta < c.tol end -# struct BeliefPropagationProblem{Network} <: AIE.Problem -# network::Network -# end - -struct BeliefPropagationProblem <: AIE.Problem end +struct BeliefPropagationProblem{Network} <: AIE.Problem + network::Network +end @kwdef mutable struct BeliefPropagationState{Iterate, Diffs} <: AIE.NonIterativeAlgorithmState iterate::Iterate @@ -113,8 +109,7 @@ end BeliefPropagationSweep(f::Function, edges) = BeliefPropagationSweep(; algorithms = f.(edges)) function AI.initialize_state( - problem::BeliefPropagationProblem, - update_algorithm::AIE.NonIterativeAlgorithm; iterate, kwargs... + ::BeliefPropagationProblem, ::AIE.NonIterativeAlgorithm; iterate, kwargs... ) diffs = iterate.diffs @@ -135,7 +130,7 @@ end function AIE.set_substate!( ::BeliefPropagationProblem, - sweep_algorithm::BeliefPropagationSweep, + ::BeliefPropagationSweep, sweep_state::AIE.DefaultState, noniterative_substate::BeliefPropagationState, ) @@ -145,16 +140,16 @@ function AIE.set_substate!( return sweep_state end -struct MessageUpdateProblem{Messages, Factors} <: AIE.Problem +struct MessageUpdateProblem{Factor, Messages} <: AIE.Problem + factor::Factor messages::Messages - factors::Factors end function AI.solve!( problem::BeliefPropagationProblem, algorithm::SimpleMessageUpdate, state::BeliefPropagationState; - logging_context_prefix = default_logging_context_prefix(problem, algorithm), + logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm), ) logger = AI.algorithm_logger() @@ -168,8 +163,8 @@ function AI.solve!( new_message = updated_message(algorithm, cache) - if algorithm.compute_diff - diff = message_diff(new_message, cache[edge]) + if !isnothing(algorithm.message_diff_function) + diff = algorithm.message_diff_function(new_message, cache[edge]) if diff > state.maxdiff state.maxdiff = diff @@ -187,7 +182,7 @@ function AI.solve!( return state end -message_diff(m1, m2) = LinearAlgebra.norm(m1 - m2) +default_message_diff_function(m1, m2) = norm(normalize(m1) - normalize(m2)) function updated_message(algorithm, cache) edge = algorithm.edge @@ -195,7 +190,7 @@ function updated_message(algorithm, cache) vertex = src(edge) messages = incoming_messages(cache, vertex; ignore_edges = typeof(edge)[reverse(edge)]) - update_problem = MessageUpdateProblem(messages, factors(cache, vertex)) + update_problem = MessageUpdateProblem(cache[vertex], messages) message_state = AI.solve(update_problem, algorithm; iterate = message(cache, edge)) @@ -206,13 +201,21 @@ function AI.solve!( problem::MessageUpdateProblem, algorithm::SimpleMessageUpdate, state::AIE.DefaultNonIterativeAlgorithmState; - logging_context_prefix = AI.default_logging_context_prefix(problem, algorithm), + logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm), kwargs... ) - # TODO: logging... + logger = AI.algorithm_logger() + + AI.emit_message( + logger, problem, algorithm, state, Symbol(logging_context_prefix, :PreUpdate) + ) + + state.iterate = contract_messages(algorithm.contraction_alg, problem.factor, problem.messages) - state.iterate = contract_messages(algorithm.contraction_alg, problem.factors, problem.messages) + AI.emit_message( + logger, problem, algorithm, state, Symbol(logging_context_prefix, :PreNormalization) + ) if algorithm.normalize # TODO: use `sum` not `norm` @@ -222,28 +225,26 @@ function AI.solve!( end end + AI.emit_message( + logger, problem, algorithm, state, Symbol(logging_context_prefix, :PostNormalization) + ) + return state end -contract_messages(alg, factors, messages) = not_implemented() -function contract_messages( - alg, - factors::Vector{<:AbstractArray}, - messages::Vector{<:AbstractArray}, - ) +function contract_messages(alg, factor::AbstractArray, messages::Vector{<:AbstractArray}) + factors = typeof(factor)[factor] return contract_network(vcat(factors, messages); alg) end -beliefpropagation(network; kwargs...) = beliefpropagation(BeliefPropagationCache(network); kwargs...) -function beliefpropagation(cache::AbstractBeliefPropagationCache; kwargs...) +beliefpropagation(network; kwargs...) = beliefpropagation(BeliefPropagationCache(network), network; kwargs...) +function beliefpropagation(cache::AbstractBeliefPropagationCache, network = nothing; kwargs...) - # problem = BeliefPropagationProblem(network(cache)) - problem = BeliefPropagationProblem() + problem = BeliefPropagationProblem(network) algorithm = select_algorithm(beliefpropagation, cache; kwargs...) # The nested algorithms will wrap and manipulate this object. - base_state = BeliefPropagationState(; iterate = cache) state = AI.initialize_state(problem, algorithm; iterate = base_state) @@ -253,13 +254,13 @@ function beliefpropagation(cache::AbstractBeliefPropagationCache; kwargs...) return state.iterate.iterate end - function select_algorithm( ::typeof(beliefpropagation), cache::AbstractBeliefPropagationCache; - edges = forest_cover_edge_sequence(network(cache)), - maxiter = is_tree(network(cache)) ? 1 : nothing, + edges = forest_cover_edge_sequence(cache), + maxiter = is_tree(cache) ? 1 : nothing, tol = -Inf, + message_diff_function = tol > -Inf ? (m1, m2) -> norm(m1 / norm(m1) - m2 / norm(m2)) : nothing, kwargs... ) @@ -268,14 +269,12 @@ function select_algorithm( end stopping_criterion = AI.StopAfterIteration(maxiter) - compute_diff = false if tol > -Inf stopping_criterion = stopping_criterion | StopWhenConverged(tol) - compute_diff = true end - extended_kwargs = extend_columns((; compute_diff, kwargs...), maxiter) + extended_kwargs = extend_columns((; message_diff_function, kwargs...), maxiter) edge_kwargs = rows(extended_kwargs, len = maxiter) return BeliefPropagation(maxiter; stopping_criterion) do repnum @@ -284,7 +283,7 @@ function select_algorithm( end # A single sweep across the given edges. -function beliefpropagation_sweep(cache::BeliefPropagationCache; edges, kwargs...) +function beliefpropagation_sweep(::BeliefPropagationCache; edges, kwargs...) return BeliefPropagationSweep(edges) do edge return SimpleMessageUpdate(edge; kwargs...) end From 292f2fa10be8626746f87148c95ea0fb0ba17ae8 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 13 Feb 2026 17:28:23 -0500 Subject: [PATCH 34/86] Upgrade to DataGraphs v0.3.1 and NamedGraphs v0.10 --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index efd1d3c..c7133ff 100644 --- a/Project.toml +++ b/Project.toml @@ -39,7 +39,7 @@ Adapt = "4.3" AlgorithmsInterface = "0.1" BackendSelection = "0.1.6" Combinatorics = "1" -DataGraphs = "0.2.7" +DataGraphs = "0.3.1" DiagonalArrays = "0.3.31" Dictionaries = "0.4.5" FunctionImplementations = "0.4" @@ -47,7 +47,7 @@ Graphs = "1.13.1" LinearAlgebra = "1.10" MacroTools = "0.5.16" NamedDimsArrays = "0.14.2" -NamedGraphs = "0.6.9, 0.7, 0.8" +NamedGraphs = "0.10" SimpleTraits = "0.9.5" SplitApplyCombine = "1.2.3" TensorOperations = "5.3.1" From 9d937aa366d7afb54ab3e918a7039606de148112 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 13 Feb 2026 17:38:37 -0500 Subject: [PATCH 35/86] Fix compat --- Project.toml | 4 ++-- test/Project.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index c7133ff..1da8abe 100644 --- a/Project.toml +++ b/Project.toml @@ -42,11 +42,11 @@ Combinatorics = "1" DataGraphs = "0.3.1" DiagonalArrays = "0.3.31" Dictionaries = "0.4.5" -FunctionImplementations = "0.4" +FunctionImplementations = "0.4.1" Graphs = "1.13.1" LinearAlgebra = "1.10" MacroTools = "0.5.16" -NamedDimsArrays = "0.14.2" +NamedDimsArrays = "0.14.3" NamedGraphs = "0.10" SimpleTraits = "0.9.5" SplitApplyCombine = "1.2.3" diff --git a/test/Project.toml b/test/Project.toml index 975c2c1..cf048b7 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -29,7 +29,7 @@ Dictionaries = "0.4.5" Graphs = "1.13.1" ITensorBase = "0.5" ITensorNetworksNext = "0.3" -NamedDimsArrays = "0.13" +NamedDimsArrays = "0.14" NamedGraphs = "0.10" QuadGK = "2.11.2" SafeTestsets = "0.1" From 5432fe28bb172ff61bb8a191b5de4604da06ef53 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 13 Feb 2026 18:08:12 -0500 Subject: [PATCH 36/86] Fix broken merge Fix broken merge --- .../beliefpropagationproblem.jl | 4 +- src/contract_network.jl | 54 +++++++++---------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 89c28df..c127655 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -79,7 +79,7 @@ end function SimpleMessageUpdate( edge; normalize = true, - contraction_alg = "eager", + contraction_alg = "exact", compute_diff = false, kwargs... ) @@ -275,7 +275,7 @@ function select_algorithm( end extended_kwargs = extend_columns((; message_diff_function, kwargs...), maxiter) - edge_kwargs = rows(extended_kwargs, len = maxiter) + edge_kwargs = rows(extended_kwargs, maxiter) return BeliefPropagation(maxiter; stopping_criterion) do repnum return beliefpropagation_sweep(cache; edges, edge_kwargs[repnum]...) diff --git a/src/contract_network.jl b/src/contract_network.jl index 4fda3a7..a8c3fc7 100644 --- a/src/contract_network.jl +++ b/src/contract_network.jl @@ -1,11 +1,27 @@ using BackendSelection: @Algorithm_str, Algorithm using Base.Broadcast: materialize -using NamedDimsArrays: inds -using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArrays, Mul, lazy, optimize_evaluation_order, +using ITensorNetworksNext.LazyNamedDimsArrays: Mul, lazy, optimize_evaluation_order, substitute, symnameddims -function contract_network(tn; alg = default_kwargs(contract_network, tn).alg) - return contract_network(alg, tn) +# This is related to `MatrixAlgebraKit.select_algorithm`. +# TODO: Define this in BackendSelection.jl. +backend_value(::Algorithm{alg}) where {alg} = alg +using BackendSelection: parameters +function merge_parameters(alg::Algorithm; kwargs...) + return Algorithm(backend_value(alg); merge(parameters(alg), kwargs)...) +end +to_algorithm(alg::Algorithm; kwargs...) = merge_parameters(alg; kwargs...) +to_algorithm(alg; kwargs...) = Algorithm(alg; kwargs...) + +# `contract_network` +function contract_network(alg::Algorithm, tn) + return throw(ArgumentError("`contract_network` algorithm `$(alg)` not implemented.")) +end +function default_kwargs(::typeof(contract_network), tn) + return (; alg = Algorithm"exact"(; order_alg = Algorithm"eager"())) +end +function contract_network(tn; alg = default_kwargs(contract_network, tn).alg, kwargs...) + return contract_network(to_algorithm(alg; kwargs...), tn) end # `contract_network(::Algorithm"exact", ...)` @@ -34,24 +50,12 @@ end # `contraction_order` function contraction_order end -default_kwargs(::typeof(contraction_order), tensors) = (; order = "eager") - -function contraction_expression(tensors; order = default_kwargs(contraction_order, tensors).order) - order = contraction_order(order, tensors) - - # Contraction order may or may not have indices attached, canonicalize the format - # by attaching indices. - subs = Dict(symnameddims(i) => symnameddims(i, tensors[i]) for i in keys(tensors)) - - return substitute(order, subs) -end - -contraction_order(order, tensors) = order -function contraction_order(tensors; order = default_kwargs(contraction_order, tensors).order) - return contraction_order(Algorithm(order), tensors) +default_kwargs(::typeof(contraction_order), tn) = (; alg = Algorithm"eager"()) +function contraction_order(tn; alg = default_kwargs(contraction_order, tn).alg, kwargs...) + return contraction_order(to_algorithm(alg; kwargs...), tn) end # Convert the tensor network to a flat symbolic multiplication expression. -function contraction_order(::Algorithm"flat", tensors) +function contraction_order(alg::Algorithm"flat", tn) # Same as: `reduce((a, b) -> *(a, b; flatten = true), syms)`. syms = vec([symnameddims(i, Tuple(axes(tn[i]))) for i in keys(tn)]) return lazy(Mul(syms)) @@ -59,11 +63,7 @@ end function contraction_order(alg::Algorithm"left_associative", tn) return prod(i -> symnameddims(i, Tuple(axes(tn[i]))), keys(tn)) end - -function contraction_order( - order_algorithm::Algorithm, - tensors, - ) - order = contraction_order(tensors; order = "flat") - return optimize_evaluation_order(order; alg = order_algorithm) +function contraction_order(alg::Algorithm, tn) + s = contraction_order(Algorithm"flat"(), tn) + return optimize_evaluation_order(s; alg) end From c916c84c19502294b77aeca61165b778ddbd66c8 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 19 Feb 2026 17:44:59 -0500 Subject: [PATCH 37/86] Bug fix; upgrade tests --- .../beliefpropagationproblem.jl | 2 +- test/Project.toml | 1 + test/test_contract_network.jl | 16 +++++++++------- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index c127655..0312843 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -232,7 +232,7 @@ function AI.solve!( return state end -function contract_messages(alg, factor::AbstractArray, messages::Vector{<:AbstractArray}) +function contract_messages(alg, factor::AbstractArray, messages) factors = typeof(factor)[factor] return contract_network(vcat(factors, messages); alg) end diff --git a/test/Project.toml b/test/Project.toml index cf048b7..8b1072a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" AlgorithmsInterface = "d1e3940c-cd12-4505-8585-b0a4b322527d" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5" DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" diff --git a/test/test_contract_network.jl b/test/test_contract_network.jl index fc863f6..35b2275 100644 --- a/test/test_contract_network.jl +++ b/test/test_contract_network.jl @@ -5,8 +5,11 @@ using ITensorBase: Index using ITensorNetworksNext: TensorNetwork, linkinds, siteinds, contract_network using TensorOperations: TensorOperations using Test: @test, @testset +using BackendSelection: @Algorithm_str, Algorithm @testset "contract_network" begin + orderalg = alg -> Algorithm"exact"(; order_alg = Algorithm(alg)) + @testset "Contract Vectors of ITensors" begin i, j, k = Index(2), Index(2), Index(5) A = [1.0 1.0; 0.5 1.0][i, j] @@ -14,10 +17,9 @@ using Test: @test, @testset C = [5.0, 1.0][j] D = [-2.0, 3.0, 4.0, 5.0, 1.0][k] - ABCD_1 = contract_network([A, B, C, D]; alg = "left_associative") - ABCD_2 = contract_network([A, B, C, D]; alg = "eager") - ABCD_3 = contract_network([A, B, C, D]; alg = "optimal") - + ABCD_1 = contract_network([A, B, C, D]; alg = orderalg("left_associative")) + ABCD_2 = contract_network([A, B, C, D]; alg = orderalg("eager")) + ABCD_3 = contract_network([A, B, C, D]; alg = orderalg("optimal")) @test ABCD_1 == ABCD_2 == ABCD_3 end @@ -31,9 +33,9 @@ using Test: @test, @testset return randn(Tuple(is)) end - z1 = contract_network(tn; alg = "left_associative")[] - z2 = contract_network(tn; alg = "eager")[] - z3 = contract_network(tn; alg = "optimal")[] + z1 = contract_network(tn; alg = orderalg("left_associative"))[] + z2 = contract_network(tn; alg = orderalg("eager"))[] + z3 = contract_network(tn; alg = orderalg("optimal"))[] @test abs(z1 - z2) / abs(z1) <= 1.0e3 * eps(Float64) @test abs(z1 - z3) / abs(z1) <= 1.0e3 * eps(Float64) From 4a511a159d298ef466108b7af250b754c6d0dc35 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 24 Feb 2026 09:41:03 -0500 Subject: [PATCH 38/86] Add 2D TN test --- test/Project.toml | 1 + test/test_beliefpropagation.jl | 64 +++++++++++++++++++++++++++------- 2 files changed, 52 insertions(+), 13 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 8b1072a..50a58c5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,6 +8,7 @@ Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" ITensorNetworksNext = "302f2e75-49f0-4526-aef7-d8ba550cb06c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 8c7829b..8a817b2 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -1,22 +1,48 @@ -using Dictionaries: Dictionary -using ITensorBase: Index +using Dictionaries: Dictionary, set! +using ITensorBase: Index, ITensor, prime, noprime using ITensorNetworksNext: BeliefPropagationCache, ITensorNetworksNext, TensorNetwork, - adapt_messages, - default_message, - default_messages, - edge_scalars, - factors, - messages, - partitionfunction, - setmessages! -using Graphs: edges, vertices + partitionfunction +using DiagonalArrays: δ +using Graphs: src, dst, edges, vertices, AbstractGraph using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree -using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges +using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges, vertextype using Test: @test, @testset +using LinearAlgebra: LinearAlgebra +using NamedDimsArrays: name, inds +function ising_tensornetwork(g::AbstractGraph, β::Real; h = 0.0) + links = Dictionary(edges(g), [Index(2; tags = "edge" => "e$(src(e))_$(dst(e))") for e in edges(g)]) + links = merge(links, Dictionary(reverse.(edges(g)), [links[e] for e in edges(g)])) + + # symmetric sqrt of Boltzmann matrix W = exp(β σσ') + sqrt_Ws = Dictionary() + for e in edges(g) + W = [ exp(-(β + 2 * h)) exp(β); exp(β) exp(-(β - 2 * h)) ] + + F = LinearAlgebra.svd(W) + U, S, V = F.U, F.S, F.Vt + @assert U * LinearAlgebra.diagm(S) * V ≈ W + id = [1.0 0.0; 0.0 1.0] + set!(sqrt_Ws, e, id) + set!(sqrt_Ws, reverse(e), U * LinearAlgebra.diagm(S) * V) + end + ts = Dictionary{vertextype(g), ITensor}() + for v in vertices(g) + es = incident_edges(g, v; dir = :in) + #t = ITensor(1.0, physical_inds[v]...) * delta([links[e] for e in es]) + t = δ(Float64, Tuple([links[e] for e in es])) + for e in es + t_prime = ITensor(sqrt_Ws[e], (name(links[e]), name(prime(links[e])))) * t + newinds = noprime.(inds(t_prime)) + t = ITensor(parent(t_prime), name.(newinds)) + end + set!(ts, v, t) + end + return TensorNetwork(g, ts) +end @testset "BeliefPropagation" begin #Chain of tensors @@ -49,5 +75,17 @@ using Test: @test, @testset bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) z_bp = partitionfunction(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] - @test z_bp ≈ z_exact atol = 1.0e-12 + @test z_bp ≈ z_exact atol = 1.0e-10 + + #Square lattice Ising model + dims = (3, 3) + g = named_grid(dims) + tn = ising_tensornetwork(g, 0.05, h = 0.5) + bpc = ITensorNetworksNext.BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 50, tol = 1.0e-10) + + z_bp = partitionfunction(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test z_bp ≈ z_exact rtol = 1.0e-4 + end From 5b97af3a6b5a219c09b6d7db9e40022ab398bb51 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 24 Feb 2026 09:47:03 -0500 Subject: [PATCH 39/86] Formatting --- docs/make.jl | 9 +-- docs/make_index.jl | 4 +- docs/make_readme.jl | 4 +- .../ITensorNetworksNextTensorOperationsExt.jl | 4 +- .../AlgorithmsInterfaceExtensions.jl | 41 ++++-------- src/LazyNamedDimsArrays/symbolicarray.jl | 8 ++- src/TensorNetworkGenerators/delta_network.jl | 2 +- src/TensorNetworkGenerators/ising_network.jl | 2 +- src/abstracttensornetwork.jl | 16 ++--- .../abstractbeliefpropagationcache.jl | 13 ++-- .../beliefpropagationcache.jl | 58 +++++++++++----- .../beliefpropagationproblem.jl | 66 +++++++++++-------- src/contract_network.jl | 4 +- src/sweeping/eigenproblem.jl | 2 +- src/tensornetwork.jl | 47 ++++++------- test/runtests.jl | 15 +++-- test/test_algorithmsinterfaceextensions.jl | 14 ++-- test/test_aqua.jl | 2 +- test/test_basics.jl | 2 +- test/test_beliefpropagation.jl | 25 ++++--- test/test_contract_network.jl | 6 +- test/test_dmrg.jl | 4 +- test/test_lazynameddimsarrays.jl | 8 +-- test/test_tensornetworkgenerators.jl | 2 +- 24 files changed, 195 insertions(+), 163 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 1b29518..c4f46f3 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,5 +1,5 @@ -using ITensorNetworksNext: ITensorNetworksNext using Documenter: Documenter, DocMeta, deploydocs, makedocs +using ITensorNetworksNext: ITensorNetworksNext DocMeta.setdocmeta!( ITensorNetworksNext, :DocTestSetup, :(using ITensorNetworksNext); recursive = true @@ -14,11 +14,12 @@ makedocs(; format = Documenter.HTML(; canonical = "https://itensor.github.io/ITensorNetworksNext.jl", edit_link = "main", - assets = ["assets/favicon.ico", "assets/extras.css"], + assets = ["assets/favicon.ico", "assets/extras.css"] ), - pages = ["Home" => "index.md", "Reference" => "reference.md"], + pages = ["Home" => "index.md", "Reference" => "reference.md"] ) deploydocs(; - repo = "github.com/ITensor/ITensorNetworksNext.jl", devbranch = "main", push_preview = true + repo = "github.com/ITensor/ITensorNetworksNext.jl", devbranch = "main", + push_preview = true ) diff --git a/docs/make_index.jl b/docs/make_index.jl index 038bc87..af08861 100644 --- a/docs/make_index.jl +++ b/docs/make_index.jl @@ -1,5 +1,5 @@ -using Literate: Literate using ITensorNetworksNext: ITensorNetworksNext +using Literate: Literate function ccq_logo(content) include_ccq_logo = """ @@ -17,5 +17,5 @@ Literate.markdown( joinpath(pkgdir(ITensorNetworksNext), "docs", "src"); flavor = Literate.DocumenterFlavor(), name = "index", - postprocess = ccq_logo, + postprocess = ccq_logo ) diff --git a/docs/make_readme.jl b/docs/make_readme.jl index 088dc58..52d0dbb 100644 --- a/docs/make_readme.jl +++ b/docs/make_readme.jl @@ -1,5 +1,5 @@ -using Literate: Literate using ITensorNetworksNext: ITensorNetworksNext +using Literate: Literate function ccq_logo(content) include_ccq_logo = """ @@ -17,5 +17,5 @@ Literate.markdown( joinpath(pkgdir(ITensorNetworksNext)); flavor = Literate.CommonMarkFlavor(), name = "README", - postprocess = ccq_logo, + postprocess = ccq_logo ) diff --git a/ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl b/ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl index 4766ee6..972b11e 100644 --- a/ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl +++ b/ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl @@ -1,9 +1,9 @@ module ITensorNetworksNextTensorOperationsExt using BackendSelection: @Algorithm_str, Algorithm -using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArrays, ismul, symnameddims, - substitute using ITensorNetworksNext.LazyNamedDimsArrays.TermInterface: arguments +using ITensorNetworksNext.LazyNamedDimsArrays: + LazyNamedDimsArrays, ismul, substitute, symnameddims using NamedDimsArrays: inds using TensorOperations: TensorOperations, optimaltree diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index 3c887b7..69a4a97 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -1,8 +1,6 @@ module AlgorithmsInterfaceExtensions -import AlgorithmsInterface as AI - -#========================== Patches for AlgorithmsInterface.jl ============================# +import AlgorithmsInterface as AI #========================== Patches for AlgorithmsInterface.jl ============================# abstract type Problem <: AI.Problem end abstract type Algorithm <: AI.Algorithm end @@ -28,9 +26,7 @@ function AI.initialize_state( problem, algorithm, algorithm.stopping_criterion ) return DefaultState(; stopping_criterion_state, kwargs...) -end - -#============================ DefaultState ================================================# +end #============================ DefaultState ================================================# @kwdef mutable struct DefaultState{ Iterate, StoppingCriterionState <: AI.StoppingCriterionState, @@ -38,16 +34,12 @@ end iterate::Iterate iteration::Int = 0 stopping_criterion_state::StoppingCriterionState -end - -#============================ increment! ==================================================# +end #============================ increment! ==================================================# # Custom version of `increment!` that also takes the problem and algorithm as arguments. function AI.increment!(problem::Problem, algorithm::Algorithm, state::State) return AI.increment!(state) -end - -#============================ solve! ======================================================# +end #============================ solve! ======================================================# # Custom version of `solve!` that allows specifying the logger and also overloads # `increment!` on the problem and algorithm. @@ -58,13 +50,13 @@ default_logging_context_prefix(x) = Symbol(basetypenameof(x), :_) function default_logging_context_prefix(problem::Problem, algorithm::Algorithm) return Symbol( default_logging_context_prefix(problem), - default_logging_context_prefix(algorithm), + default_logging_context_prefix(algorithm) ) end function AI.solve!( problem::Problem, algorithm::Algorithm, state::State; logging_context_prefix = default_logging_context_prefix(problem, algorithm), - kwargs..., + kwargs... ) logger = AI.algorithm_logger() @@ -97,13 +89,11 @@ end function AI.solve( problem::Problem, algorithm::Algorithm; logging_context_prefix = default_logging_context_prefix(problem, algorithm), - kwargs..., + kwargs... ) state = AI.initialize_state(problem, algorithm; kwargs...) return AI.solve!(problem, algorithm, state; logging_context_prefix, kwargs...) -end - -#============================ AlgorithmIterator ===========================================# +end #============================ AlgorithmIterator ===========================================# abstract type AlgorithmIterator end @@ -136,9 +126,7 @@ struct DefaultAlgorithmIterator{Problem, Algorithm, State} <: AlgorithmIterator problem::Problem algorithm::Algorithm state::State -end - -#============================ with_algorithmlogger ========================================# +end #============================ with_algorithmlogger ========================================# # Allow passing functions, not just CallbackActions. @inline function with_algorithmlogger(f, args::Pair{Symbol, AI.LoggingAction}...) @@ -146,9 +134,7 @@ end end @inline function with_algorithmlogger(f, args::Pair{Symbol}...) return AI.with_algorithmlogger(f, (first.(args) .=> AI.CallbackAction.(last.(args)))...) -end - -#============================ NestedAlgorithm =============================================# +end #============================ NestedAlgorithm =============================================# abstract type NestedAlgorithm <: Algorithm end @@ -213,8 +199,7 @@ end function DefaultNestedAlgorithm(f::Function, iterable; kwargs...) return DefaultNestedAlgorithm(; algorithms = map(f, iterable), kwargs...) -end -#============================ FlattenedAlgorithm ==========================================# +end #============================ FlattenedAlgorithm ==========================================# # Flatten a nested algorithm. abstract type FlattenedAlgorithm <: Algorithm end @@ -284,9 +269,7 @@ end parent_iteration::Int = 1 child_iteration::Int = 0 stopping_criterion_state::StoppingCriterionState -end - -#============================ NonIterativeAlgorithm =======================================# +end #============================ NonIterativeAlgorithm =======================================# # Algorithm that only performs a single step. abstract type NonIterativeAlgorithm <: Algorithm end diff --git a/src/LazyNamedDimsArrays/symbolicarray.jl b/src/LazyNamedDimsArrays/symbolicarray.jl index a0922fd..e3ff4d4 100644 --- a/src/LazyNamedDimsArrays/symbolicarray.jl +++ b/src/LazyNamedDimsArrays/symbolicarray.jl @@ -1,8 +1,12 @@ # TODO: Allow dynamic/unknown number of dimensions by supporting vector axes. -struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: AbstractArray{T, N} +struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: + AbstractArray{T, N} name::Name axes::Axes - function SymbolicArray{T}(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) where {T} + function SymbolicArray{T}( + name, + ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}} + ) where {T} N = length(ax) return new{T, N, typeof(name), typeof(ax)}(name, ax) end diff --git a/src/TensorNetworkGenerators/delta_network.jl b/src/TensorNetworkGenerators/delta_network.jl index 8b28def..e6a453c 100644 --- a/src/TensorNetworkGenerators/delta_network.jl +++ b/src/TensorNetworkGenerators/delta_network.jl @@ -1,6 +1,6 @@ +using ..ITensorNetworksNext: TensorNetwork using DiagonalArrays: δ using Graphs: AbstractGraph -using ..ITensorNetworksNext: TensorNetwork using NamedGraphs.GraphsExtensions: incident_edges """ diff --git a/src/TensorNetworkGenerators/ising_network.jl b/src/TensorNetworkGenerators/ising_network.jl index 1f2fa31..e37551c 100644 --- a/src/TensorNetworkGenerators/ising_network.jl +++ b/src/TensorNetworkGenerators/ising_network.jl @@ -1,6 +1,6 @@ +using ..ITensorNetworksNext: @preserve_graph using DiagonalArrays: DiagonalArray using Graphs: degree, dst, edges, src -using ..ITensorNetworksNext: @preserve_graph using LinearAlgebra: Diagonal, eigen using NamedDimsArrays: apply, denamed, name, operator, randname using NamedGraphs.GraphsExtensions: vertextype diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index c4b6fcb..7fca799 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -1,17 +1,17 @@ using Adapt: Adapt, adapt using BackendSelection: @Algorithm_str, Algorithm -using DataGraphs: AbstractDataGraph, DataGraphs, edge_data, set_vertex_data!, +using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, set_vertex_data!, underlying_graph, underlying_graph_type, vertex_data using Dictionaries: Dictionary -using Graphs: AbstractEdge, AbstractGraph, Graphs, add_edge!, add_vertex!, - dst, edges, edgetype, ne, neighbors, nv, rem_edge!, src, vertices +using Graphs: Graphs, AbstractEdge, AbstractGraph, add_edge!, add_vertex!, dst, edges, + edgetype, ne, neighbors, nv, rem_edge!, src, vertices using LinearAlgebra: LinearAlgebra using MacroTools: @capture using NamedDimsArrays: dimnames, inds -using NamedGraphs: NamedGraph, NamedGraphs, not_implemented +using NamedGraphs.GraphsExtensions: + directed_graph, incident_edges, rem_edges!, similar_graph, vertextype using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger -using NamedGraphs.GraphsExtensions: directed_graph, incident_edges, rem_edges!, - similar_graph, vertextype +using NamedGraphs: NamedGraphs, NamedGraph, not_implemented abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} end @@ -125,7 +125,7 @@ is_assignment_expr(expr) = false macro preserve_graph(expr) if !is_setindex!_expr(expr) error( - "preserve_graph must be used with setindex! syntax (as @preserve_graph a[i,j,...] = value)", + "preserve_graph must be used with setindex! syntax (as @preserve_graph a[i,j,...] = value)" ) end @capture(expr, array_[indices__] = value_) @@ -207,7 +207,7 @@ Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) = not_implemented() function Base.setindex!( tn::AbstractTensorNetwork, value, - edge::Pair{<:OrdinalSuffixedInteger, <:OrdinalSuffixedInteger}, + edge::Pair{<:OrdinalSuffixedInteger, <:OrdinalSuffixedInteger} ) return not_implemented() end diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index b77fb4e..33f185b 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -1,7 +1,7 @@ -using Graphs: AbstractGraph, AbstractEdge -using DataGraphs: AbstractDataGraph, edge_data, vertex_data, edge_data_type +using DataGraphs: AbstractDataGraph, edge_data, edge_data_type, vertex_data +using Graphs: AbstractEdge, AbstractGraph using NamedGraphs.GraphsExtensions: boundary_edges -using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, parent +using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, parent messages(bp_cache::AbstractGraph) = edge_data(bp_cache) messages(bp_cache::AbstractGraph, edges) = [message(bp_cache, e) for e in edges] @@ -63,7 +63,6 @@ function region_scalar(bp_cache::AbstractGraph, edge::AbstractEdge) end function region_scalar(bp_cache::AbstractGraph, vertex) - messages = incoming_messages(bp_cache, vertex) state = factors(bp_cache, vertex) @@ -78,7 +77,10 @@ function vertex_scalars(bp_cache::AbstractGraph, vertices = vertices(bp_cache)) return map(v -> region_scalar(bp_cache, v), vertices) end -function edge_scalars(bp_cache::AbstractGraph, edges = edges(undirected_graph(underlying_graph(bp_cache)))) +function edge_scalars( + bp_cache::AbstractGraph, + edges = edges(undirected_graph(underlying_graph(bp_cache))) + ) return map(e -> region_scalar(bp_cache, e), edges) end @@ -123,7 +125,6 @@ message_type(bpc::AbstractBeliefPropagationCache) = message_type(typeof(bpc)) message_type(::Type{<:AbstractBeliefPropagationCache{<:Any, <:Any, ED}}) where {ED} = ED function free_energy(bp_cache::AbstractBeliefPropagationCache) - numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache) if any(t -> real(t) < 0, numerator_terms) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 2c253e6..5d1a31c 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -1,19 +1,23 @@ -using DataGraphs: AbstractDataGraph, DataGraphs, edge_data, edge_data_type, - set_vertex_data!, underlying_graph, underlying_graph_type, vertex_data, - vertex_data_type -using Dictionaries: Dictionary, delete!, set!, getindices -using Graphs: AbstractGraph, connected_components, is_tree, is_directed +using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, edge_data_type, + set_vertex_data!, underlying_graph, underlying_graph_type, vertex_data, vertex_data_type +using Dictionaries: Dictionary, delete!, getindices, set! +using Graphs: AbstractGraph, connected_components, is_directed, is_tree using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, lazy, parenttype -using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges, undirected_graph, vertextype +using NamedGraphs.GraphsExtensions: + default_root_vertex, forest_cover, post_order_dfs_edges, undirected_graph, vertextype using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, quotient_graph - using NamedGraphs: Vertices, convert_vertextype, parent_graph_indices -struct BeliefPropagationCache{V, VD, ED, E, G <: AbstractGraph{V}} <: AbstractBeliefPropagationCache{V, VD, ED} +struct BeliefPropagationCache{V, VD, ED, E, G <: AbstractGraph{V}} <: + AbstractBeliefPropagationCache{V, VD, ED} underlying_graph::G # we only use this for the edges. factors::Dictionary{V, VD} messages::Dictionary{E, ED} - function BeliefPropagationCache(graph::AbstractGraph, factors::Dictionary, messages::Dictionary) + function BeliefPropagationCache( + graph::AbstractGraph, + factors::Dictionary, + messages::Dictionary + ) # Ensure the graph is directed, if not make it directed. digraph = is_directed(graph) ? graph : directed_graph(graph) @@ -34,14 +38,22 @@ end DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = bpc.underlying_graph -DataGraphs.is_vertex_assigned(bpc::BeliefPropagationCache, vertex) = haskey(bpc.factors, vertex) +function DataGraphs.is_vertex_assigned(bpc::BeliefPropagationCache, vertex) + return haskey(bpc.factors, vertex) +end DataGraphs.is_edge_assigned(bpc::BeliefPropagationCache, edge) = haskey(bpc.messages, edge) DataGraphs.get_vertex_data(bpc::BeliefPropagationCache, vertex) = bpc.factors[vertex] -DataGraphs.get_edge_data(bpc::BeliefPropagationCache, edge::AbstractEdge) = bpc.messages[edge] +function DataGraphs.get_edge_data(bpc::BeliefPropagationCache, edge::AbstractEdge) + return bpc.messages[edge] +end -DataGraphs.set_vertex_data!(bpc::BeliefPropagationCache, val, vertex) = set!(bpc.factors, vertex, val) -DataGraphs.set_edge_data!(bpc::BeliefPropagationCache, val, edge) = set!(bpc.messages, edge, val) +function DataGraphs.set_vertex_data!(bpc::BeliefPropagationCache, val, vertex) + return set!(bpc.factors, vertex, val) +end +function DataGraphs.set_edge_data!(bpc::BeliefPropagationCache, val, edge) + return set!(bpc.messages, edge, val) +end # These two methods assume `network` behaves llike a tensor network # (could be e.g. a QuotientView) otherwise how would one know what the factors should be. @@ -64,7 +76,11 @@ function BeliefPropagationCache(MT::Type, graph::AbstractGraph, factors::Diction end function Base.copy(bp_cache::BeliefPropagationCache) - return BeliefPropagationCache(copy(bp_cache.underlying_graph), copy(bp_cache.factors), copy(bp_cache.messages)) + return BeliefPropagationCache( + copy(bp_cache.underlying_graph), + copy(bp_cache.factors), + copy(bp_cache.messages) + ) end # TODO: This needs to go in GraphsExtensions @@ -85,7 +101,8 @@ function forest_cover_edge_sequence(gi::AbstractGraph; root_vertex = default_roo end function induced_subgraph_bpcache(graph, subvertices) - underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices) + underlying_subgraph, vlist = + Graphs.induced_subgraph(underlying_graph(graph), subvertices) assigned = v -> isassigned(graph, v) @@ -100,7 +117,10 @@ function induced_subgraph_bpcache(graph, subvertices) return subgraph, vlist end -function NamedGraphs.induced_subgraph_from_vertices(graph::BeliefPropagationCache, subvertices) +function NamedGraphs.induced_subgraph_from_vertices( + graph::BeliefPropagationCache, + subvertices + ) return induced_subgraph_bpcache(graph, subvertices) end @@ -108,7 +128,6 @@ end # Take a QuotientView of the underlying graph. function PartitionedGraphs.quotientview(bpc::BeliefPropagationCache) - graph = underlying_graph(bpc) quotient_view = QuotientView(graph) @@ -137,6 +156,9 @@ function DataGraphs.get_index_data(tn::BeliefPropagationCache, vertex::QuotientV data = collect(map(v -> tn[v], vertices(tn, vertex))) return mapreduce(lazy, *, data) end -function DataGraphs.is_graph_index_assigned(tn::BeliefPropagationCache, vertex::QuotientVertex) +function DataGraphs.is_graph_index_assigned( + tn::BeliefPropagationCache, + vertex::QuotientVertex + ) return isassigned(tn, Vertices(vertices(tn, vertex))) end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 0312843..1a62792 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,12 +1,11 @@ +import .AlgorithmsInterfaceExtensions as AIE +import AlgorithmsInterface as AI +using DataGraphs: edge_data using Graphs: AbstractEdge, edges, has_edge, vertices -using NamedGraphs.GraphsExtensions: add_edges!, boundary_edges, subgraph +using LinearAlgebra: norm, normalize using NamedDimsArrays: AbstractNamedDimsArray +using NamedGraphs.GraphsExtensions: add_edges!, boundary_edges, subgraph using NamedGraphs.PartitionedGraphs: quotientvertices -using DataGraphs: edge_data -using LinearAlgebra: norm, normalize - -import AlgorithmsInterface as AI -import .AlgorithmsInterfaceExtensions as AIE @kwdef struct StopWhenConverged <: AI.StoppingCriterion tol::Float64 = 0.0 @@ -24,7 +23,7 @@ function AI.initialize_state!( ::AIE.Problem, ::AIE.Algorithm, ::StopWhenConverged, - st::StopWhenConvergedState, + st::StopWhenConvergedState ) st.delta = Inf return st @@ -35,7 +34,7 @@ function AI.is_finished!( ::AIE.Algorithm, state::AIE.State, c::StopWhenConverged, - st::StopWhenConvergedState, + st::StopWhenConvergedState ) # maxdiff = 0.0 initially, so skip this the first time. @@ -50,7 +49,8 @@ struct BeliefPropagationProblem{Network} <: AIE.Problem network::Network end -@kwdef mutable struct BeliefPropagationState{Iterate, Diffs} <: AIE.NonIterativeAlgorithmState +@kwdef mutable struct BeliefPropagationState{Iterate, Diffs} <: + AIE.NonIterativeAlgorithmState iterate::Iterate diffs::Diffs = similar(edge_data(iterate), Float64) maxdiff::Float64 = 0.0 @@ -83,7 +83,10 @@ function SimpleMessageUpdate( compute_diff = false, kwargs... ) - return SimpleMessageUpdate(edge, (; normalize, contraction_alg, compute_diff, kwargs...)) + return SimpleMessageUpdate( + edge, + (; normalize, contraction_alg, compute_diff, kwargs...) + ) end function Base.getproperty(alg::SimpleMessageUpdate, name::Symbol) @@ -106,12 +109,13 @@ struct BeliefPropagationSweep{ end end -BeliefPropagationSweep(f::Function, edges) = BeliefPropagationSweep(; algorithms = f.(edges)) +function BeliefPropagationSweep(f::Function, edges) + return BeliefPropagationSweep(; algorithms = f.(edges)) +end function AI.initialize_state( ::BeliefPropagationProblem, ::AIE.NonIterativeAlgorithm; iterate, kwargs... ) - diffs = iterate.diffs maxdiff = iterate.maxdiff @@ -122,7 +126,7 @@ end function AI.initialize_state!( ::BeliefPropagationProblem, ::BeliefPropagationSweep, - iteration_state::AIE.State, + iteration_state::AIE.State ) iteration_state.iterate.maxdiff = 0.0 return iteration_state @@ -132,9 +136,8 @@ function AIE.set_substate!( ::BeliefPropagationProblem, ::BeliefPropagationSweep, sweep_state::AIE.DefaultState, - noniterative_substate::BeliefPropagationState, + noniterative_substate::BeliefPropagationState ) - sweep_state.iterate = noniterative_substate return sweep_state @@ -149,9 +152,8 @@ function AI.solve!( problem::BeliefPropagationProblem, algorithm::SimpleMessageUpdate, state::BeliefPropagationState; - logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm), + logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm) ) - logger = AI.algorithm_logger() cache = state.iterate @@ -204,17 +206,20 @@ function AI.solve!( logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm), kwargs... ) - logger = AI.algorithm_logger() AI.emit_message( logger, problem, algorithm, state, Symbol(logging_context_prefix, :PreUpdate) ) - state.iterate = contract_messages(algorithm.contraction_alg, problem.factor, problem.messages) + state.iterate = + contract_messages(algorithm.contraction_alg, problem.factor, problem.messages) AI.emit_message( - logger, problem, algorithm, state, Symbol(logging_context_prefix, :PreNormalization) + logger, problem, algorithm, state, Symbol( + logging_context_prefix, + :PreNormalization + ) ) if algorithm.normalize @@ -226,7 +231,8 @@ function AI.solve!( end AI.emit_message( - logger, problem, algorithm, state, Symbol(logging_context_prefix, :PostNormalization) + logger, problem, algorithm, state, + Symbol(logging_context_prefix, :PostNormalization) ) return state @@ -237,9 +243,14 @@ function contract_messages(alg, factor::AbstractArray, messages) return contract_network(vcat(factors, messages); alg) end -beliefpropagation(network; kwargs...) = beliefpropagation(BeliefPropagationCache(network), network; kwargs...) -function beliefpropagation(cache::AbstractBeliefPropagationCache, network = nothing; kwargs...) - +function beliefpropagation(network; kwargs...) + return beliefpropagation(BeliefPropagationCache(network), network; kwargs...) +end +function beliefpropagation( + cache::AbstractBeliefPropagationCache, + network = nothing; + kwargs... + ) problem = BeliefPropagationProblem(network) algorithm = select_algorithm(beliefpropagation, cache; kwargs...) @@ -260,10 +271,13 @@ function select_algorithm( edges = forest_cover_edge_sequence(cache), maxiter = is_tree(cache) ? 1 : nothing, tol = -Inf, - message_diff_function = tol > -Inf ? (m1, m2) -> norm(m1 / norm(m1) - m2 / norm(m2)) : nothing, + message_diff_function = if tol > -Inf + (m1, m2) -> norm(m1 / norm(m1) - m2 / norm(m2)) + else + nothing + end, kwargs... ) - if isnothing(maxiter) throw(ArgumentError("`maxiter` must be specified for non-tree graphs")) end diff --git a/src/contract_network.jl b/src/contract_network.jl index a8c3fc7..9db4c32 100644 --- a/src/contract_network.jl +++ b/src/contract_network.jl @@ -1,7 +1,7 @@ using BackendSelection: @Algorithm_str, Algorithm using Base.Broadcast: materialize -using ITensorNetworksNext.LazyNamedDimsArrays: Mul, lazy, optimize_evaluation_order, - substitute, symnameddims +using ITensorNetworksNext.LazyNamedDimsArrays: + Mul, lazy, optimize_evaluation_order, substitute, symnameddims # This is related to `MatrixAlgebraKit.select_algorithm`. # TODO: Define this in BackendSelection.jl. diff --git a/src/sweeping/eigenproblem.jl b/src/sweeping/eigenproblem.jl index 36978b2..8fefbd0 100644 --- a/src/sweeping/eigenproblem.jl +++ b/src/sweeping/eigenproblem.jl @@ -1,5 +1,5 @@ -import AlgorithmsInterface as AI import .AlgorithmsInterfaceExtensions as AIE +import AlgorithmsInterface as AI function dmrg(operator, algorithm, state) problem = EigenProblem(operator) diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 0d30970..a371373 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -1,25 +1,19 @@ +using .LazyNamedDimsArrays: Mul, lazy using Combinatorics: combinations -using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph +using DataGraphs.DataGraphsPartitionedGraphsExt +using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph, edge_data, get_vertices_data, + vertex_data, vertex_data_type using Dictionaries: AbstractDictionary, Indices, dictionary, set!, unset! -using Graphs: AbstractSimpleGraph, rem_vertex!, rem_edge! +using Graphs: AbstractSimpleGraph, rem_edge!, rem_vertex! using NamedDimsArrays: AbstractNamedDimsArray, dimnames -using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype, Vertices, parent_graph_indices -using NamedGraphs.GraphsExtensions: GraphsExtensions, arranged_edges, arrange_edge, vertextype -using NamedGraphs.PartitionedGraphs: - AbstractPartitionedGraph, - PartitionedGraphs, - departition, - partitioned_vertices, - partitionedgraph, - quotient_graph, - quotient_graph_type, - QuotientVertex, - QuotientVertices, - QuotientVertexVertices, +using NamedGraphs.GraphsExtensions: + GraphsExtensions, arrange_edge, arranged_edges, vertextype +using NamedGraphs.PartitionedGraphs: AbstractPartitionedGraph, PartitionedGraphs, + QuotientVertex, QuotientVertexVertices, QuotientVertices, departition, + partitioned_vertices, partitionedgraph, quotient_graph, quotient_graph_type, quotientvertices -using .LazyNamedDimsArrays: lazy, Mul -using DataGraphs: vertex_data_type, vertex_data, edge_data, get_vertices_data -using DataGraphs.DataGraphsPartitionedGraphsExt +using NamedGraphs: + NamedGraphs, NamedEdge, NamedGraph, Vertices, parent_graph_indices, vertextype function _TensorNetwork end @@ -44,7 +38,9 @@ function TensorNetwork(graph::AbstractGraph, tensors::AbstractDictionary) return tn end -function TensorNetwork{V, VD, UG, Tensors}(graph::UG) where {V, VD, UG <: AbstractGraph{V}, Tensors} +function TensorNetwork{V, VD, UG, Tensors}( + graph::UG + ) where {V, VD, UG <: AbstractGraph{V}, Tensors} return _TensorNetwork(graph, Tensors()) end @@ -121,14 +117,20 @@ end NamedGraphs.convert_vertextype(::Type{V}, tn::TensorNetwork{V}) where {V} = tn NamedGraphs.convert_vertextype(V::Type, tn::TensorNetwork) = TensorNetwork{V}(tn) -Graphs.connected_components(tn::TensorNetwork) = Graphs.connected_components(underlying_graph(tn)) +function Graphs.connected_components(tn::TensorNetwork) + return Graphs.connected_components(underlying_graph(tn)) +end function Graphs.rem_edge!(tn::TensorNetwork, e) if !has_edge(underlying_graph(tn), e) return false end if !isempty(linkinds(tn, e)) - throw(ArgumentError("cannot remove edge $e due to tensor indices existing on this edge.")) + throw( + ArgumentError( + "cannot remove edge $e due to tensor indices existing on this edge." + ) + ) end rem_edge!(underlying_graph(tn), e) return true @@ -150,7 +152,8 @@ function NamedGraphs.induced_subgraph_from_vertices(graph::TensorNetwork, subver end function tensornetwork_induced_subgraph(graph, subvertices) - underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices) + underlying_subgraph, vlist = + Graphs.induced_subgraph(underlying_graph(graph), subvertices) subgraph = TensorNetwork(underlying_subgraph) do vertex return graph[vertex] diff --git a/test/runtests.jl b/test/runtests.jl index 0008050..16689fa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,14 +10,19 @@ const GROUP = uppercase( get(ENV, "GROUP", "ALL") else only(match(pat, ARGS[arg_id]).captures) - end, + end ) -"match files of the form `test_*.jl`, but exclude `*setup*.jl`" +""" +match files of the form `test_*.jl`, but exclude `*setup*.jl` +""" function istestfile(fn) - return endswith(fn, ".jl") && startswith(basename(fn), "test_") && !contains(fn, "setup") + return endswith(fn, ".jl") && startswith(basename(fn), "test_") && + !contains(fn, "setup") end -"match files of the form `*.jl`, but exclude `*_notest.jl` and `*setup*.jl`" +""" +match files of the form `*.jl`, but exclude `*_notest.jl` and `*setup*.jl` +""" function isexamplefile(fn) return endswith(fn, ".jl") && !endswith(fn, "_notest.jl") && !contains(fn, "setup") end @@ -57,7 +62,7 @@ end :macrocall, GlobalRef(Suppressor, Symbol("@suppress")), LineNumberNode(@__LINE__, @__FILE__), - :(include($filename)), + :(include($filename)) ) ) end diff --git a/test/test_algorithmsinterfaceextensions.jl b/test/test_algorithmsinterfaceextensions.jl index 8e0665c..44e6a09 100644 --- a/test/test_algorithmsinterfaceextensions.jl +++ b/test/test_algorithmsinterfaceextensions.jl @@ -164,7 +164,7 @@ end # Test with CallbackAction (wrapped functions) state = AIE.with_algorithmlogger( :TestProblem_TestAlgorithm_PreStep => callback1, - :TestProblem_TestAlgorithm_PostStep => callback2, + :TestProblem_TestAlgorithm_PostStep => callback2 ) do return AI.solve(problem, algorithm; iterate = [0.0]) end @@ -227,7 +227,7 @@ end ) state = AIE.DefaultState(; iterate = [0.0], - stopping_criterion_state = stopping_criterion_state, + stopping_criterion_state = stopping_criterion_state ) # Test progression through iterations @@ -253,7 +253,7 @@ end state = AIE.DefaultState(; iterate = [5.0, 10.0], iteration = 1, - stopping_criterion_state, + stopping_criterion_state ) subproblem, subalgorithm, substate = AIE.get_subproblem(problem, nested_alg, state) @@ -264,7 +264,7 @@ end # Test set_substate! new_substate = AIE.DefaultState(; iterate = [100.0, 200.0], - substate.stopping_criterion_state, + substate.stopping_criterion_state ) AIE.set_substate!(problem, nested_alg, state, new_substate) @test state.iterate ≈ [100.0, 200.0] @@ -321,7 +321,7 @@ end flattened_alg = AIE.DefaultFlattenedAlgorithm(; algorithms = nested_algs, - stopping_criterion = AI.StopAfterIteration(4), + stopping_criterion = AI.StopAfterIteration(4) ) problem = TestProblem([1.0]) @@ -330,7 +330,7 @@ end ) state = AIE.DefaultFlattenedAlgorithmState(; iterate = [0.0], - stopping_criterion_state = stopping_criterion_state, + stopping_criterion_state = stopping_criterion_state ) # Test initial state @@ -388,7 +388,7 @@ end # Using the helper function flattened_alg = AIE.flattened_algorithm(2) do i AIE.nested_algorithm(1) do j - TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) end end diff --git a/test/test_aqua.jl b/test/test_aqua.jl index a38563a..8eb4612 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -1,5 +1,5 @@ -using ITensorNetworksNext: ITensorNetworksNext using Aqua: Aqua +using ITensorNetworksNext: ITensorNetworksNext using Test: @testset @testset "Code quality (Aqua.jl)" begin diff --git a/test/test_basics.jl b/test/test_basics.jl index 0c9d803..9f80b25 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,7 +1,7 @@ using Dictionaries: Indices using Graphs: dst, edges, has_edge, ne, nv, src, vertices -using ITensorNetworksNext: TensorNetwork, linkinds, siteinds using ITensorBase: Index +using ITensorNetworksNext: TensorNetwork, linkinds, siteinds using NamedDimsArrays: dimnames using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges using NamedGraphs.NamedGraphGenerators: named_grid diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 8a817b2..d1cca76 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -1,26 +1,26 @@ +using DiagonalArrays: δ using Dictionaries: Dictionary, set! -using ITensorBase: Index, ITensor, prime, noprime +using Graphs: AbstractGraph, dst, edges, src, vertices +using ITensorBase: ITensor, Index, noprime, prime using ITensorNetworksNext: - BeliefPropagationCache, - ITensorNetworksNext, - TensorNetwork, - partitionfunction -using DiagonalArrays: δ -using Graphs: src, dst, edges, vertices, AbstractGraph -using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree + ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, partitionfunction +using LinearAlgebra: LinearAlgebra +using NamedDimsArrays: inds, name using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges, vertextype +using NamedGraphs.NamedGraphGenerators: named_comb_tree, named_grid using Test: @test, @testset -using LinearAlgebra: LinearAlgebra -using NamedDimsArrays: name, inds function ising_tensornetwork(g::AbstractGraph, β::Real; h = 0.0) - links = Dictionary(edges(g), [Index(2; tags = "edge" => "e$(src(e))_$(dst(e))") for e in edges(g)]) + links = Dictionary( + edges(g), + [Index(2; tags = "edge" => "e$(src(e))_$(dst(e))") for e in edges(g)] + ) links = merge(links, Dictionary(reverse.(edges(g)), [links[e] for e in edges(g)])) # symmetric sqrt of Boltzmann matrix W = exp(β σσ') sqrt_Ws = Dictionary() for e in edges(g) - W = [ exp(-(β + 2 * h)) exp(β); exp(β) exp(-(β - 2 * h)) ] + W = [exp(-(β + 2 * h)) exp(β); exp(β) exp(-(β - 2 * h))] F = LinearAlgebra.svd(W) U, S, V = F.U, F.S, F.Vt @@ -87,5 +87,4 @@ end z_bp = partitionfunction(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test z_bp ≈ z_exact rtol = 1.0e-4 - end diff --git a/test/test_contract_network.jl b/test/test_contract_network.jl index 35b2275..b453e76 100644 --- a/test/test_contract_network.jl +++ b/test/test_contract_network.jl @@ -1,11 +1,11 @@ +using BackendSelection: @Algorithm_str, Algorithm using Graphs: edges +using ITensorBase: Index +using ITensorNetworksNext: TensorNetwork, contract_network, linkinds, siteinds using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges using NamedGraphs.NamedGraphGenerators: named_grid -using ITensorBase: Index -using ITensorNetworksNext: TensorNetwork, linkinds, siteinds, contract_network using TensorOperations: TensorOperations using Test: @test, @testset -using BackendSelection: @Algorithm_str, Algorithm @testset "contract_network" begin orderalg = alg -> Algorithm"exact"(; order_alg = Algorithm(alg)) diff --git a/test/test_dmrg.jl b/test/test_dmrg.jl index 01f04ac..dba2570 100644 --- a/test/test_dmrg.jl +++ b/test/test_dmrg.jl @@ -1,6 +1,6 @@ import AlgorithmsInterface as AI -using ITensorNetworksNext: EigsolveRegion, dmrg, select_algorithm import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +using ITensorNetworksNext: EigsolveRegion, dmrg, select_algorithm using Test: @test, @testset @testset "select_algorithm(dmrg, ...)" begin @@ -21,7 +21,7 @@ using Test: @test, @testset return EigsolveRegion( regions[j]; maxdim = maxdims[i], - cutoff = cutoffs[i], + cutoff = cutoffs[i] ) end end diff --git a/test/test_lazynameddimsarrays.jl b/test/test_lazynameddimsarrays.jl index d067c24..751b469 100644 --- a/test/test_lazynameddimsarrays.jl +++ b/test/test_lazynameddimsarrays.jl @@ -1,9 +1,9 @@ using AbstractTrees: AbstractTrees, print_tree, printnode using Base.Broadcast: materialize -using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArrays, LazyNamedDimsArray, - Mul, SymbolicArray, ismul, lazy, substitute, symnameddims -using NamedDimsArrays: NamedDimsArray, @names, denamed, dimnames, inds, nameddims, - namedoneto +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, LazyNamedDimsArrays, Mul, + SymbolicArray, ismul, lazy, substitute, symnameddims +using NamedDimsArrays: + @names, NamedDimsArray, denamed, dimnames, inds, nameddims, namedoneto using TermInterface: arguments, arity, children, head, iscall, isexpr, maketerm, operation, sorted_arguments, sorted_children using Test: @test, @test_throws, @testset diff --git a/test/test_tensornetworkgenerators.jl b/test/test_tensornetworkgenerators.jl index 2d092c3..f29a900 100644 --- a/test/test_tensornetworkgenerators.jl +++ b/test/test_tensornetworkgenerators.jl @@ -1,8 +1,8 @@ using DiagonalArrays: δ using Graphs: edges, ne, nv, vertices using ITensorBase: Index -using ITensorNetworksNext: contract_network using ITensorNetworksNext.TensorNetworkGenerators: delta_network, ising_network +using ITensorNetworksNext: contract_network using NamedDimsArrays: inds using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges using NamedGraphs.NamedGraphGenerators: named_grid From 951cee6195de502a41e82e5c1139904803f8febd Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 12 Mar 2026 19:24:40 -0400 Subject: [PATCH 40/86] Simplify BP code --- .../AlgorithmsInterfaceExtensions.jl | 6 +- .../abstractbeliefpropagationcache.jl | 15 +- .../beliefpropagationproblem.jl | 187 ++++++------------ test/test_beliefpropagation.jl | 9 +- 4 files changed, 79 insertions(+), 138 deletions(-) diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index f44cbeb..9f63691 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -22,12 +22,12 @@ function AI.initialize_state!( end function AI.initialize_state( - problem::Problem, algorithm::Algorithm; kwargs... + problem::Problem, algorithm::Algorithm; iterate, kwargs... ) stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion + problem, algorithm, algorithm.stopping_criterion; iterate ) - return DefaultState(; stopping_criterion_state, kwargs...) + return DefaultState(; iterate, stopping_criterion_state, kwargs...) end # ============================ DefaultState ================================================ diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 33f185b..9ac3d59 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -58,15 +58,18 @@ function setfactor!(bpc::AbstractDataGraph, vertex, factor) return bpc end -function region_scalar(bp_cache::AbstractGraph, edge::AbstractEdge) - return (message(bp_cache, edge) * message(bp_cache, reverse(edge)))[] +function region_scalar(bp_cache::AbstractGraph, edge::AbstractEdge; alg = "exact") + # Make generic to deal with the possibilty of multiple messages. + m1s = messages(bp_cache, [edge]) + m2s = messages(bp_cache, [reverse(edge)]) + return contract_network(vcat(m1s, m2s); alg)[] end -function region_scalar(bp_cache::AbstractGraph, vertex) +function region_scalar(bp_cache::AbstractGraph, vertex; alg = "exact") messages = incoming_messages(bp_cache, vertex) state = factors(bp_cache, vertex) - return (reduce(*, messages) * reduce(*, state))[] + return contract_network(vcat(messages, state); alg)[] end message_type(bpc::AbstractGraph) = message_type(typeof(bpc)) @@ -124,7 +127,7 @@ factor_type(::Type{<:AbstractBeliefPropagationCache{<:Any, VD}}) where {VD} = VD message_type(bpc::AbstractBeliefPropagationCache) = message_type(typeof(bpc)) message_type(::Type{<:AbstractBeliefPropagationCache{<:Any, <:Any, ED}}) where {ED} = ED -function free_energy(bp_cache::AbstractBeliefPropagationCache) +function logscalar(bp_cache::AbstractBeliefPropagationCache) numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache) if any(t -> real(t) < 0, numerator_terms) @@ -140,4 +143,4 @@ function free_energy(bp_cache::AbstractBeliefPropagationCache) return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) end -partitionfunction(bp_cache::AbstractBeliefPropagationCache) = exp(free_energy(bp_cache)) +scalar(bp_cache::AbstractBeliefPropagationCache) = exp(logscalar(bp_cache)) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 1a62792..1d96ed0 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -7,16 +7,19 @@ using NamedDimsArrays: AbstractNamedDimsArray using NamedGraphs.GraphsExtensions: add_edges!, boundary_edges, subgraph using NamedGraphs.PartitionedGraphs: quotientvertices -@kwdef struct StopWhenConverged <: AI.StoppingCriterion - tol::Float64 = 0.0 +@kwdef struct StopWhenConverged{Tol <: Real} <: AI.StoppingCriterion + tol::Tol = NaN end -@kwdef mutable struct StopWhenConvergedState <: AI.StoppingCriterionState - delta::Float64 = Inf +@kwdef mutable struct StopWhenConvergedState{Iterate, Delta <: Real} <: + AI.StoppingCriterionState + delta::Delta = NaN + at_iteration::Int = -1 + previous_iterate::Iterate end -function AI.initialize_state(::AIE.Problem, ::AIE.Algorithm, ::StopWhenConverged) - return StopWhenConvergedState() +function AI.initialize_state(::AIE.Problem, ::AIE.Algorithm, ::StopWhenConverged; iterate) + return StopWhenConvergedState(; previous_iterate = copy(iterate)) end function AI.initialize_state!( @@ -25,23 +28,45 @@ function AI.initialize_state!( ::StopWhenConverged, st::StopWhenConvergedState ) - st.delta = Inf + st.delta = NaN return st end function AI.is_finished!( - ::AIE.Problem, - ::AIE.Algorithm, + problem::AIE.Problem, + algorithm::AIE.Algorithm, state::AIE.State, c::StopWhenConverged, st::StopWhenConvergedState ) # maxdiff = 0.0 initially, so skip this the first time. - if state.iteration > 0 - st.delta = state.iterate.maxdiff + iterate = state.iterate + previous_iterate = st.previous_iterate + + delta = iterate_diff(iterate, previous_iterate) + + st.previous_iterate = copy(iterate) + + state.iteration == 0 && return false + + st.delta = delta + + if AI.is_finished(problem, algorithm, state, c, st) + st.at_iteration = state.iteration + return true end + return false +end + +function AI.is_finished( + ::AIE.Problem, + ::AIE.Algorithm, + ::AIE.State, + c::StopWhenConverged, + st::StopWhenConvergedState + ) return st.delta < c.tol end @@ -49,11 +74,12 @@ struct BeliefPropagationProblem{Network} <: AIE.Problem network::Network end -@kwdef mutable struct BeliefPropagationState{Iterate, Diffs} <: - AIE.NonIterativeAlgorithmState - iterate::Iterate - diffs::Diffs = similar(edge_data(iterate), Float64) - maxdiff::Float64 = 0.0 +function iterate_diff(cache1, cache2) + return maximum(edges(cache1)) do edge + m1 = cache1[edge] + m2 = cache2[edge] + return 1 - abs2(LinearAlgebra.dot(normalize(m1), normalize(m2))) + end end @kwdef struct BeliefPropagation{ @@ -69,7 +95,7 @@ function BeliefPropagation(f::Function, niterations::Int; kwargs...) return BeliefPropagation(; algorithms = f.(1:niterations), kwargs...) end -abstract type AbstractMessageUpdate <: AIE.NonIterativeAlgorithm end +abstract type AbstractMessageUpdate end struct SimpleMessageUpdate{E <: AbstractEdge, Kwargs <: NamedTuple} <: AbstractMessageUpdate edge::E @@ -80,12 +106,11 @@ function SimpleMessageUpdate( edge; normalize = true, contraction_alg = "exact", - compute_diff = false, kwargs... ) return SimpleMessageUpdate( edge, - (; normalize, contraction_alg, compute_diff, kwargs...) + (; normalize, contraction_alg, kwargs...) ) end @@ -97,9 +122,10 @@ function Base.getproperty(alg::SimpleMessageUpdate, name::Symbol) end end +AI.initialize_state(::BeliefPropagationProblem, ::SimpleMessageUpdate; iterate) = iterate + struct BeliefPropagationSweep{ - ChildAlgorithm <: AIE.Algorithm, - Algorithms <: AbstractVector{ChildAlgorithm}, + ChildAlgorithm, Algorithms <: AbstractVector{ChildAlgorithm}, } <: AIE.NestedAlgorithm algorithms::Algorithms stopping_criterion::AI.StopAfterIteration @@ -113,75 +139,29 @@ function BeliefPropagationSweep(f::Function, edges) return BeliefPropagationSweep(; algorithms = f.(edges)) end -function AI.initialize_state( - ::BeliefPropagationProblem, ::AIE.NonIterativeAlgorithm; iterate, kwargs... - ) - diffs = iterate.diffs - maxdiff = iterate.maxdiff - - return BeliefPropagationState(; iterate = iterate.iterate, diffs, maxdiff, kwargs...) -end - -# This gets called at the start of every sweep. -function AI.initialize_state!( - ::BeliefPropagationProblem, - ::BeliefPropagationSweep, - iteration_state::AIE.State - ) - iteration_state.iterate.maxdiff = 0.0 - return iteration_state -end - function AIE.set_substate!( ::BeliefPropagationProblem, ::BeliefPropagationSweep, - sweep_state::AIE.DefaultState, - noniterative_substate::BeliefPropagationState + state::AIE.DefaultState, + cache::AbstractBeliefPropagationCache ) - sweep_state.iterate = noniterative_substate + state.iterate = cache - return sweep_state -end - -struct MessageUpdateProblem{Factor, Messages} <: AIE.Problem - factor::Factor - messages::Messages + return state end function AI.solve!( - problem::BeliefPropagationProblem, + ::BeliefPropagationProblem, algorithm::SimpleMessageUpdate, - state::BeliefPropagationState; - logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm) + cache::AbstractBeliefPropagationCache; kwargs... ) - logger = AI.algorithm_logger() - - cache = state.iterate edge = algorithm.edge - AI.emit_message( - logger, problem, algorithm, state, Symbol(logging_context_prefix, :PreUpdate) - ) - new_message = updated_message(algorithm, cache) - if !isnothing(algorithm.message_diff_function) - diff = algorithm.message_diff_function(new_message, cache[edge]) - - if diff > state.maxdiff - state.maxdiff = diff - end - - state.diffs[edge] = diff - end - setmessage!(cache, edge, new_message) - AI.emit_message( - logger, problem, algorithm, state, Symbol(logging_context_prefix, :PostUpdate) - ) - - return state + return cache end default_message_diff_function(m1, m2) = norm(normalize(m1) - normalize(m2)) @@ -192,50 +172,16 @@ function updated_message(algorithm, cache) vertex = src(edge) messages = incoming_messages(cache, vertex; ignore_edges = typeof(edge)[reverse(edge)]) - update_problem = MessageUpdateProblem(cache[vertex], messages) - - message_state = AI.solve(update_problem, algorithm; iterate = message(cache, edge)) - - return message_state.iterate -end - -function AI.solve!( - problem::MessageUpdateProblem, - algorithm::SimpleMessageUpdate, - state::AIE.DefaultNonIterativeAlgorithmState; - logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm), - kwargs... - ) - logger = AI.algorithm_logger() - - AI.emit_message( - logger, problem, algorithm, state, Symbol(logging_context_prefix, :PreUpdate) - ) - - state.iterate = - contract_messages(algorithm.contraction_alg, problem.factor, problem.messages) - - AI.emit_message( - logger, problem, algorithm, state, Symbol( - logging_context_prefix, - :PreNormalization - ) - ) + new_message = contract_messages(algorithm.contraction_alg, cache[vertex], messages) if algorithm.normalize - # TODO: use `sum` not `norm` - message_norm = LinearAlgebra.norm(state.iterate) + message_norm = sum(new_message) if !iszero(message_norm) - state.iterate /= message_norm + new_message /= message_norm end end - AI.emit_message( - logger, problem, algorithm, state, - Symbol(logging_context_prefix, :PostNormalization) - ) - - return state + return new_message end function contract_messages(alg, factor::AbstractArray, messages) @@ -246,6 +192,7 @@ end function beliefpropagation(network; kwargs...) return beliefpropagation(BeliefPropagationCache(network), network; kwargs...) end + function beliefpropagation( cache::AbstractBeliefPropagationCache, network = nothing; @@ -255,14 +202,9 @@ function beliefpropagation( algorithm = select_algorithm(beliefpropagation, cache; kwargs...) - # The nested algorithms will wrap and manipulate this object. - base_state = BeliefPropagationState(; iterate = cache) - - state = AI.initialize_state(problem, algorithm; iterate = base_state) - - state = AI.solve!(problem, algorithm, state) + state = AI.solve(problem, algorithm; iterate = cache) - return state.iterate.iterate + return state.iterate end function select_algorithm( @@ -271,11 +213,6 @@ function select_algorithm( edges = forest_cover_edge_sequence(cache), maxiter = is_tree(cache) ? 1 : nothing, tol = -Inf, - message_diff_function = if tol > -Inf - (m1, m2) -> norm(m1 / norm(m1) - m2 / norm(m2)) - else - nothing - end, kwargs... ) if isnothing(maxiter) @@ -288,7 +225,7 @@ function select_algorithm( stopping_criterion = stopping_criterion | StopWhenConverged(tol) end - extended_kwargs = extend_columns((; message_diff_function, kwargs...), maxiter) + extended_kwargs = extend_columns((; kwargs...), maxiter) edge_kwargs = rows(extended_kwargs, maxiter) return BeliefPropagation(maxiter; stopping_criterion) do repnum diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index d1cca76..59affe9 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -3,7 +3,7 @@ using Dictionaries: Dictionary, set! using Graphs: AbstractGraph, dst, edges, src, vertices using ITensorBase: ITensor, Index, noprime, prime using ITensorNetworksNext: - ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, partitionfunction + ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, scalar using LinearAlgebra: LinearAlgebra using NamedDimsArrays: inds, name using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges, vertextype @@ -43,6 +43,7 @@ function ising_tensornetwork(g::AbstractGraph, β::Real; h = 0.0) end return TensorNetwork(g, ts) end + @testset "BeliefPropagation" begin #Chain of tensors @@ -57,7 +58,7 @@ end bpc = BeliefPropagationCache(tn) bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) - z_bp = partitionfunction(bpc) + z_bp = scalar(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test z_bp ≈ z_exact atol = 1.0e-14 @@ -73,7 +74,7 @@ end bpc = BeliefPropagationCache(tn) bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) - z_bp = partitionfunction(bpc) + z_bp = scalar(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test z_bp ≈ z_exact atol = 1.0e-10 @@ -84,7 +85,7 @@ end bpc = ITensorNetworksNext.BeliefPropagationCache(tn) bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 50, tol = 1.0e-10) - z_bp = partitionfunction(bpc) + z_bp = scalar(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test z_bp ≈ z_exact rtol = 1.0e-4 end From 1f1920c8f1e46dc5400e679be0e842aa8e3534df Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 12 Mar 2026 20:22:48 -0400 Subject: [PATCH 41/86] Add spin ice test --- .../beliefpropagationproblem.jl | 8 +-- test/test_beliefpropagation.jl | 51 ++++++++----------- 2 files changed, 27 insertions(+), 32 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 1d96ed0..3b1f1ef 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -78,7 +78,9 @@ function iterate_diff(cache1, cache2) return maximum(edges(cache1)) do edge m1 = cache1[edge] m2 = cache2[edge] - return 1 - abs2(LinearAlgebra.dot(normalize(m1), normalize(m2))) + #FIXME: `abs2` not defined for `ITensor` + m1m2 = LinearAlgebra.dot(normalize(m1), normalize(m2)) + return 1 - abs(m1m2)^2 end end @@ -212,7 +214,7 @@ function select_algorithm( cache::AbstractBeliefPropagationCache; edges = forest_cover_edge_sequence(cache), maxiter = is_tree(cache) ? 1 : nothing, - tol = -Inf, + tol = NaN, kwargs... ) if isnothing(maxiter) @@ -221,7 +223,7 @@ function select_algorithm( stopping_criterion = AI.StopAfterIteration(maxiter) - if tol > -Inf + if !isnan(tol) stopping_criterion = stopping_criterion | StopWhenConverged(tol) end diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 59affe9..5c75900 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -10,35 +10,25 @@ using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges, vertextype using NamedGraphs.NamedGraphGenerators: named_comb_tree, named_grid using Test: @test, @testset -function ising_tensornetwork(g::AbstractGraph, β::Real; h = 0.0) +function spin_ice_tensornetwork(g) links = Dictionary( edges(g), - [Index(2; tags = "edge" => "e$(src(e))_$(dst(e))") for e in edges(g)] + [Index(2) for e in edges(g)] + # [Index(2; tags = "edge " => "e$(src(e))_$(dst(e))") for e in edges(g)] ) links = merge(links, Dictionary(reverse.(edges(g)), [links[e] for e in edges(g)])) - # symmetric sqrt of Boltzmann matrix W = exp(β σσ') - sqrt_Ws = Dictionary() - for e in edges(g) - W = [exp(-(β + 2 * h)) exp(β); exp(β) exp(-(β - 2 * h))] - - F = LinearAlgebra.svd(W) - U, S, V = F.U, F.S, F.Vt - @assert U * LinearAlgebra.diagm(S) * V ≈ W - id = [1.0 0.0; 0.0 1.0] - set!(sqrt_Ws, e, id) - set!(sqrt_Ws, reverse(e), U * LinearAlgebra.diagm(S) * V) - end ts = Dictionary{vertextype(g), ITensor}() for v in vertices(g) es = incident_edges(g, v; dir = :in) - #t = ITensor(1.0, physical_inds[v]...) * delta([links[e] for e in es]) - t = δ(Float64, Tuple([links[e] for e in es])) - for e in es - t_prime = ITensor(sqrt_Ws[e], (name(links[e]), name(prime(links[e])))) * t - newinds = noprime.(inds(t_prime)) - t = ITensor(parent(t_prime), name.(newinds)) + t_data = zeros(Int, 2, 2, 2, 2) + for (i, j, k, l) in Iterators.product(0:1, 0:1, 0:1, 0:1) + if i + j + k + l == 2 + t_data[i + 1, j + 1, k + 1, l + 1] = 1 + end end + linkinds = [links[e] for e in es] + t = t_data[linkinds...] set!(ts, v, t) end return TensorNetwork(g, ts) @@ -78,14 +68,17 @@ end z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test z_bp ≈ z_exact atol = 1.0e-10 - #Square lattice Ising model - dims = (3, 3) - g = named_grid(dims) - tn = ising_tensornetwork(g, 0.05, h = 0.5) - bpc = ITensorNetworksNext.BeliefPropagationCache(tn) - bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 50, tol = 1.0e-10) + #Spin Ice Model + for n in (3, 4, 5) + dims = (n, n) + g = named_grid(dims; periodic = true) + tn = spin_ice_tensornetwork(g) - z_bp = scalar(bpc) - z_exact = reduce(*, [tn[v] for v in vertices(g)])[] - @test z_bp ≈ z_exact rtol = 1.0e-4 + bpc = ITensorNetworksNext.BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) + + z_bp = scalar(bpc) + + @test z_bp ≈ 1.5^(n^2) + end end From 5f3be9835b043aec34e9840c5148f87912c5fb33 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 12 Mar 2026 20:26:18 -0400 Subject: [PATCH 42/86] Version Bump --- Project.toml | 2 +- docs/Project.toml | 2 +- examples/Project.toml | 2 +- test/Project.toml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index b5fd828..551b18e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" -version = "0.3.24" +version = "0.4.0" authors = ["ITensor developers and contributors"] [workspace] diff --git a/docs/Project.toml b/docs/Project.toml index fbde468..b4f33d5 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -10,5 +10,5 @@ path = ".." [compat] Documenter = "1" ITensorFormatter = "0.2.27" -ITensorNetworksNext = "0.3" +ITensorNetworksNext = "0.4" Literate = "2" diff --git a/examples/Project.toml b/examples/Project.toml index 780f959..4108720 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -5,4 +5,4 @@ ITensorNetworksNext = "302f2e75-49f0-4526-aef7-d8ba550cb06c" path = ".." [compat] -ITensorNetworksNext = "0.3" +ITensorNetworksNext = "0.4" diff --git a/test/Project.toml b/test/Project.toml index 317a90c..6725634 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -31,7 +31,7 @@ DiagonalArrays = "0.3.23" Dictionaries = "0.4.5" Graphs = "1.13.1" ITensorBase = "0.5" -ITensorNetworksNext = "0.3" +ITensorNetworksNext = "0.4" ITensorPkgSkeleton = "0.3.42" NamedDimsArrays = "0.14" NamedGraphs = "0.10" From 487683a9ba798aae0e465fd0490aa0eb1950a29c Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 13 Mar 2026 10:22:17 -0400 Subject: [PATCH 43/86] Use `abs2` in message diff function. --- src/beliefpropagation/beliefpropagationproblem.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 3b1f1ef..12d503e 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -74,13 +74,14 @@ struct BeliefPropagationProblem{Network} <: AIE.Problem network::Network end -function iterate_diff(cache1, cache2) +function iterate_diff( + cache1::AbstractBeliefPropagationCache, + cache2::AbstractBeliefPropagationCache + ) return maximum(edges(cache1)) do edge m1 = cache1[edge] m2 = cache2[edge] - #FIXME: `abs2` not defined for `ITensor` - m1m2 = LinearAlgebra.dot(normalize(m1), normalize(m2)) - return 1 - abs(m1m2)^2 + return 1 - abs2(LinearAlgebra.dot(normalize(m1), normalize(m2))) end end From aa242432f121a39cf4a1b693f18f282ee72be96f Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 16 Mar 2026 14:57:52 -0600 Subject: [PATCH 44/86] Add method for setting intitial messages; improve spin ice tests. --- src/beliefpropagation/beliefpropagationcache.jl | 10 ++++++++-- test/test_beliefpropagation.jl | 11 +++++++---- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 5d1a31c..e7ccac6 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -61,20 +61,26 @@ function BeliefPropagationCache(network::AbstractGraph) graph = underlying_graph(network) return BeliefPropagationCache(graph, copy(vertex_data(network))) end -function BeliefPropagationCache(MT::Type, network::AbstractGraph) +function BeliefPropagationCache(callable::Base.Callable, network::AbstractGraph) graph = underlying_graph(network) - return BeliefPropagationCache(MT, graph, copy(vertex_data(network))) + return BeliefPropagationCache(callable, graph, copy(vertex_data(network))) end function BeliefPropagationCache(graph::AbstractGraph, factors::Dictionary) MT = vertex_data_type(typeof(graph)) return BeliefPropagationCache(MT, graph, factors) end + function BeliefPropagationCache(MT::Type, graph::AbstractGraph, factors::Dictionary) messages = Dictionary{edgetype(graph), MT}() return BeliefPropagationCache(graph, factors, messages) end +function BeliefPropagationCache(f::Function, graph::AbstractGraph, factors::Dictionary) + messages = map(f, Indices(edges(graph))) + return BeliefPropagationCache(graph, factors, messages) +end + function Base.copy(bp_cache::BeliefPropagationCache) return BeliefPropagationCache( copy(bp_cache.underlying_graph), diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 5c75900..aaa2031 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -3,7 +3,7 @@ using Dictionaries: Dictionary, set! using Graphs: AbstractGraph, dst, edges, src, vertices using ITensorBase: ITensor, Index, noprime, prime using ITensorNetworksNext: - ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, scalar + ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, linkinds, scalar using LinearAlgebra: LinearAlgebra using NamedDimsArrays: inds, name using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges, vertextype @@ -68,14 +68,17 @@ end z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test z_bp ≈ z_exact atol = 1.0e-10 - #Spin Ice Model + #Spin Ice Model (has analytical bp solution given by 1.5^(n^2)) for n in (3, 4, 5) dims = (n, n) g = named_grid(dims; periodic = true) tn = spin_ice_tensornetwork(g) - bpc = ITensorNetworksNext.BeliefPropagationCache(tn) - bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) + bpc = ITensorNetworksNext.BeliefPropagationCache(tn) do edge + # Use `rand` so messages have positive elements. + return rand(Tuple(linkinds(tn, edge))) + end + bpc = ITensorNetworksNext.beliefpropagation(bpc; tol = 1.0e-10, maxiter = 10) z_bp = scalar(bpc) From 9248686e749e788beadfbfead12ca1d822687e01 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 16 Mar 2026 14:58:54 -0600 Subject: [PATCH 45/86] Remove redundant `default_message_diff_function` function. --- src/beliefpropagation/beliefpropagationproblem.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 12d503e..a30bf08 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -167,8 +167,6 @@ function AI.solve!( return cache end -default_message_diff_function(m1, m2) = norm(normalize(m1) - normalize(m2)) - function updated_message(algorithm, cache) edge = algorithm.edge From 9d7abeac4cfd9ad871cb8e1f7fa78d6f87a3b784 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 24 Mar 2026 12:53:49 -0400 Subject: [PATCH 46/86] Upgrade to DataGraphs and NamedGraphs to 0.4 and 0.11 --- Project.toml | 4 ++-- src/abstracttensornetwork.jl | 5 ++--- src/tensornetwork.jl | 13 +++++++------ test/Project.toml | 2 +- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 551b18e..48e06e4 100644 --- a/Project.toml +++ b/Project.toml @@ -39,7 +39,7 @@ Adapt = "4.3" AlgorithmsInterface = "0.1" BackendSelection = "0.1.6" Combinatorics = "1" -DataGraphs = "0.3.1" +DataGraphs = "0.4.0" DiagonalArrays = "0.3.31" Dictionaries = "0.4.5" FunctionImplementations = "0.4.1" @@ -47,7 +47,7 @@ Graphs = "1.13.1" LinearAlgebra = "1.10" MacroTools = "0.5.16" NamedDimsArrays = "0.14.3" -NamedGraphs = "0.10" +NamedGraphs = "0.11" SimpleTraits = "0.9.5" SplitApplyCombine = "1.2.3" TensorOperations = "5.3.1" diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index 7fca799..bed2ac7 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -8,10 +8,9 @@ using Graphs: Graphs, AbstractEdge, AbstractGraph, add_edge!, add_vertex!, dst, using LinearAlgebra: LinearAlgebra using MacroTools: @capture using NamedDimsArrays: dimnames, inds -using NamedGraphs.GraphsExtensions: - directed_graph, incident_edges, rem_edges!, similar_graph, vertextype +using NamedGraphs.GraphsExtensions: directed_graph, incident_edges, rem_edges!, vertextype using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger -using NamedGraphs: NamedGraphs, NamedGraph, not_implemented +using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, similar_graph abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} end diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index a371373..ebb9f11 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -136,16 +136,17 @@ function Graphs.rem_edge!(tn::TensorNetwork, e) return true end -function GraphsExtensions.similar_graph(type::Type{<:TensorNetwork}) +function NamedGraphs.similar_graph( + type::Type{<:TensorNetwork}, + underlying_graph::AbstractGraph + ) DT = fieldtype(type, :tensors) empty_dict = DT() - return TensorNetwork(similar_graph(underlying_graph_type(type)), empty_dict) -end -function GraphsExtensions.similar_graph(tn::TensorNetwork, underlying_graph::AbstractGraph) - DT = fieldtype(typeof(tn), :tensors) - empty_dict = DT() return _TensorNetwork(underlying_graph, empty_dict) end +function NamedGraphs.similar_graph(tn::TensorNetwork, underlying_graph::AbstractGraph) + return similar_graph(typeof(tn), underlying_graph) +end function NamedGraphs.induced_subgraph_from_vertices(graph::TensorNetwork, subvertices) return tensornetwork_induced_subgraph(graph, subvertices) diff --git a/test/Project.toml b/test/Project.toml index 6725634..ee4dbd0 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -34,7 +34,7 @@ ITensorBase = "0.5" ITensorNetworksNext = "0.4" ITensorPkgSkeleton = "0.3.42" NamedDimsArrays = "0.14" -NamedGraphs = "0.10" +NamedGraphs = "0.11" QuadGK = "2.11.2" SafeTestsets = "0.1" Suppressor = "0.2.8" From 0b65bfb138f848be99cc1d41f2ce92ba4eada233 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 14 Apr 2026 12:01:20 -0400 Subject: [PATCH 47/86] Formatting --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 48e06e4..5455ec2 100644 --- a/Project.toml +++ b/Project.toml @@ -39,7 +39,7 @@ Adapt = "4.3" AlgorithmsInterface = "0.1" BackendSelection = "0.1.6" Combinatorics = "1" -DataGraphs = "0.4.0" +DataGraphs = "0.4" DiagonalArrays = "0.3.31" Dictionaries = "0.4.5" FunctionImplementations = "0.4.1" From f08f022769a8cd3d801c8c61bd8679a6d6a3e575 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 17 Apr 2026 14:21:25 -0400 Subject: [PATCH 48/86] Upgrade to simplified `similar_graph` --- src/tensornetwork.jl | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index ebb9f11..a7cbb0e 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -138,21 +138,36 @@ end function NamedGraphs.similar_graph( type::Type{<:TensorNetwork}, - underlying_graph::AbstractGraph + vertices, + edges ) DT = fieldtype(type, :tensors) empty_dict = DT() + + underlying_graph = similar_graph(underlying_graph_type(type), vertices, edges) + return _TensorNetwork(underlying_graph, empty_dict) end -function NamedGraphs.similar_graph(tn::TensorNetwork, underlying_graph::AbstractGraph) - return similar_graph(typeof(tn), underlying_graph) +function NamedGraphs.similar_graph( + graph::TensorNetwork, + VD::Type, + ::Type{<:Nothing}, + vertices, + edges + ) + V = eltype(vertices) + empty_dict = Dictionary{V, VD}() + + underlying_graph = similar_graph(underlying_graph(graph), vertices, edges) + + return _TensorNetwork(underlying_graph, empty_dict) end function NamedGraphs.induced_subgraph_from_vertices(graph::TensorNetwork, subvertices) - return tensornetwork_induced_subgraph(graph, subvertices) + return induced_subgraph_tensornetwork(graph, subvertices) end -function tensornetwork_induced_subgraph(graph, subvertices) +function induced_subgraph_tensornetwork(graph, subvertices) underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices) From 3aec51640f8c86f0632ca4c7e75f38d2788f5339 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 20 Apr 2026 14:06:37 +0100 Subject: [PATCH 49/86] Remove edge arg in `similar_graph`. --- src/tensornetwork.jl | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index a7cbb0e..34bec2c 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -138,13 +138,12 @@ end function NamedGraphs.similar_graph( type::Type{<:TensorNetwork}, - vertices, - edges + vertices ) DT = fieldtype(type, :tensors) empty_dict = DT() - underlying_graph = similar_graph(underlying_graph_type(type), vertices, edges) + underlying_graph = similar_graph(underlying_graph_type(type), vertices) return _TensorNetwork(underlying_graph, empty_dict) end @@ -152,15 +151,14 @@ function NamedGraphs.similar_graph( graph::TensorNetwork, VD::Type, ::Type{<:Nothing}, - vertices, - edges + vertices ) V = eltype(vertices) empty_dict = Dictionary{V, VD}() - underlying_graph = similar_graph(underlying_graph(graph), vertices, edges) + new_underlying_graph = similar_graph(underlying_graph(graph), vertices) - return _TensorNetwork(underlying_graph, empty_dict) + return _TensorNetwork(new_underlying_graph, empty_dict) end function NamedGraphs.induced_subgraph_from_vertices(graph::TensorNetwork, subvertices) From 4d4bc5a0aea14a7c37270a2ee8774d0c79dd17d6 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 20 Apr 2026 14:07:36 +0100 Subject: [PATCH 50/86] Inline message computation into `solve!`; use type instead of alg string. --- .../beliefpropagationproblem.jl | 30 ++++++------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index a30bf08..a3443e8 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,5 +1,6 @@ import .AlgorithmsInterfaceExtensions as AIE import AlgorithmsInterface as AI +using BackendSelection: @Algorithm_str, Algorithm using DataGraphs: edge_data using Graphs: AbstractEdge, edges, has_edge, vertices using LinearAlgebra: norm, normalize @@ -98,9 +99,7 @@ function BeliefPropagation(f::Function, niterations::Int; kwargs...) return BeliefPropagation(; algorithms = f.(1:niterations), kwargs...) end -abstract type AbstractMessageUpdate end - -struct SimpleMessageUpdate{E <: AbstractEdge, Kwargs <: NamedTuple} <: AbstractMessageUpdate +struct SimpleMessageUpdate{E <: AbstractEdge, Kwargs <: NamedTuple} edge::E kwargs::Kwargs end @@ -108,7 +107,7 @@ end function SimpleMessageUpdate( edge; normalize = true, - contraction_alg = "exact", + contraction_alg = Algorithm"exact", kwargs... ) return SimpleMessageUpdate( @@ -160,20 +159,12 @@ function AI.solve!( ) edge = algorithm.edge - new_message = updated_message(algorithm, cache) - - setmessage!(cache, edge, new_message) - - return cache -end - -function updated_message(algorithm, cache) - edge = algorithm.edge - vertex = src(edge) - messages = incoming_messages(cache, vertex; ignore_edges = typeof(edge)[reverse(edge)]) + messages = incoming_messages(cache, vertex; ignore_edges = [reverse(edge)]) + + tensors = vcat([factor(cache, vertex)], messages) - new_message = contract_messages(algorithm.contraction_alg, cache[vertex], messages) + new_message = contract_network(tensors; algorithm.contraction_alg) if algorithm.normalize message_norm = sum(new_message) @@ -182,12 +173,9 @@ function updated_message(algorithm, cache) end end - return new_message -end + setmessage!(cache, edge, new_message) -function contract_messages(alg, factor::AbstractArray, messages) - factors = typeof(factor)[factor] - return contract_network(vcat(factors, messages); alg) + return cache end function beliefpropagation(network; kwargs...) From 1f23ab8234aa4321f6c96917fa29e3e85572c858 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 20 Apr 2026 14:25:42 +0100 Subject: [PATCH 51/86] Add in `PartitionedGraphs` interface methods for `TensorNetwork` and `BeliefPropagationCache`. --- src/beliefpropagation/beliefpropagationcache.jl | 4 ++++ src/tensornetwork.jl | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index e7ccac6..c9a6991 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -132,6 +132,10 @@ end ## PartitionedGraphs +function PartitionedGraphs.partitioned_vertices(bpc::BeliefPropagationCache) + return partitioned_vertices(bpc.underlying_graph) +end + # Take a QuotientView of the underlying graph. function PartitionedGraphs.quotientview(bpc::BeliefPropagationCache) graph = underlying_graph(bpc) diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 34bec2c..6d55a6a 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -177,6 +177,10 @@ function induced_subgraph_tensornetwork(graph, subvertices) end ## PartitionedGraphs +function PartitionedGraphs.partitioned_vertices(tn::TensorNetwork) + return partitioned_vertices(tn.underlying_graph) +end + function PartitionedGraphs.quotient_graph(tn::TensorNetwork) ug = quotient_graph(underlying_graph(tn)) From 9ca72759ca76252220b72c954bb931b3b4b6204b Mon Sep 17 00:00:00 2001 From: Jack Dunham <72548217+jack-dunham@users.noreply.github.com> Date: Wed, 22 Apr 2026 15:39:54 +0100 Subject: [PATCH 52/86] Use `map` instead of comprehension when returning messages. Co-authored-by: Matt Fishman --- src/beliefpropagation/abstractbeliefpropagationcache.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 9ac3d59..7d07c1e 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -4,7 +4,7 @@ using NamedGraphs.GraphsExtensions: boundary_edges using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, parent messages(bp_cache::AbstractGraph) = edge_data(bp_cache) -messages(bp_cache::AbstractGraph, edges) = [message(bp_cache, e) for e in edges] +messages(bp_cache::AbstractGraph, edges) = map(e -> message(bp_cache, e), edges) message(bp_cache::AbstractGraph, edge::AbstractEdge) = messages(bp_cache)[edge] From 005ccf00e7aa34eb7be5fcc7b5ed7b381aba17a0 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 21 Apr 2026 10:56:07 +0100 Subject: [PATCH 53/86] Test BP with differing precisions; remove `atol` test criteria. --- test/test_beliefpropagation.jl | 83 ++++++++++++++++++---------------- 1 file changed, 43 insertions(+), 40 deletions(-) diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index aaa2031..57c8ffc 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -35,53 +35,56 @@ function spin_ice_tensornetwork(g) end @testset "BeliefPropagation" begin + @testset "$T" for T in (Float32, Float64, ComplexF64, BigFloat) + #Chain of tensors + dims = (2, 1) + g = named_grid(dims) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) - #Chain of tensors - dims = (2, 1) - g = named_grid(dims) - l = Dict(e => Index(2) for e in edges(g)) - l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) - tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) - end + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(T, Tuple(is)) + end - bpc = BeliefPropagationCache(tn) - bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) - z_bp = scalar(bpc) - z_exact = reduce(*, [tn[v] for v in vertices(g)])[] - @test z_bp ≈ z_exact atol = 1.0e-14 + bpc = BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) + z_bp = scalar(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test z_bp ≈ z_exact - #Tree of tensors - dims = (4, 3) - g = named_comb_tree(dims) - l = Dict(e => Index(3) for e in edges(g)) - l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) - tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) - end + #Tree of tensors + dims = (4, 3) + g = named_comb_tree(dims) + l = Dict(e => Index(3) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(T, Tuple(is)) + end - bpc = BeliefPropagationCache(tn) - bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) - z_bp = scalar(bpc) - z_exact = reduce(*, [tn[v] for v in vertices(g)])[] - @test z_bp ≈ z_exact atol = 1.0e-10 + bpc = BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) + z_bp = scalar(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test z_bp ≈ z_exact - #Spin Ice Model (has analytical bp solution given by 1.5^(n^2)) - for n in (3, 4, 5) - dims = (n, n) - g = named_grid(dims; periodic = true) - tn = spin_ice_tensornetwork(g) + #Spin Ice Model (has analytical bp solution given by 1.5^(n^2)) + for n in (3, 4, 5) + dims = (n, n) + g = named_grid(dims; periodic = true) + tn = spin_ice_tensornetwork(g) - bpc = ITensorNetworksNext.BeliefPropagationCache(tn) do edge - # Use `rand` so messages have positive elements. - return rand(Tuple(linkinds(tn, edge))) - end - bpc = ITensorNetworksNext.beliefpropagation(bpc; tol = 1.0e-10, maxiter = 10) + bpc = ITensorNetworksNext.BeliefPropagationCache(tn) do edge + # Use `rand` so messages have positive elements. + return rand(T, Tuple(linkinds(tn, edge))) + end + bpc = + ITensorNetworksNext.beliefpropagation(bpc; tol = 1.0e-10, maxiter = 10) - z_bp = scalar(bpc) + z_bp = scalar(bpc) - @test z_bp ≈ 1.5^(n^2) + @test z_bp ≈ 1.5^(n^2) + end end end From 19d1256a0a105ae77f119df097e8264f2f0ab52c Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 22 Apr 2026 15:51:06 +0100 Subject: [PATCH 54/86] Fix `nested_algorithm` methods on iterables. --- .../AlgorithmsInterfaceExtensions.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index 9f63691..fe749b4 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -151,9 +151,8 @@ end abstract type NestedAlgorithm <: Algorithm end -function nested_algorithm(f::Function, iterable; kwargs...) - return DefaultNestedAlgorithm(f, iterable; kwargs...) -end +nested_algorithm(f::Function, int::Int; kwargs...) = nested_algorithm(f, 1:int; kwargs...) +nested_algorithm(f::Function, iterable; kwargs...) = DefaultNestedAlgorithm(f, iterable; kwargs...) max_iterations(algorithm::NestedAlgorithm) = length(algorithm.algorithms) @@ -206,8 +205,8 @@ from a list of stored algorithms. algorithms::Algorithms stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end -function DefaultNestedAlgorithm(f::Function, nalgorithms::Int; kwargs...) - return DefaultNestedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...) +function DefaultNestedAlgorithm(f::Function, iterable; kwargs...) + return DefaultNestedAlgorithm(; algorithms = f.(iterable), kwargs...) end # ============================ FlattenedAlgorithm ========================================== From 0674767c102101ad8e67784d9a6debb0315671db Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 22 Apr 2026 17:49:03 +0100 Subject: [PATCH 55/86] Cleanup `AbstractBeliefPropagationCache` interface. --- .../abstractbeliefpropagationcache.jl | 79 +++++++++---------- 1 file changed, 36 insertions(+), 43 deletions(-) diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 7d07c1e..662749c 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -2,41 +2,32 @@ using DataGraphs: AbstractDataGraph, edge_data, edge_data_type, vertex_data using Graphs: AbstractEdge, AbstractGraph using NamedGraphs.GraphsExtensions: boundary_edges using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, parent +using NamedGraphs: AbstractEdges, AbstractVertices, to_graph_index -messages(bp_cache::AbstractGraph) = edge_data(bp_cache) -messages(bp_cache::AbstractGraph, edges) = map(e -> message(bp_cache, e), edges) +messages(bpc::AbstractDataGraph) = edge_data(bpc) +messages(bpc::AbstractGraph, edges) = map(e -> message(bpc, e), edges) -message(bp_cache::AbstractGraph, edge::AbstractEdge) = messages(bp_cache)[edge] +message(bpc::AbstractGraph, edge) = messages(bpc)[edge] -deletemessage!(bp_cache::AbstractGraph, edge) = not_implemented() -function deletemessage!(bp_cache::AbstractDataGraph, edge) - ms = messages(bp_cache) - delete!(ms, edge) - return bp_cache -end +deletemessage!(bpc::AbstractGraph, edge) = not_implemented() -function deletemessages!(bp_cache::AbstractGraph, edges = edges(bp_cache)) +function deletemessages!(bpc::AbstractGraph, edges = edges(bpc)) for e in edges - deletemessage!(bp_cache, e) + deletemessage!(bpc, e) end - return bp_cache + return bpc end -setmessage!(bp_cache::AbstractGraph, edge, message) = not_implemented() -function setmessage!(bp_cache::AbstractDataGraph, edge, message) - setindex!(bp_cache, message, edge) - return bp_cache -end -function setmessage!(bp_cache::QuotientView, edge, message) - setmessages!(parent(bp_cache), QuotientEdge(edge), message) - return bp_cache +# Fallback; assume `setindex!` is implemented. +function setmessage!(bpc::AbstractGraph, edge, message) + bpc[edge] = message + return bpc end - -function setmessages!(bp_cache::AbstractGraph, edge::QuotientEdge, message) - for e in edges(bp_cache, edge) - setmessage!(parent(bp_cache), e, message[e]) +function setmessages!(bpc::AbstractGraph, messages) + for (key, val) in messages + setmessage!(bpc, key, val) end - return bp_cache + return bpc end function setmessages!(bpc_dst::AbstractGraph, bpc_src::AbstractGraph, edges) for e in edges @@ -45,31 +36,32 @@ function setmessages!(bpc_dst::AbstractGraph, bpc_src::AbstractGraph, edges) return bpc_dst end -factors(bpc::AbstractGraph) = vertex_data(bpc) -factors(bpc::AbstractGraph, vertices::Vector) = [factor(bpc, v) for v in vertices] -factors(bpc::AbstractGraph{V}, vertex::V) where {V} = factors(bpc, V[vertex]) +factors(bpc::AbstractDataGraph) = vertex_data(bpc) +factors(bpc::AbstractGraph, vertices) = map(v -> factor(bpc, v), vertices) factor(bpc::AbstractGraph, vertex) = bpc[vertex] -setfactor!(bpc::AbstractGraph, vertex, factor) = not_implemented() -function setfactor!(bpc::AbstractDataGraph, vertex, factor) - fs = factors(bpc) - setindex!(fs, vertex, factor) +function setfactor!(bpc::AbstractGraph, vertex, factor) + bpc[vertex] = factor return bpc end -function region_scalar(bp_cache::AbstractGraph, edge::AbstractEdge; alg = "exact") +# Internal convenience only +_graph_index_scalar(bpc::AbstractGraph, vertex) = vertex_scalar(bpc, vertex) +_graph_index_scalar(bpc::AbstractGraph, edge::AbstractEdge) = edge_scalar(bpc, edge) + +function edge_scalar(bp_cache::AbstractGraph, edge; kwargs...) # Make generic to deal with the possibilty of multiple messages. m1s = messages(bp_cache, [edge]) m2s = messages(bp_cache, [reverse(edge)]) - return contract_network(vcat(m1s, m2s); alg)[] + return contract_network(vcat(m1s, m2s); kwargs...)[] end -function region_scalar(bp_cache::AbstractGraph, vertex; alg = "exact") +function vertex_scalar(bp_cache::AbstractGraph, vertex; kwargs...) messages = incoming_messages(bp_cache, vertex) - state = factors(bp_cache, vertex) + state = factors(bp_cache, [vertex]) - return contract_network(vcat(messages, state); alg)[] + return contract_network(vcat(messages, state); kwargs...)[] end message_type(bpc::AbstractGraph) = message_type(typeof(bpc)) @@ -77,18 +69,18 @@ message_type(G::Type{<:AbstractGraph}) = eltype(Base.promote_op(messages, G)) message_type(type::Type{<:AbstractDataGraph}) = edge_data_type(type) function vertex_scalars(bp_cache::AbstractGraph, vertices = vertices(bp_cache)) - return map(v -> region_scalar(bp_cache, v), vertices) + return map(v -> vertex_scalar(bp_cache, v), vertices) end function edge_scalars( bp_cache::AbstractGraph, edges = edges(undirected_graph(underlying_graph(bp_cache))) ) - return map(e -> region_scalar(bp_cache, e), edges) + return map(e -> edge_scalar(bp_cache, e), edges) end -function scalar_factors_quotient(bp_cache::AbstractGraph) - return vertex_scalars(bp_cache), edge_scalars(bp_cache) +function region_scalar(bpc::AbstractGraph, region) + return mapreduce(ind -> _graph_index_scalar(bpc, ind), *, region) end function incoming_messages(bp_cache::AbstractGraph, vertices; ignore_edges = []) @@ -127,8 +119,9 @@ factor_type(::Type{<:AbstractBeliefPropagationCache{<:Any, VD}}) where {VD} = VD message_type(bpc::AbstractBeliefPropagationCache) = message_type(typeof(bpc)) message_type(::Type{<:AbstractBeliefPropagationCache{<:Any, <:Any, ED}}) where {ED} = ED -function logscalar(bp_cache::AbstractBeliefPropagationCache) - numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache) +function logscalar(bpc::AbstractBeliefPropagationCache) + numerator_terms = vertex_scalars(bpc) + denominator_terms = edge_scalars(bpc) if any(t -> real(t) < 0, numerator_terms) numerator_terms = complex.(numerator_terms) From 3720391325a940d613aa32041eee89e5a31d9060 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 27 Apr 2026 13:42:24 -0400 Subject: [PATCH 56/86] Remove `Graphs.connected_components` method for `TensorNetwork` This method was just forwarding the underlying graph. --- src/tensornetwork.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 6d55a6a..5a0eef6 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -117,10 +117,6 @@ end NamedGraphs.convert_vertextype(::Type{V}, tn::TensorNetwork{V}) where {V} = tn NamedGraphs.convert_vertextype(V::Type, tn::TensorNetwork) = TensorNetwork{V}(tn) -function Graphs.connected_components(tn::TensorNetwork) - return Graphs.connected_components(underlying_graph(tn)) -end - function Graphs.rem_edge!(tn::TensorNetwork, e) if !has_edge(underlying_graph(tn), e) return false From ccdcb743463935e518c5be543ebf7aa20f1cb59f Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 27 Apr 2026 13:49:20 -0400 Subject: [PATCH 57/86] Remove unecessary `symnameddims` method. --- src/LazyNamedDimsArrays/symbolicnameddimsarray.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl b/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl index 44bae0a..172ec08 100644 --- a/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl +++ b/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl @@ -5,9 +5,6 @@ const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} = function symnameddims(symname, dims) return lazy(nameddims(SymbolicArray(symname, denamed.(dims)), name.(dims))) end -function symnameddims(name, ndarray::AbstractNamedDimsArray) - return symnameddims(name, Tuple(inds(ndarray))) -end symnameddims(name) = symnameddims(name, ()) using AbstractTrees: AbstractTrees function AbstractTrees.printnode(io::IO, a::SymbolicNamedDimsArray) From d59c3e16e6bc4b4d631c9d3d925e75230e668af5 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 27 Apr 2026 13:49:26 -0400 Subject: [PATCH 58/86] Remove confusing code comment. --- src/beliefpropagation/abstractbeliefpropagationcache.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 662749c..0a3e28c 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -51,7 +51,6 @@ _graph_index_scalar(bpc::AbstractGraph, vertex) = vertex_scalar(bpc, vertex) _graph_index_scalar(bpc::AbstractGraph, edge::AbstractEdge) = edge_scalar(bpc, edge) function edge_scalar(bp_cache::AbstractGraph, edge; kwargs...) - # Make generic to deal with the possibilty of multiple messages. m1s = messages(bp_cache, [edge]) m2s = messages(bp_cache, [reverse(edge)]) return contract_network(vcat(m1s, m2s); kwargs...)[] From 9a2a88ea0704a0aa33dbaef171e947505bd4e52f Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 27 Apr 2026 13:53:16 -0400 Subject: [PATCH 59/86] Remove `beliefpropagation_sweep` in favour of constructor call. --- src/beliefpropagation/beliefpropagationproblem.jl | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index a3443e8..ca5e1e7 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -218,13 +218,8 @@ function select_algorithm( edge_kwargs = rows(extended_kwargs, maxiter) return BeliefPropagation(maxiter; stopping_criterion) do repnum - return beliefpropagation_sweep(cache; edges, edge_kwargs[repnum]...) - end -end - -# A single sweep across the given edges. -function beliefpropagation_sweep(::BeliefPropagationCache; edges, kwargs...) - return BeliefPropagationSweep(edges) do edge - return SimpleMessageUpdate(edge; kwargs...) + return BeliefPropagationSweep(edges) do edge + return SimpleMessageUpdate(edge; edge_kwargs[repnum]...) + end end end From 2ae7100775f0c7a0e1994655b55c44b7f8b9e043 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 28 Apr 2026 09:17:29 -0400 Subject: [PATCH 60/86] Fix message type initialization failing when only factors are provided. --- src/beliefpropagation/beliefpropagationcache.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index c9a6991..0971303 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -67,7 +67,7 @@ function BeliefPropagationCache(callable::Base.Callable, network::AbstractGraph) end function BeliefPropagationCache(graph::AbstractGraph, factors::Dictionary) - MT = vertex_data_type(typeof(graph)) + MT = eltype(factors) return BeliefPropagationCache(MT, graph, factors) end From bf4d0fe0dde1a999c9f3b04d969cea2a2a596c54 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 28 Apr 2026 10:06:06 -0400 Subject: [PATCH 61/86] Formatting. --- .../AlgorithmsInterfaceExtensions.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index fe749b4..f042dc0 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -152,7 +152,9 @@ end abstract type NestedAlgorithm <: Algorithm end nested_algorithm(f::Function, int::Int; kwargs...) = nested_algorithm(f, 1:int; kwargs...) -nested_algorithm(f::Function, iterable; kwargs...) = DefaultNestedAlgorithm(f, iterable; kwargs...) +function nested_algorithm(f::Function, iterable; kwargs...) + return DefaultNestedAlgorithm(f, iterable; kwargs...) +end max_iterations(algorithm::NestedAlgorithm) = length(algorithm.algorithms) From d33a58bd3ccd7d7d58ac7ec6b6a168e681d901c1 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 28 Apr 2026 10:24:34 -0400 Subject: [PATCH 62/86] Remove `edge_data_type` method for `AbstractTensorNetwork` An `AbstractTensorNetwork` has edge type `Nothing`, which can be obtained from the `AbstractDataGraph` method. --- src/abstracttensornetwork.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index bed2ac7..0cb997f 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -17,8 +17,6 @@ abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} # Need to be careful about removing edges from tensor networks in case there is a bond Graphs.rem_edge!(::AbstractTensorNetwork, edge) = not_implemented() -DataGraphs.edge_data_type(::Type{<:AbstractTensorNetwork}) = not_implemented() - # Graphs.jl overloads function Graphs.weights(graph::AbstractTensorNetwork) V = vertextype(graph) From e5619be92af30ca4736e30d0ea3fef991024b1fc Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 28 Apr 2026 10:24:44 -0400 Subject: [PATCH 63/86] Add some tests for `TensorNetwork` type. --- test/Project.toml | 2 ++ test/test_tensornetwork.jl | 31 +++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) create mode 100644 test/test_tensornetwork.jl diff --git a/test/Project.toml b/test/Project.toml index ee4dbd0..4bcd159 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,7 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" AlgorithmsInterface = "d1e3940c-cd12-4505-8585-b0a4b322527d" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5" +DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a" DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -27,6 +28,7 @@ path = ".." AbstractTrees = "0.4.5" AlgorithmsInterface = "0.1" Aqua = "0.8.14" +DataGraphs = "0.4" DiagonalArrays = "0.3.23" Dictionaries = "0.4.5" Graphs = "1.13.1" diff --git a/test/test_tensornetwork.jl b/test/test_tensornetwork.jl new file mode 100644 index 0000000..7c5b818 --- /dev/null +++ b/test/test_tensornetwork.jl @@ -0,0 +1,31 @@ +using DataGraphs: assigned_edge_data, assigned_vertex_data +using Graphs: dst, edges, has_edge, ne, nv, src, vertices +using ITensorBase: Index +using ITensorNetworksNext: TensorNetwork +using NamedGraphs.NamedGraphGenerators: named_grid +using NamedGraphs: similar_graph +using Test: @test, @testset + +@testset "`TensorNetwork`" begin + @testset "DataGraphs/NamedGraphs interface" begin + dims = (3, 3) + g = named_grid(dims) + s = Dict(v => Index(2) for v in vertices(g)) + tn = TensorNetwork(g) do v + return randn(s[v]) + end + + stn = similar_graph(tn) + @test stn isa TensorNetwork + @test vertices(stn) == vertices(tn) + @test edges(stn) == edges(tn) + @test isempty(assigned_vertex_data(stn)) + @test isempty(assigned_edge_data(stn)) + + stn = similar_graph(tn, vertices(tn)) + @test vertices(stn) == vertices(tn) + @test ne(stn) == 0 + @test isempty(assigned_vertex_data(stn)) + @test isempty(assigned_edge_data(stn)) + end +end From 19588cee5b43070f916748f95ac2423f9a2cb510 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 28 Apr 2026 10:47:38 -0400 Subject: [PATCH 64/86] Bug fixes; more tests --- src/tensornetwork.jl | 4 ++-- test/test_tensornetwork.jl | 19 +++++++++++++++++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 5a0eef6..80f81a0 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -110,7 +110,7 @@ TensorNetwork(tn::TensorNetwork) = copy(tn) TensorNetwork{V}(tn::TensorNetwork{V}) where {V} = copy(tn) function TensorNetwork{V}(tn::TensorNetwork) where {V} g = convert_vertextype(V, underlying_graph(tn)) - d = dictionary(V(k) => tn[k] for k in keys(d)) + d = dictionary(V(k) => tn[k] for k in vertices(tn)) return TensorNetwork(g, d) end @@ -134,7 +134,7 @@ end function NamedGraphs.similar_graph( type::Type{<:TensorNetwork}, - vertices + vertices = vertextype(type)[] ) DT = fieldtype(type, :tensors) empty_dict = DT() diff --git a/test/test_tensornetwork.jl b/test/test_tensornetwork.jl index 7c5b818..f2d0666 100644 --- a/test/test_tensornetwork.jl +++ b/test/test_tensornetwork.jl @@ -1,9 +1,10 @@ -using DataGraphs: assigned_edge_data, assigned_vertex_data +using DataGraphs: assigned_edge_data, assigned_vertex_data, vertex_data using Graphs: dst, edges, has_edge, ne, nv, src, vertices using ITensorBase: Index using ITensorNetworksNext: TensorNetwork +using NamedGraphs.GraphsExtensions: vertextype using NamedGraphs.NamedGraphGenerators: named_grid -using NamedGraphs: similar_graph +using NamedGraphs: convert_vertextype, similar_graph using Test: @test, @testset @testset "`TensorNetwork`" begin @@ -27,5 +28,19 @@ using Test: @test, @testset @test ne(stn) == 0 @test isempty(assigned_vertex_data(stn)) @test isempty(assigned_edge_data(stn)) + + stn = similar_graph(typeof(tn)) + @test nv(stn) == 0 + @test stn isa typeof(tn) + + stn = similar_graph(typeof(tn), vertices(tn)) + @test nv(stn) == nv(tn) + @test ne(stn) == 0 + @test stn isa typeof(tn) + + ctn = convert_vertextype(Tuple{Float64, Float64}, tn) + @test ctn isa TensorNetwork + @test vertextype(ctn) == Tuple{Float64, Float64} + @test collect(vertex_data(ctn)) == collect(vertex_data(tn)) end end From b520afd7660a39b54a92788465fc30a0dc121bfb Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 28 Apr 2026 16:52:38 -0400 Subject: [PATCH 65/86] Using `Inf` instead of `NaN` for delta initialization in `StopWhenConvergedState`. --- src/beliefpropagation/beliefpropagationproblem.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index ca5e1e7..e4a1a00 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -9,12 +9,12 @@ using NamedGraphs.GraphsExtensions: add_edges!, boundary_edges, subgraph using NamedGraphs.PartitionedGraphs: quotientvertices @kwdef struct StopWhenConverged{Tol <: Real} <: AI.StoppingCriterion - tol::Tol = NaN + tol::Tol = 0.0 end @kwdef mutable struct StopWhenConvergedState{Iterate, Delta <: Real} <: AI.StoppingCriterionState - delta::Delta = NaN + delta::Delta = Inf at_iteration::Int = -1 previous_iterate::Iterate end @@ -29,7 +29,7 @@ function AI.initialize_state!( ::StopWhenConverged, st::StopWhenConvergedState ) - st.delta = NaN + st.delta = Inf return st end @@ -201,7 +201,7 @@ function select_algorithm( cache::AbstractBeliefPropagationCache; edges = forest_cover_edge_sequence(cache), maxiter = is_tree(cache) ? 1 : nothing, - tol = NaN, + tol = nothing, kwargs... ) if isnothing(maxiter) @@ -210,7 +210,7 @@ function select_algorithm( stopping_criterion = AI.StopAfterIteration(maxiter) - if !isnan(tol) + if !isnothing(tol) stopping_criterion = stopping_criterion | StopWhenConverged(tol) end From 397733a84f3f0246a253be5d5d0e3470c990e2c5 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 28 Apr 2026 16:53:27 -0400 Subject: [PATCH 66/86] Add some basic tests for `PartitionedGraphs` interactions with `TensorNetwork`. --- test/test_tensornetwork.jl | 109 ++++++++++++++++++++++++++++++++++++- 1 file changed, 107 insertions(+), 2 deletions(-) diff --git a/test/test_tensornetwork.jl b/test/test_tensornetwork.jl index f2d0666..aaf96e8 100644 --- a/test/test_tensornetwork.jl +++ b/test/test_tensornetwork.jl @@ -1,11 +1,15 @@ -using DataGraphs: assigned_edge_data, assigned_vertex_data, vertex_data +using DataGraphs: assigned_edge_data, assigned_vertex_data, underlying_graph, vertex_data using Graphs: dst, edges, has_edge, ne, nv, src, vertices using ITensorBase: Index +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray using ITensorNetworksNext: TensorNetwork using NamedGraphs.GraphsExtensions: vertextype using NamedGraphs.NamedGraphGenerators: named_grid +using NamedGraphs.PartitionedGraphs: AbstractPartitionedGraph, QuotientVertex, departition, + partitioned_vertices, partitionedgraph, quotient_graph, quotient_graph_type, + quotientvertices using NamedGraphs: convert_vertextype, similar_graph -using Test: @test, @testset +using Test: @test, @test_throws, @testset @testset "`TensorNetwork`" begin @testset "DataGraphs/NamedGraphs interface" begin @@ -43,4 +47,105 @@ using Test: @test, @testset @test vertextype(ctn) == Tuple{Float64, Float64} @test collect(vertex_data(ctn)) == collect(vertex_data(tn)) end + + @testset "`PartitionedGraphs`" begin + dims = (3, 3) + g = named_grid(dims) + s = Dict(v => Index(2) for v in vertices(g)) + tn = TensorNetwork(g) do v + return randn(s[v]) + end + + # Row partition: each partition is one row of the grid. + row_parts = [[(i, j) for i in 1:dims[1]] for j in 1:dims[2]] + + @testset "default `partitioned_vertices`" begin + # By default the entire underlying graph is one partition. + pvs = partitioned_vertices(tn) + @test length(pvs) == 1 + @test issetequal(only(pvs), vertices(tn)) + end + + @testset "default `quotientvertices`" begin + qvs = collect(quotientvertices(tn)) + @test length(qvs) == 1 + @test only(qvs) isa QuotientVertex + end + + @testset "`tn[QuotientVertex(...)]` (default)" begin + qv = only(collect(quotientvertices(tn))) + data = tn[qv] + @test data isa LazyNamedDimsArray + end + + @testset "`quotient_graph` (default partitioning)" begin + qtn = quotient_graph(tn) + @test qtn isa TensorNetwork + @test nv(qtn) == 1 + @test ne(qtn) == 0 + v = only(collect(vertices(qtn))) + @test qtn[v] isa LazyNamedDimsArray + end + + @testset "`quotient_graph_type`" begin + QT = quotient_graph_type(typeof(tn)) + @test QT <: TensorNetwork + qtn = quotient_graph(tn) + @test vertextype(qtn) === vertextype(QT) + end + + @testset "`partitionedgraph(tn, parts)`" begin + ptn = partitionedgraph(tn, row_parts) + @test ptn isa TensorNetwork + # The set of underlying vertices/edges is preserved. + @test issetequal(vertices(ptn), vertices(tn)) + @test issetequal(edges(ptn), edges(tn)) + @test nv(ptn) == nv(tn) + @test ne(ptn) == ne(tn) + # Vertex data is copied, not aliased. + @test collect(vertex_data(ptn)) == collect(vertex_data(tn)) + @test vertex_data(ptn) !== vertex_data(tn) + end + + @testset "`partitioned_vertices` of partitioned tn" begin + ptn = partitionedgraph(tn, row_parts) + pvs = partitioned_vertices(ptn) + @test length(pvs) == dims[2] + for part in pvs + @test length(part) == dims[1] + end + @test issetequal(reduce(vcat, pvs), vertices(tn)) + end + + @testset "`tn[QuotientVertex(...)]` (partitioned)" begin + ptn = partitionedgraph(tn, row_parts) + for qv in quotientvertices(ptn) + @test ptn[qv] isa LazyNamedDimsArray + end + end + + @testset "`quotient_graph` of partitioned tn" begin + ptn = partitionedgraph(tn, row_parts) + qtn = quotient_graph(ptn) + @test qtn isa TensorNetwork + @test nv(qtn) == dims[2] + # The row-partitioned grid quotients to a path graph of length `dims[2]`. + @test ne(qtn) == dims[2] - 1 + for v in vertices(qtn) + @test qtn[v] isa LazyNamedDimsArray + end + end + + @testset "`departition`" begin + # `departition` on a non-partitioned tn returns itself. + @test departition(tn) === tn + + # `departition` on a partitioned tn unwraps one layer of partitioning. + ptn = partitionedgraph(tn, row_parts) + dtn = departition(ptn) + @test dtn isa TensorNetwork + @test issetequal(vertices(dtn), vertices(tn)) + @test issetequal(edges(dtn), edges(tn)) + end + end end From 44f063af2d7dafc826dd6817a39f5a76b17e0e7c Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 29 Apr 2026 10:13:04 -0400 Subject: [PATCH 67/86] Add tests via Claude. --- test/test_new.jl | 742 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 742 insertions(+) create mode 100644 test/test_new.jl diff --git a/test/test_new.jl b/test/test_new.jl new file mode 100644 index 0000000..3acc422 --- /dev/null +++ b/test/test_new.jl @@ -0,0 +1,742 @@ +import AlgorithmsInterface as AI +import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +using AbstractTrees: AbstractTrees +using BackendSelection: @Algorithm_str, Algorithm +using DataGraphs: vertex_data +using Dictionaries: Dictionary +using Graphs: Graphs, AbstractEdge, dst, edges, has_edge, ne, nv, src, vertices +using ITensorBase: ITensor, Index +using ITensorNetworksNext: BeliefPropagationCache, EigsolveRegion, ITensorNetworksNext, + TensorNetwork, contract_network, dmrg, factor, factor_type, factors, linkinds, message, + message_type, messages, scalar +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, LazyNamedDimsArrays, Mul, + SymbolicArray, ismul, lazy, parenttype, substitute, symnameddims +using ITensorNetworksNext.TensorNetworkGenerators: ising_network +using NamedDimsArrays: AbstractNamedDimsArray, NamedDimsArray, denamed, dimnames, inds, + nameddims, namedoneto +using NamedGraphs: NamedGraphs +using NamedGraphs.GraphsExtensions: GraphsExtensions, incident_edges +using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree +using TermInterface: arguments, head, iscall, isexpr, operation +using Test: @test, @test_throws, @testset +using WrappedUnions: unwrap + +# Type definitions used by some tests below; must be at file scope. +struct _DummyNonIter <: AIE.NonIterativeAlgorithm end +struct _DummyProblem <: AIE.Problem end + +@testset "test_new.jl" begin + # --------------------------------------------------------------------------- + # AbstractTensorNetwork: iteration / keys / eltype / is_directed / show + # --------------------------------------------------------------------------- + @testset "AbstractTensorNetwork interface" begin + g = named_grid((2, 2)) + s = Dict(v => Index(2) for v in vertices(g)) + tn = TensorNetwork(g) do v + return randn(s[v]) + end + + # `iterate` works (delegates to `vertex_data`). + @test !isempty(collect(tn)) + # `keys` returns vertices. + @test issetequal(keys(tn), vertices(tn)) + # `eltype` matches the eltype of the vertex data. + @test eltype(tn) === eltype(vertex_data(tn)) + # `is_directed` is `false` for AbstractTensorNetwork. + @test !Graphs.is_directed(typeof(tn)) + + # `show` MIME and default both succeed and mention vertices/edges. + s_plain = sprint(show, MIME"text/plain"(), tn) + @test occursin("vertices", s_plain) + @test occursin("edge", s_plain) + s_default = sprint(show, tn) + @test occursin("vertices", s_default) + + # `setindex!` for edges is unimplemented. + e = first(edges(tn)) + @test_throws ErrorException tn[e] = randn(2, 2) + @test_throws ErrorException tn[src(e) => dst(e)] = randn(2, 2) + end + + # --------------------------------------------------------------------------- + # `linkaxes` / `linknames` on a TensorNetwork + # --------------------------------------------------------------------------- + @testset "linkaxes / linknames" begin + g = named_grid((3,)) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + e = first(edges(tn)) + p = src(e) => dst(e) + + li = linkinds(tn, e) + la_e = ITensorNetworksNext.linkaxes(tn, e) + la_p = ITensorNetworksNext.linkaxes(tn, p) + @test la_e == la_p + @test length(la_e) == length(li) + + ln_e = ITensorNetworksNext.linknames(tn, e) + ln_p = ITensorNetworksNext.linknames(tn, p) + @test ln_e == ln_p + @test length(ln_e) == length(li) + end + + # --------------------------------------------------------------------------- + # expression-shape predicates + # --------------------------------------------------------------------------- + @testset "is_setindex!_expr / is_assignment_expr / is_getindex_expr" begin + @test ITensorNetworksNext.is_setindex!_expr(:(a[1] = 2)) + @test !ITensorNetworksNext.is_setindex!_expr(:(a[1])) + @test !ITensorNetworksNext.is_setindex!_expr(:(a + b)) + @test !ITensorNetworksNext.is_setindex!_expr(42) + + @test ITensorNetworksNext.is_assignment_expr(:(x = 1)) + @test !ITensorNetworksNext.is_assignment_expr(:(x + 1)) + @test !ITensorNetworksNext.is_assignment_expr(42) + + @test ITensorNetworksNext.is_getindex_expr(:(a[1])) + @test !ITensorNetworksNext.is_getindex_expr(:(a + 1)) + @test !ITensorNetworksNext.is_getindex_expr(42) + end + + # --------------------------------------------------------------------------- + # `add_missing_edges!`: no-op on a well-formed network. + # --------------------------------------------------------------------------- + @testset "add_missing_edges!" begin + g = named_grid((2, 2)) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + es_before = collect(edges(tn)) + ITensorNetworksNext.add_missing_edges!(tn) + @test issetequal(edges(tn), es_before) + + v = first(vertices(tn)) + ITensorNetworksNext.add_missing_edges!(tn, v) + @test issetequal(edges(tn), es_before) + end + + # --------------------------------------------------------------------------- + # `TensorNetwork` constructor / copy / convert variants and `rem_edge!` + # --------------------------------------------------------------------------- + @testset "TensorNetwork copy / convert / rem_edge!" begin + g = named_grid((3,)) + s = Dict(v => Index(2) for v in vertices(g)) + tn = TensorNetwork(g) do v + return randn(s[v]) + end + + # `TensorNetwork(tensors)` infers the graph from shared indices. + link = Index(2) + A = randn(s[(1,)], link) + B = randn(s[(2,)], link) + tensors = Dictionary([(1,), (2,)], [A, B]) + tn_inferred = TensorNetwork(tensors) + @test tn_inferred isa TensorNetwork + @test issetequal(vertices(tn_inferred), [(1,), (2,)]) + @test ne(tn_inferred) == 1 + + # `copy` produces an independent TensorNetwork. + tn2 = copy(tn) + @test tn2 isa TensorNetwork + @test issetequal(vertices(tn2), vertices(tn)) + @test issetequal(edges(tn2), edges(tn)) + @test vertex_data(tn2) !== vertex_data(tn) + + # `TensorNetwork(tn)` and `TensorNetwork{V}(tn)` (same V) call `copy`. + tn3 = TensorNetwork(tn) + @test tn3 isa TensorNetwork + @test issetequal(vertices(tn3), vertices(tn)) + + V = GraphsExtensions.vertextype(tn) + tn4 = TensorNetwork{V}(tn) + @test tn4 isa TensorNetwork + @test issetequal(vertices(tn4), vertices(tn)) + + # `TensorNetwork{V}(tn)` with a different V re-keys vertices. + tn5 = TensorNetwork{Tuple{Float64}}(tn) + @test tn5 isa TensorNetwork + @test all(v -> v isa Tuple{Float64}, vertices(tn5)) + + # `rem_edge!` returns false for an absent edge. + bad_edge = (1,) => (3,) + @test !Graphs.rem_edge!(tn, bad_edge) + + # `rem_edge!` on an edge with shared inds throws. + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn_link = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + e = first(edges(tn_link)) + @test_throws ArgumentError Graphs.rem_edge!(tn_link, e) + end + + # --------------------------------------------------------------------------- + # `induced_subgraph_from_vertices` for TensorNetwork + # --------------------------------------------------------------------------- + @testset "TensorNetwork induced_subgraph_from_vertices" begin + g = named_grid((3,)) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + sub_vs = [(1,), (2,)] + subtn, _ = NamedGraphs.induced_subgraph_from_vertices(tn, sub_vs) + @test subtn isa TensorNetwork + @test issetequal(vertices(subtn), sub_vs) + end + + # --------------------------------------------------------------------------- + # `BeliefPropagationCache` constructor variants and message/factor mutators + # --------------------------------------------------------------------------- + @testset "BeliefPropagationCache constructors and mutators" begin + g = named_grid((2, 2)) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + # `BeliefPropagationCache(network)` (no callable; cache constructed). + bpc1 = BeliefPropagationCache(tn) + @test bpc1 isa BeliefPropagationCache + @test length(factors(bpc1)) == nv(tn) + + # `BeliefPropagationCache(callable, network)` + bpc2 = BeliefPropagationCache(tn) do edge + return ones(Tuple(linkinds(tn, edge))) + end + @test length(messages(bpc2)) == 2 * length(edges(g)) + + # `copy` is independent of the source. + bpc_copy = copy(bpc2) + @test bpc_copy isa BeliefPropagationCache + @test length(messages(bpc_copy)) == length(messages(bpc2)) + + # `setmessage!` and `setfactor!` write through the cache. + e = first(edges(bpc2)) + new_msg = ones(Tuple(linkinds(tn, e))) .* 2.0 + ITensorNetworksNext.setmessage!(bpc2, e, new_msg) + @test message(bpc2, e) == new_msg + + v = first(vertices(bpc2)) + old_factor = factor(bpc2, v) + new_factor = old_factor .* 2 + ITensorNetworksNext.setfactor!(bpc2, v, new_factor) + @test factor(bpc2, v) == new_factor + + # `setmessages!` accepts a mapping and updates entries. + e2 = first(edges(bpc2)) + msg2 = ones(Tuple(linkinds(tn, e2))) .* 3.0 + ITensorNetworksNext.setmessages!(bpc2, Dict(e2 => msg2)) + @test message(bpc2, e2) == msg2 + + # `setmessages!(dst, src, edges)` copies messages between caches. + bpc_dst = BeliefPropagationCache(tn) do edge + return zeros(Tuple(linkinds(tn, edge))) + end + e3 = first(edges(bpc2)) + ITensorNetworksNext.setmessages!(bpc_dst, bpc2, [e3]) + @test message(bpc_dst, e3) == message(bpc2, e3) + end + + # --------------------------------------------------------------------------- + # AbstractBeliefPropagationCache helpers: vertex/edge/region scalars, + # incoming_messages, map_messages/map_factors, factor_type, message_type. + # --------------------------------------------------------------------------- + @testset "BeliefPropagationCache scalars / mappers" begin + g = named_grid((2,)) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + bpc = BeliefPropagationCache(tn) do edge + return ones(Tuple(linkinds(tn, edge))) + end + bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) + + v = first(vertices(bpc)) + e = first(edges(bpc)) + + # Vertex/edge/region scalars. + vs = ITensorNetworksNext.vertex_scalar(bpc, v) + es = ITensorNetworksNext.edge_scalar(bpc, e) + @test vs isa Number + @test es isa Number + + rs = ITensorNetworksNext.region_scalar(bpc, [v, e]) + @test rs ≈ vs * es + + # `incoming_messages` excludes specified edges. + in_msgs = ITensorNetworksNext.incoming_messages(bpc, v) + in_msgs_filtered = ITensorNetworksNext.incoming_messages( + bpc, v; ignore_edges = [reverse(e)] + ) + @test length(in_msgs_filtered) <= length(in_msgs) + + # `factor_type` / `message_type` resolve to concrete types. + @test factor_type(bpc) isa Type + @test message_type(bpc) isa Type + + # `map_messages` and `map_factors` produce independent caches. + bpc_doubled = ITensorNetworksNext.map_messages(m -> 2 .* m, bpc) + @test message(bpc_doubled, e) ≈ 2 .* message(bpc, e) + + bpc_scaled = ITensorNetworksNext.map_factors(f -> f .* 2, bpc) + for vv in vertices(bpc_scaled) + @test factor(bpc_scaled, vv) ≈ factor(bpc, vv) .* 2 + end + + # `adapt_factors` and `adapt_messages` should at least be callable. + @test ITensorNetworksNext.adapt_factors(identity, bpc) isa BeliefPropagationCache + @test ITensorNetworksNext.adapt_messages(identity, bpc) isa BeliefPropagationCache + end + + # --------------------------------------------------------------------------- + # `logscalar` branches: complex-promotion path and zero denominator. + # --------------------------------------------------------------------------- + @testset "logscalar special branches" begin + g = named_grid((2,)) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + bpc = BeliefPropagationCache(tn) do edge + return ones(Tuple(linkinds(tn, edge))) + end + bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) + + # Negate one factor so the numerator product becomes negative, + # forcing a complex promotion in `logscalar`. + v = first(vertices(bpc)) + ITensorNetworksNext.setfactor!(bpc, v, -1 .* factor(bpc, v)) + @test ITensorNetworksNext.logscalar(bpc) isa Number + + # Zero out a message so a denominator term becomes zero -> -Inf. + bpc_zero = BeliefPropagationCache(tn) do edge + return zeros(Tuple(linkinds(tn, edge))) + end + @test ITensorNetworksNext.logscalar(bpc_zero) == -Inf + end + + # --------------------------------------------------------------------------- + # `induced_subgraph_bpcache` / induced_subgraph_from_vertices on a BPCache. + # --------------------------------------------------------------------------- + @testset "BeliefPropagationCache induced_subgraph" begin + g = named_grid((3,)) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + bpc = BeliefPropagationCache(tn) do edge + return ones(Tuple(linkinds(tn, edge))) + end + + sub_vs = [(1,), (2,)] + subbpc = subgraph(bpc, sub_vs) + @test subbpc isa BeliefPropagationCache + @test issetequal(vertices(subbpc), sub_vs) + @test has_edge(subbpc, (1,) => (2,)) + end + + # --------------------------------------------------------------------------- + # `forest_cover_edge_sequence` returns a sequence covering a tree. + # --------------------------------------------------------------------------- + @testset "forest_cover_edge_sequence" begin + g = named_comb_tree((3, 2)) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + bpc = BeliefPropagationCache(tn) do edge + return ones(Tuple(linkinds(tn, edge))) + end + + seq = ITensorNetworksNext.forest_cover_edge_sequence(bpc) + @test eltype(seq) <: AbstractEdge + @test !isempty(seq) + end + + # --------------------------------------------------------------------------- + # Belief propagation: `select_algorithm` errors when `maxiter` is required. + # --------------------------------------------------------------------------- + @testset "beliefpropagation select_algorithm error" begin + # 2x2 grid: not a tree, so `maxiter` cannot be defaulted. + g = named_grid((2, 2)) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + bpc = BeliefPropagationCache(tn) do edge + return ones(Tuple(linkinds(tn, edge))) + end + @test_throws ArgumentError ITensorNetworksNext.select_algorithm( + ITensorNetworksNext.beliefpropagation, bpc; maxiter = nothing + ) + end + + # --------------------------------------------------------------------------- + # `iterate_diff` and `SimpleMessageUpdate.getproperty(:kwargs)` path. + # --------------------------------------------------------------------------- + @testset "iterate_diff and SimpleMessageUpdate kwargs" begin + g = named_grid((2,)) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + bpc1 = BeliefPropagationCache(tn) do edge + return ones(Tuple(linkinds(tn, edge))) + end + bpc2 = BeliefPropagationCache(tn) do edge + return ones(Tuple(linkinds(tn, edge))) + end + + # Identical caches: diff should be ~0. + @test ITensorNetworksNext.iterate_diff(bpc1, bpc2) ≈ 0 atol = 1.0e-10 + + # `SimpleMessageUpdate.getproperty(:kwargs)` returns the NamedTuple. + edge = first(edges(bpc1)) + upd = ITensorNetworksNext.SimpleMessageUpdate(edge; normalize = false) + @test upd.kwargs isa NamedTuple + # Forwarded properties still work (`getfield(:kwargs)` then property). + @test upd.normalize == false + end + + # --------------------------------------------------------------------------- + # `contract_network`: unknown algorithm error and `left_associative` order. + # --------------------------------------------------------------------------- + @testset "contract_network error / left_associative" begin + g = named_grid((2,)) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + @test_throws ArgumentError contract_network(tn; alg = Algorithm"unknown_alg"()) + + # `contraction_order` for `left_associative` algorithm. + order = ITensorNetworksNext.contraction_order(tn; alg = Algorithm"left_associative"()) + @test order isa LazyNamedDimsArray + end + + # --------------------------------------------------------------------------- + # `dmrg`: thin wrappers and unimplemented `EigsolveRegion` step. + # --------------------------------------------------------------------------- + @testset "dmrg wrappers" begin + operator = "operator" + init = "init" + nsweeps = 2 + regions = ["region1"] + algorithm = ITensorNetworksNext.select_algorithm( + dmrg, operator, init; nsweeps, regions, maxdim = 10 + ) + + # `dmrg(operator, algorithm, state)` errors deep in `EigsolveRegion`'s solve!. + @test_throws Exception dmrg(operator, algorithm, init) + + # `dmrg(operator, state; ...)` builds the algorithm internally; same expected error. + @test_throws Exception dmrg(operator, init; nsweeps, regions, maxdim = 10) + + # The `EigsolveRegion`-specific `solve!` errors directly. + region = EigsolveRegion("region"; maxdim = 10) + problem = ITensorNetworksNext.EigenProblem(operator) + state = AI.initialize_state(problem, region; iterate = init) + @test_throws ErrorException AI.solve!(problem, region, state) + end + + # --------------------------------------------------------------------------- + # `LazyNamedDimsArrays`: error paths in lazy interface. + # --------------------------------------------------------------------------- + @testset "LazyNamedDimsArrays interface error paths" begin + a = nameddims(randn(2, 2), (:i, :j)) + la = lazy(a) + + # `getindex_lazy` errors on expressions, but works on a leaf. + @test la[1, 1] == a[1, 1] + expr = la * lazy(nameddims(randn(2, 2), (:j, :k))) + @test_throws ErrorException expr[1, 1] + + # `denamed` works on a leaf, errors on non-leaf. + @test denamed(la) == denamed(a) + @test_throws ErrorException LazyNamedDimsArrays.denamed_lazy(expr) + + # `dimnames` and `inds` on a `Mul`. + @test issetequal(dimnames(expr), [:i, :k]) + @test length(inds(expr)) == 2 + + # Equality and hash branches. + la2 = lazy(a) + @test la == la2 + @test isequal(la, la2) + @test !(la == expr) # leaf vs expression + @test !isequal(la, expr) # leaf vs expression + @test hash(la) == hash(la2) + + # `mul_lazy(a)` on a leaf wraps it in a `Mul`. + wrapped = *(la) + @test ismul(wrapped) + @test arguments(wrapped) == [la] + + # `mul_lazy(a)` on a Mul returns it unchanged. + @test *(expr) == expr + + # `mul_lazy(a, b; flatten=true)` flattens the arguments. + expr3 = lazy(nameddims(randn(2, 2), (:k, :l))) + flat = LazyNamedDimsArrays.mul_lazy(expr, expr3; flatten = true) + @test ismul(flat) + @test length(arguments(flat)) == 3 + + # Number * Number short-circuit. + @test LazyNamedDimsArrays.mul_lazy(2, 3) == 6 + + # Unsupported ops error. + @test_throws ErrorException la + la2 + @test_throws ErrorException la - la2 + @test_throws ErrorException -la + @test_throws ErrorException la / 2 + @test_throws ErrorException 2 * la + @test_throws ErrorException la * 2 + + # `maketerm` for non-`*` head errors. + @test_throws ErrorException LazyNamedDimsArrays.maketerm_lazy( + LazyNamedDimsArray, +, [la, la2], nothing + ) + + # `parenttype` resolution. + @test parenttype(LazyNamedDimsArray) === AbstractNamedDimsArray + @test parenttype(LazyNamedDimsArray{Float64}) === AbstractNamedDimsArray{Float64} + @test parenttype(typeof(la)) === typeof(a) + end + + # --------------------------------------------------------------------------- + # Lazy broadcasting (linear ops only; arbitrary ops error). + # --------------------------------------------------------------------------- + @testset "Lazy broadcasting" begin + a = nameddims(randn(2, 2), (:i, :j)) + la, la2 = lazy(a), lazy(a) + style = LazyNamedDimsArrays.LazyNamedDimsArrayStyle() + + # Broadcasted linear ops route through `+, -, *, /, unary -`, + # all of which themselves error in the lazy framework. + @test_throws ErrorException Base.Broadcast.broadcasted(style, +, la, la2) + @test_throws ErrorException Base.Broadcast.broadcasted(style, -, la, la2) + @test_throws ErrorException Base.Broadcast.broadcasted(style, *, 2.0, la) + @test_throws ErrorException Base.Broadcast.broadcasted(style, *, la, 2.0) + @test Base.Broadcast.broadcasted(style, *, 2.0, 3.0) == 6.0 + @test_throws ErrorException Base.Broadcast.broadcasted(style, /, la, 2.0) + @test_throws ErrorException Base.Broadcast.broadcasted(style, -, la) + + # Arbitrary functions error explicitly. + @test_throws ErrorException Base.Broadcast.broadcasted(style, sin, la) + end + + # --------------------------------------------------------------------------- + # `SymbolicArray`: getindex/setindex! errors, permutedims, show, printnode. + # --------------------------------------------------------------------------- + @testset "SymbolicArray operations" begin + sa = SymbolicArray(:x, (Base.OneTo(2), Base.OneTo(3))) + @test size(sa) == (2, 3) + + # Indexing errors. + @test_throws ErrorException sa[1, 1] + @test_throws ErrorException (sa[1, 1] = 0) + + # `permutedims`. + pa = permutedims(sa, (2, 1)) + @test size(pa) == (3, 2) + + # `show` writes the symbolic name. + s_plain = sprint(show, MIME"text/plain"(), sa) + @test occursin("x", s_plain) + s_default = sprint(show, sa) + @test occursin("SymbolicArray", s_default) + + # `printnode` writes the symbolic name. + s_node = sprint(AbstractTrees.printnode, sa) + @test occursin("x", s_node) + end + + # --------------------------------------------------------------------------- + # `SymbolicNamedDimsArray`: equality and printnode with non-zero ndims. + # --------------------------------------------------------------------------- + @testset "SymbolicNamedDimsArray equality / printnode" begin + i, j = namedoneto.(2, (:i, :j)) + sa = symnameddims(:a, (i, j)) + sa2 = symnameddims(:a, (i, j)) + sa_perm = symnameddims(:a, (j, i)) + sa_other = symnameddims(:b, (i, j)) + + # Equality: same name + same dimnames (any order) -> equal. + @test unwrap(sa) == unwrap(sa2) + @test unwrap(sa) == unwrap(sa_perm) + @test unwrap(sa) != unwrap(sa_other) + + # `printnode` on a non-scalar prints both name and dims. + s_node = sprint(AbstractTrees.printnode, unwrap(sa)) + @test occursin("a", s_node) + @test occursin("[", s_node) + end + + # --------------------------------------------------------------------------- + # `evaluation_time_complexity` / `flatten_expression` / `optimize_evaluation_order` + # --------------------------------------------------------------------------- + @testset "LazyNamedDimsArrays evaluation_order" begin + a = nameddims(randn(3, 3), (:i, :j)) + b = nameddims(randn(3, 3), (:j, :k)) + la, lb = lazy.((a, b)) + expr = la * lb + + # Time complexity for a known mul. + @test LazyNamedDimsArrays.evaluation_time_complexity(expr) > 0 + + # Flatten of a `Mul` of `Mul`s. + c = nameddims(randn(3, 3), (:k, :i)) + lc = lazy(c) + nested = (la * lb) * lc + flat = LazyNamedDimsArrays.flatten_expression(nested) + @test ismul(flat) + @test length(arguments(flat)) == 3 + + # `flatten_expression` is identity on leaves. + @test LazyNamedDimsArrays.flatten_expression(la) === la + + # `optimize_evaluation_order` on a leaf is identity. + @test LazyNamedDimsArrays.optimize_evaluation_order(la) === la + + # `optimize_contraction_order` with eager picks an ordering. + eager = Algorithm"eager"() + flat_expr = LazyNamedDimsArrays.flatten_expression((la * lb) * lc) + @test LazyNamedDimsArrays.optimize_evaluation_order(eager, flat_expr) isa + LazyNamedDimsArray + + # Time-complexity for scalar*tensor and tensor*scalar. + n = nameddims(randn(3, 3), (:i, :j)) + @test LazyNamedDimsArrays.time_complexity(*, 2.0, n) > 0 + @test LazyNamedDimsArrays.time_complexity(*, n, 2.0) > 0 + + # Time complexity for elementwise +. + n2 = nameddims(randn(3, 3), (:i, :j)) + @test LazyNamedDimsArrays.time_complexity(+, n, n2) > 0 + end + + # --------------------------------------------------------------------------- + # `nameddimsarraysextensions._hash` fallback for non-NamedDimsArray. + # --------------------------------------------------------------------------- + @testset "_hash fallback" begin + @test LazyNamedDimsArrays._hash(42, UInt64(0)) == hash(42, UInt64(0)) + @test LazyNamedDimsArrays._hash("x", UInt64(0)) == hash("x", UInt64(0)) + end + + # --------------------------------------------------------------------------- + # `generic_map` for arrays / dicts / sets. + # --------------------------------------------------------------------------- + @testset "generic_map" begin + @test LazyNamedDimsArrays.generic_map(x -> x + 1, [1, 2, 3]) == [2, 3, 4] + + d = Dict(:a => 1, :b => 2) + md = LazyNamedDimsArrays.generic_map(x -> x * 10, d) + @test md isa Dict + @test md[:a] == 10 + @test md[:b] == 20 + + ms = LazyNamedDimsArrays.generic_map(x -> x * 2, Set([1, 2, 3])) + @test ms == Set([2, 4, 6]) + end + + # --------------------------------------------------------------------------- + # `Mul` core hooks. + # --------------------------------------------------------------------------- + @testset "Mul / Applied basics" begin + a = lazy(nameddims(randn(2, 2), (:i, :j))) + b = lazy(nameddims(randn(2, 2), (:j, :i))) + m = Mul([a, b]) + + @test arguments(m) == [a, b] + @test operation(m) ≡ * + @test iscall(m) + @test isexpr(m) + @test head(m) ≡ * + + # `show` for an `Applied` writes parens-joined arguments. + @test occursin("*", sprint(show, m)) + + # Hashing of equal `Mul`s. + m2 = Mul([a, b]) + @test hash(m) == hash(m2) + end + + # --------------------------------------------------------------------------- + # AlgorithmsInterfaceExtensions: `NonIterativeAlgorithm` fallback `solve!`. + # --------------------------------------------------------------------------- + @testset "NonIterativeAlgorithm fallback solve!" begin + problem = _DummyProblem() + algorithm = _DummyNonIter() + state = AI.initialize_state(problem, algorithm; iterate = [0.0]) + @test_throws Exception AI.solve!(problem, algorithm, state) + end + + # --------------------------------------------------------------------------- + # Latent-bug catchers — these tests are currently expected to FAIL. + # They exercise code paths whose source references variables that aren't + # defined in the function body. They exist to surface those bugs the next + # time someone runs the suite, not to lock in current (buggy) behavior. + # --------------------------------------------------------------------------- + @testset "siteaxes / sitenames (latent UndefVarError)" begin + # Build a TN where each tensor has both link indices (one per neighbor) + # and a "site" index that no neighbor shares. + g = named_grid((3,)) + site_idx = Dict(v => Index(2) for v in vertices(g)) + link = Dict(e => Index(2) for e in edges(g)) + link = merge(link, Dict(reverse(e) => link[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = (site_idx[v], (link[e] for e in incident_edges(g, v))...) + return randn(is) + end + + e = first(edges(tn)) + + # Both functions reference `v` inside their `for v′ in neighbors(tn, v)` + # loop, but `v` is never defined in either body — only `edge` is in + # scope. Calling them currently throws `UndefVarError(:v)`. + # The expected (post-fix) behavior is to return a non-empty collection + # of the site axes / site names at the edge endpoints, so we assert that + # the call succeeds and returns something sensible. + sax = ITensorNetworksNext.siteaxes(tn, e) + @test sax isa AbstractVector || sax isa AbstractSet || sax isa Tuple + @test !isempty(sax) + + snm = ITensorNetworksNext.sitenames(tn, e) + @test snm isa AbstractVector || snm isa AbstractSet || snm isa Tuple + @test !isempty(snm) + end +end From 7197bcf9543ce697a137ac347f654254d33968fc Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 29 Apr 2026 13:25:42 -0400 Subject: [PATCH 68/86] Refine and redistribute generated tests --- test/test_beliefpropagation.jl | 238 +++++++++-- test/test_new.jl | 742 --------------------------------- test/test_tensornetwork.jl | 76 +++- 3 files changed, 270 insertions(+), 786 deletions(-) delete mode 100644 test/test_new.jl diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 57c8ffc..d9112b9 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -2,8 +2,10 @@ using DiagonalArrays: δ using Dictionaries: Dictionary, set! using Graphs: AbstractGraph, dst, edges, src, vertices using ITensorBase: ITensor, Index, noprime, prime -using ITensorNetworksNext: - ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, linkinds, scalar +using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, + edge_scalar, factor, factor_type, factors, incoming_messages, linkinds, message, + message_type, messages, region_scalar, scalar, setfactor!, setmessage!, setmessages!, + vertex_scalar, vertex_scalars using LinearAlgebra: LinearAlgebra using NamedDimsArrays: inds, name using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges, vertextype @@ -35,56 +37,208 @@ function spin_ice_tensornetwork(g) end @testset "BeliefPropagation" begin - @testset "$T" for T in (Float32, Float64, ComplexF64, BigFloat) - #Chain of tensors - dims = (2, 1) - g = named_grid(dims) - l = Dict(e => Index(2) for e in edges(g)) - l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) - - tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(T, Tuple(is)) + @testset "`BeliefPropagationCache`" begin + @testset "Basics" begin + dims = (3, 3) + g = named_grid(dims) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + bpc = BeliefPropagationCache(tn) do edge + return "$(src(edge)) => $(dst(edge))" + end + + @test factor_type(bpc) <: ITensor + @test message_type(bpc) <: String + @test length(factors(bpc)) == 9 + @test length(messages(bpc)) == 2 * length(edges(g)) + @test bpc[(2, 2)] == tn[(2, 2)] + @test factor(bpc, (1, 1)) == tn[(1, 1)] + @test bpc[(1, 1) => (1, 2)] == "(1, 1) => (1, 2)" + @test message(bpc, (2, 1) => (1, 1)) == "(2, 1) => (1, 1)" + + # set factor + f = factor(bpc, (1, 1)) + setfactor!(bpc, (1, 1), 2 * f) + @test factor(bpc, (1, 1)) == 2 * f + + # set message + setmessage!(bpc, (1, 1) => (1, 2), "new message") + @test message(bpc, (1, 1) => (1, 2)) == "new message" + + setmessages!(bpc, Dict(((1, 2) => (2, 2)) => "m1", ((2, 2) => (2, 3)) => "m2")) + @test message(bpc, (1, 1) => (1, 2)) == "new message" + @test message(bpc, (1, 2) => (2, 2)) == "m1" + @test message(bpc, (2, 2) => (2, 3)) == "m2" + + bpc_dst = BeliefPropagationCache(tn) do edge + return "" + end + setmessages!(bpc_dst, bpc, [(1, 2) => (2, 2), (2, 2) => (2, 3)]) + @test message(bpc_dst, (1, 1) => (1, 2)) == "" + @test message(bpc, (1, 2) => (2, 2)) == "m1" + @test message(bpc, (2, 2) => (2, 3)) == "m2" end + @testset "Vertex/region scalars" begin + g = named_path_graph(3) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(ComplexF32, Tuple(is)) + end - bpc = BeliefPropagationCache(tn) - bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) - z_bp = scalar(bpc) - z_exact = reduce(*, [tn[v] for v in vertices(g)])[] - @test z_bp ≈ z_exact - - #Tree of tensors - dims = (4, 3) - g = named_comb_tree(dims) - l = Dict(e => Index(3) for e in edges(g)) - l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) - tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(T, Tuple(is)) + bpc = BeliefPropagationCache(tn) do edge + return ones(Float64, Tuple(linkinds(tn, edge))) + end + + # Vertex/edge/region scalars. + @test vertex_scalar(bpc, 2) isa ComplexF64 + @test edge_scalar(bpc, 1 => 2) isa Float64 + + @test region_scalar(bpc, [1]) == vertex_scalar(bpc, 1) + @test region_scalar(bpc, [1 => 2]) == edge_scalar(bpc, 1 => 2) + @test region_scalar(bpc, [2 => 1]) == edge_scalar(bpc, 1 => 2) + @test region_scalar(bpc, [1, 2, 3]) == prod(vertex_scalars(bpc)) + + # `incoming_messages` excludes specified edges. + in_msgs = incoming_messages(bpc, 2) + in_msgs_filtered = incoming_messages( + bpc, 2; ignore_edges = [1 => 2] + ) + @test length(in_msgs) == 2 + @test length(in_msgs_filtered) == 1 + @test only(in_msgs_filtered) == bpc[3 => 2] + + # `factor_type` / `message_type` resolve to concrete types. + @test factor_type(bpc) <: ITensor + @test message_type(bpc) <: ITensor + + # `map_messages` and `map_factors` produce independent caches. + bpc_doubled = ITensorNetworksNext.map_messages(m -> 2 .* m, bpc) + @test !(bpc_doubled === bpc) + @test message(bpc_doubled, 1 => 2) ≈ 2 .* message(bpc, 1 => 2) + @test message(bpc_doubled, 2 => 3) ≈ 2 .* message(bpc, 2 => 3) + + bpc_scaled = ITensorNetworksNext.map_factors(f -> f .* 2, bpc) + @test !(bpc_scaled === bpc) + for vv in vertices(bpc_scaled) + @test factor(bpc_scaled, vv) ≈ factor(bpc, vv) .* 2 + end + + # `adapt_factors` and `adapt_messages` should at least be callable. + @test ITensorNetworksNext.adapt_factors(identity, bpc) isa + BeliefPropagationCache + @test ITensorNetworksNext.adapt_messages(identity, bpc) isa + BeliefPropagationCache end - bpc = BeliefPropagationCache(tn) - bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) - z_bp = scalar(bpc) - z_exact = reduce(*, [tn[v] for v in vertices(g)])[] - @test z_bp ≈ z_exact + @testset "subgraph" begin + g = named_grid((3,)) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + bpc = BeliefPropagationCache(tn) do edge + return ones(Tuple(linkinds(tn, edge))) + end + + sub_vs = [(1,), (2,)] + subbpc = subgraph(bpc, sub_vs) + @test subbpc isa BeliefPropagationCache + @test issetequal(vertices(subbpc), sub_vs) + @test has_edge(subbpc, (1,) => (2,)) + end + @testset "diff" begin + g = named_grid((2,)) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + bpc1 = BeliefPropagationCache(tn) do edge + return ones(Tuple(linkinds(tn, edge))) + end + bpc2 = BeliefPropagationCache(tn) do edge + return ones(Tuple(linkinds(tn, edge))) + end + + # Identical caches: diff should be ~0. + @test ITensorNetworksNext.iterate_diff(bpc1, bpc2) ≈ 0.0 atol = 10 * eps() + end + end - #Spin Ice Model (has analytical bp solution given by 1.5^(n^2)) - for n in (3, 4, 5) - dims = (n, n) - g = named_grid(dims; periodic = true) - tn = spin_ice_tensornetwork(g) + @testset "Algorithm" begin + @testset "$T" for T in (Float32, Float64, ComplexF64, BigFloat) + #Chain of tensors + dims = (2, 1) + g = named_grid(dims) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) - bpc = ITensorNetworksNext.BeliefPropagationCache(tn) do edge - # Use `rand` so messages have positive elements. - return rand(T, Tuple(linkinds(tn, edge))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(T, Tuple(is)) end - bpc = - ITensorNetworksNext.beliefpropagation(bpc; tol = 1.0e-10, maxiter = 10) + bpc = BeliefPropagationCache(tn) do edge + return ones(T, Tuple(linkinds(tn, edge))) + end + bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) + z_bp = scalar(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test z_bp ≈ z_exact + + #Tree of tensors + dims = (4, 3) + g = named_comb_tree(dims) + l = Dict(e => Index(3) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(T, Tuple(is)) + end + + bpc = BeliefPropagationCache(tn) do edge + return ones(T, Tuple(linkinds(tn, edge))) + end + bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) z_bp = scalar(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test z_bp ≈ z_exact + + #Spin Ice Model (has analytical bp solution given by 1.5^(n^2)) + @testset "Spin Ice Model (analytical)" begin + for n in (3, 4, 5) + dims = (n, n) + g = named_grid(dims; periodic = true) + tn = spin_ice_tensornetwork(g) + + bpc = ITensorNetworksNext.BeliefPropagationCache(tn) do edge + # Use `rand` so messages have positive elements. + return rand(T, Tuple(linkinds(tn, edge))) + end + bpc = + ITensorNetworksNext.beliefpropagation( + bpc; + tol = 1.0e-10, + maxiter = 10 + ) - @test z_bp ≈ 1.5^(n^2) + z_bp = scalar(bpc) + + @test z_bp ≈ 1.5^(n^2) + end + end end end end diff --git a/test/test_new.jl b/test/test_new.jl deleted file mode 100644 index 3acc422..0000000 --- a/test/test_new.jl +++ /dev/null @@ -1,742 +0,0 @@ -import AlgorithmsInterface as AI -import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE -using AbstractTrees: AbstractTrees -using BackendSelection: @Algorithm_str, Algorithm -using DataGraphs: vertex_data -using Dictionaries: Dictionary -using Graphs: Graphs, AbstractEdge, dst, edges, has_edge, ne, nv, src, vertices -using ITensorBase: ITensor, Index -using ITensorNetworksNext: BeliefPropagationCache, EigsolveRegion, ITensorNetworksNext, - TensorNetwork, contract_network, dmrg, factor, factor_type, factors, linkinds, message, - message_type, messages, scalar -using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, LazyNamedDimsArrays, Mul, - SymbolicArray, ismul, lazy, parenttype, substitute, symnameddims -using ITensorNetworksNext.TensorNetworkGenerators: ising_network -using NamedDimsArrays: AbstractNamedDimsArray, NamedDimsArray, denamed, dimnames, inds, - nameddims, namedoneto -using NamedGraphs: NamedGraphs -using NamedGraphs.GraphsExtensions: GraphsExtensions, incident_edges -using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree -using TermInterface: arguments, head, iscall, isexpr, operation -using Test: @test, @test_throws, @testset -using WrappedUnions: unwrap - -# Type definitions used by some tests below; must be at file scope. -struct _DummyNonIter <: AIE.NonIterativeAlgorithm end -struct _DummyProblem <: AIE.Problem end - -@testset "test_new.jl" begin - # --------------------------------------------------------------------------- - # AbstractTensorNetwork: iteration / keys / eltype / is_directed / show - # --------------------------------------------------------------------------- - @testset "AbstractTensorNetwork interface" begin - g = named_grid((2, 2)) - s = Dict(v => Index(2) for v in vertices(g)) - tn = TensorNetwork(g) do v - return randn(s[v]) - end - - # `iterate` works (delegates to `vertex_data`). - @test !isempty(collect(tn)) - # `keys` returns vertices. - @test issetequal(keys(tn), vertices(tn)) - # `eltype` matches the eltype of the vertex data. - @test eltype(tn) === eltype(vertex_data(tn)) - # `is_directed` is `false` for AbstractTensorNetwork. - @test !Graphs.is_directed(typeof(tn)) - - # `show` MIME and default both succeed and mention vertices/edges. - s_plain = sprint(show, MIME"text/plain"(), tn) - @test occursin("vertices", s_plain) - @test occursin("edge", s_plain) - s_default = sprint(show, tn) - @test occursin("vertices", s_default) - - # `setindex!` for edges is unimplemented. - e = first(edges(tn)) - @test_throws ErrorException tn[e] = randn(2, 2) - @test_throws ErrorException tn[src(e) => dst(e)] = randn(2, 2) - end - - # --------------------------------------------------------------------------- - # `linkaxes` / `linknames` on a TensorNetwork - # --------------------------------------------------------------------------- - @testset "linkaxes / linknames" begin - g = named_grid((3,)) - l = Dict(e => Index(2) for e in edges(g)) - l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) - tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) - end - - e = first(edges(tn)) - p = src(e) => dst(e) - - li = linkinds(tn, e) - la_e = ITensorNetworksNext.linkaxes(tn, e) - la_p = ITensorNetworksNext.linkaxes(tn, p) - @test la_e == la_p - @test length(la_e) == length(li) - - ln_e = ITensorNetworksNext.linknames(tn, e) - ln_p = ITensorNetworksNext.linknames(tn, p) - @test ln_e == ln_p - @test length(ln_e) == length(li) - end - - # --------------------------------------------------------------------------- - # expression-shape predicates - # --------------------------------------------------------------------------- - @testset "is_setindex!_expr / is_assignment_expr / is_getindex_expr" begin - @test ITensorNetworksNext.is_setindex!_expr(:(a[1] = 2)) - @test !ITensorNetworksNext.is_setindex!_expr(:(a[1])) - @test !ITensorNetworksNext.is_setindex!_expr(:(a + b)) - @test !ITensorNetworksNext.is_setindex!_expr(42) - - @test ITensorNetworksNext.is_assignment_expr(:(x = 1)) - @test !ITensorNetworksNext.is_assignment_expr(:(x + 1)) - @test !ITensorNetworksNext.is_assignment_expr(42) - - @test ITensorNetworksNext.is_getindex_expr(:(a[1])) - @test !ITensorNetworksNext.is_getindex_expr(:(a + 1)) - @test !ITensorNetworksNext.is_getindex_expr(42) - end - - # --------------------------------------------------------------------------- - # `add_missing_edges!`: no-op on a well-formed network. - # --------------------------------------------------------------------------- - @testset "add_missing_edges!" begin - g = named_grid((2, 2)) - l = Dict(e => Index(2) for e in edges(g)) - l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) - tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) - end - - es_before = collect(edges(tn)) - ITensorNetworksNext.add_missing_edges!(tn) - @test issetequal(edges(tn), es_before) - - v = first(vertices(tn)) - ITensorNetworksNext.add_missing_edges!(tn, v) - @test issetequal(edges(tn), es_before) - end - - # --------------------------------------------------------------------------- - # `TensorNetwork` constructor / copy / convert variants and `rem_edge!` - # --------------------------------------------------------------------------- - @testset "TensorNetwork copy / convert / rem_edge!" begin - g = named_grid((3,)) - s = Dict(v => Index(2) for v in vertices(g)) - tn = TensorNetwork(g) do v - return randn(s[v]) - end - - # `TensorNetwork(tensors)` infers the graph from shared indices. - link = Index(2) - A = randn(s[(1,)], link) - B = randn(s[(2,)], link) - tensors = Dictionary([(1,), (2,)], [A, B]) - tn_inferred = TensorNetwork(tensors) - @test tn_inferred isa TensorNetwork - @test issetequal(vertices(tn_inferred), [(1,), (2,)]) - @test ne(tn_inferred) == 1 - - # `copy` produces an independent TensorNetwork. - tn2 = copy(tn) - @test tn2 isa TensorNetwork - @test issetequal(vertices(tn2), vertices(tn)) - @test issetequal(edges(tn2), edges(tn)) - @test vertex_data(tn2) !== vertex_data(tn) - - # `TensorNetwork(tn)` and `TensorNetwork{V}(tn)` (same V) call `copy`. - tn3 = TensorNetwork(tn) - @test tn3 isa TensorNetwork - @test issetequal(vertices(tn3), vertices(tn)) - - V = GraphsExtensions.vertextype(tn) - tn4 = TensorNetwork{V}(tn) - @test tn4 isa TensorNetwork - @test issetequal(vertices(tn4), vertices(tn)) - - # `TensorNetwork{V}(tn)` with a different V re-keys vertices. - tn5 = TensorNetwork{Tuple{Float64}}(tn) - @test tn5 isa TensorNetwork - @test all(v -> v isa Tuple{Float64}, vertices(tn5)) - - # `rem_edge!` returns false for an absent edge. - bad_edge = (1,) => (3,) - @test !Graphs.rem_edge!(tn, bad_edge) - - # `rem_edge!` on an edge with shared inds throws. - l = Dict(e => Index(2) for e in edges(g)) - l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) - tn_link = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) - end - e = first(edges(tn_link)) - @test_throws ArgumentError Graphs.rem_edge!(tn_link, e) - end - - # --------------------------------------------------------------------------- - # `induced_subgraph_from_vertices` for TensorNetwork - # --------------------------------------------------------------------------- - @testset "TensorNetwork induced_subgraph_from_vertices" begin - g = named_grid((3,)) - l = Dict(e => Index(2) for e in edges(g)) - l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) - tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) - end - - sub_vs = [(1,), (2,)] - subtn, _ = NamedGraphs.induced_subgraph_from_vertices(tn, sub_vs) - @test subtn isa TensorNetwork - @test issetequal(vertices(subtn), sub_vs) - end - - # --------------------------------------------------------------------------- - # `BeliefPropagationCache` constructor variants and message/factor mutators - # --------------------------------------------------------------------------- - @testset "BeliefPropagationCache constructors and mutators" begin - g = named_grid((2, 2)) - l = Dict(e => Index(2) for e in edges(g)) - l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) - tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) - end - - # `BeliefPropagationCache(network)` (no callable; cache constructed). - bpc1 = BeliefPropagationCache(tn) - @test bpc1 isa BeliefPropagationCache - @test length(factors(bpc1)) == nv(tn) - - # `BeliefPropagationCache(callable, network)` - bpc2 = BeliefPropagationCache(tn) do edge - return ones(Tuple(linkinds(tn, edge))) - end - @test length(messages(bpc2)) == 2 * length(edges(g)) - - # `copy` is independent of the source. - bpc_copy = copy(bpc2) - @test bpc_copy isa BeliefPropagationCache - @test length(messages(bpc_copy)) == length(messages(bpc2)) - - # `setmessage!` and `setfactor!` write through the cache. - e = first(edges(bpc2)) - new_msg = ones(Tuple(linkinds(tn, e))) .* 2.0 - ITensorNetworksNext.setmessage!(bpc2, e, new_msg) - @test message(bpc2, e) == new_msg - - v = first(vertices(bpc2)) - old_factor = factor(bpc2, v) - new_factor = old_factor .* 2 - ITensorNetworksNext.setfactor!(bpc2, v, new_factor) - @test factor(bpc2, v) == new_factor - - # `setmessages!` accepts a mapping and updates entries. - e2 = first(edges(bpc2)) - msg2 = ones(Tuple(linkinds(tn, e2))) .* 3.0 - ITensorNetworksNext.setmessages!(bpc2, Dict(e2 => msg2)) - @test message(bpc2, e2) == msg2 - - # `setmessages!(dst, src, edges)` copies messages between caches. - bpc_dst = BeliefPropagationCache(tn) do edge - return zeros(Tuple(linkinds(tn, edge))) - end - e3 = first(edges(bpc2)) - ITensorNetworksNext.setmessages!(bpc_dst, bpc2, [e3]) - @test message(bpc_dst, e3) == message(bpc2, e3) - end - - # --------------------------------------------------------------------------- - # AbstractBeliefPropagationCache helpers: vertex/edge/region scalars, - # incoming_messages, map_messages/map_factors, factor_type, message_type. - # --------------------------------------------------------------------------- - @testset "BeliefPropagationCache scalars / mappers" begin - g = named_grid((2,)) - l = Dict(e => Index(2) for e in edges(g)) - l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) - tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) - end - - bpc = BeliefPropagationCache(tn) do edge - return ones(Tuple(linkinds(tn, edge))) - end - bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) - - v = first(vertices(bpc)) - e = first(edges(bpc)) - - # Vertex/edge/region scalars. - vs = ITensorNetworksNext.vertex_scalar(bpc, v) - es = ITensorNetworksNext.edge_scalar(bpc, e) - @test vs isa Number - @test es isa Number - - rs = ITensorNetworksNext.region_scalar(bpc, [v, e]) - @test rs ≈ vs * es - - # `incoming_messages` excludes specified edges. - in_msgs = ITensorNetworksNext.incoming_messages(bpc, v) - in_msgs_filtered = ITensorNetworksNext.incoming_messages( - bpc, v; ignore_edges = [reverse(e)] - ) - @test length(in_msgs_filtered) <= length(in_msgs) - - # `factor_type` / `message_type` resolve to concrete types. - @test factor_type(bpc) isa Type - @test message_type(bpc) isa Type - - # `map_messages` and `map_factors` produce independent caches. - bpc_doubled = ITensorNetworksNext.map_messages(m -> 2 .* m, bpc) - @test message(bpc_doubled, e) ≈ 2 .* message(bpc, e) - - bpc_scaled = ITensorNetworksNext.map_factors(f -> f .* 2, bpc) - for vv in vertices(bpc_scaled) - @test factor(bpc_scaled, vv) ≈ factor(bpc, vv) .* 2 - end - - # `adapt_factors` and `adapt_messages` should at least be callable. - @test ITensorNetworksNext.adapt_factors(identity, bpc) isa BeliefPropagationCache - @test ITensorNetworksNext.adapt_messages(identity, bpc) isa BeliefPropagationCache - end - - # --------------------------------------------------------------------------- - # `logscalar` branches: complex-promotion path and zero denominator. - # --------------------------------------------------------------------------- - @testset "logscalar special branches" begin - g = named_grid((2,)) - l = Dict(e => Index(2) for e in edges(g)) - l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) - tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) - end - bpc = BeliefPropagationCache(tn) do edge - return ones(Tuple(linkinds(tn, edge))) - end - bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) - - # Negate one factor so the numerator product becomes negative, - # forcing a complex promotion in `logscalar`. - v = first(vertices(bpc)) - ITensorNetworksNext.setfactor!(bpc, v, -1 .* factor(bpc, v)) - @test ITensorNetworksNext.logscalar(bpc) isa Number - - # Zero out a message so a denominator term becomes zero -> -Inf. - bpc_zero = BeliefPropagationCache(tn) do edge - return zeros(Tuple(linkinds(tn, edge))) - end - @test ITensorNetworksNext.logscalar(bpc_zero) == -Inf - end - - # --------------------------------------------------------------------------- - # `induced_subgraph_bpcache` / induced_subgraph_from_vertices on a BPCache. - # --------------------------------------------------------------------------- - @testset "BeliefPropagationCache induced_subgraph" begin - g = named_grid((3,)) - l = Dict(e => Index(2) for e in edges(g)) - l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) - tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) - end - bpc = BeliefPropagationCache(tn) do edge - return ones(Tuple(linkinds(tn, edge))) - end - - sub_vs = [(1,), (2,)] - subbpc = subgraph(bpc, sub_vs) - @test subbpc isa BeliefPropagationCache - @test issetequal(vertices(subbpc), sub_vs) - @test has_edge(subbpc, (1,) => (2,)) - end - - # --------------------------------------------------------------------------- - # `forest_cover_edge_sequence` returns a sequence covering a tree. - # --------------------------------------------------------------------------- - @testset "forest_cover_edge_sequence" begin - g = named_comb_tree((3, 2)) - l = Dict(e => Index(2) for e in edges(g)) - l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) - tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) - end - bpc = BeliefPropagationCache(tn) do edge - return ones(Tuple(linkinds(tn, edge))) - end - - seq = ITensorNetworksNext.forest_cover_edge_sequence(bpc) - @test eltype(seq) <: AbstractEdge - @test !isempty(seq) - end - - # --------------------------------------------------------------------------- - # Belief propagation: `select_algorithm` errors when `maxiter` is required. - # --------------------------------------------------------------------------- - @testset "beliefpropagation select_algorithm error" begin - # 2x2 grid: not a tree, so `maxiter` cannot be defaulted. - g = named_grid((2, 2)) - l = Dict(e => Index(2) for e in edges(g)) - l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) - tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) - end - bpc = BeliefPropagationCache(tn) do edge - return ones(Tuple(linkinds(tn, edge))) - end - @test_throws ArgumentError ITensorNetworksNext.select_algorithm( - ITensorNetworksNext.beliefpropagation, bpc; maxiter = nothing - ) - end - - # --------------------------------------------------------------------------- - # `iterate_diff` and `SimpleMessageUpdate.getproperty(:kwargs)` path. - # --------------------------------------------------------------------------- - @testset "iterate_diff and SimpleMessageUpdate kwargs" begin - g = named_grid((2,)) - l = Dict(e => Index(2) for e in edges(g)) - l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) - tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) - end - - bpc1 = BeliefPropagationCache(tn) do edge - return ones(Tuple(linkinds(tn, edge))) - end - bpc2 = BeliefPropagationCache(tn) do edge - return ones(Tuple(linkinds(tn, edge))) - end - - # Identical caches: diff should be ~0. - @test ITensorNetworksNext.iterate_diff(bpc1, bpc2) ≈ 0 atol = 1.0e-10 - - # `SimpleMessageUpdate.getproperty(:kwargs)` returns the NamedTuple. - edge = first(edges(bpc1)) - upd = ITensorNetworksNext.SimpleMessageUpdate(edge; normalize = false) - @test upd.kwargs isa NamedTuple - # Forwarded properties still work (`getfield(:kwargs)` then property). - @test upd.normalize == false - end - - # --------------------------------------------------------------------------- - # `contract_network`: unknown algorithm error and `left_associative` order. - # --------------------------------------------------------------------------- - @testset "contract_network error / left_associative" begin - g = named_grid((2,)) - l = Dict(e => Index(2) for e in edges(g)) - l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) - tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) - end - - @test_throws ArgumentError contract_network(tn; alg = Algorithm"unknown_alg"()) - - # `contraction_order` for `left_associative` algorithm. - order = ITensorNetworksNext.contraction_order(tn; alg = Algorithm"left_associative"()) - @test order isa LazyNamedDimsArray - end - - # --------------------------------------------------------------------------- - # `dmrg`: thin wrappers and unimplemented `EigsolveRegion` step. - # --------------------------------------------------------------------------- - @testset "dmrg wrappers" begin - operator = "operator" - init = "init" - nsweeps = 2 - regions = ["region1"] - algorithm = ITensorNetworksNext.select_algorithm( - dmrg, operator, init; nsweeps, regions, maxdim = 10 - ) - - # `dmrg(operator, algorithm, state)` errors deep in `EigsolveRegion`'s solve!. - @test_throws Exception dmrg(operator, algorithm, init) - - # `dmrg(operator, state; ...)` builds the algorithm internally; same expected error. - @test_throws Exception dmrg(operator, init; nsweeps, regions, maxdim = 10) - - # The `EigsolveRegion`-specific `solve!` errors directly. - region = EigsolveRegion("region"; maxdim = 10) - problem = ITensorNetworksNext.EigenProblem(operator) - state = AI.initialize_state(problem, region; iterate = init) - @test_throws ErrorException AI.solve!(problem, region, state) - end - - # --------------------------------------------------------------------------- - # `LazyNamedDimsArrays`: error paths in lazy interface. - # --------------------------------------------------------------------------- - @testset "LazyNamedDimsArrays interface error paths" begin - a = nameddims(randn(2, 2), (:i, :j)) - la = lazy(a) - - # `getindex_lazy` errors on expressions, but works on a leaf. - @test la[1, 1] == a[1, 1] - expr = la * lazy(nameddims(randn(2, 2), (:j, :k))) - @test_throws ErrorException expr[1, 1] - - # `denamed` works on a leaf, errors on non-leaf. - @test denamed(la) == denamed(a) - @test_throws ErrorException LazyNamedDimsArrays.denamed_lazy(expr) - - # `dimnames` and `inds` on a `Mul`. - @test issetequal(dimnames(expr), [:i, :k]) - @test length(inds(expr)) == 2 - - # Equality and hash branches. - la2 = lazy(a) - @test la == la2 - @test isequal(la, la2) - @test !(la == expr) # leaf vs expression - @test !isequal(la, expr) # leaf vs expression - @test hash(la) == hash(la2) - - # `mul_lazy(a)` on a leaf wraps it in a `Mul`. - wrapped = *(la) - @test ismul(wrapped) - @test arguments(wrapped) == [la] - - # `mul_lazy(a)` on a Mul returns it unchanged. - @test *(expr) == expr - - # `mul_lazy(a, b; flatten=true)` flattens the arguments. - expr3 = lazy(nameddims(randn(2, 2), (:k, :l))) - flat = LazyNamedDimsArrays.mul_lazy(expr, expr3; flatten = true) - @test ismul(flat) - @test length(arguments(flat)) == 3 - - # Number * Number short-circuit. - @test LazyNamedDimsArrays.mul_lazy(2, 3) == 6 - - # Unsupported ops error. - @test_throws ErrorException la + la2 - @test_throws ErrorException la - la2 - @test_throws ErrorException -la - @test_throws ErrorException la / 2 - @test_throws ErrorException 2 * la - @test_throws ErrorException la * 2 - - # `maketerm` for non-`*` head errors. - @test_throws ErrorException LazyNamedDimsArrays.maketerm_lazy( - LazyNamedDimsArray, +, [la, la2], nothing - ) - - # `parenttype` resolution. - @test parenttype(LazyNamedDimsArray) === AbstractNamedDimsArray - @test parenttype(LazyNamedDimsArray{Float64}) === AbstractNamedDimsArray{Float64} - @test parenttype(typeof(la)) === typeof(a) - end - - # --------------------------------------------------------------------------- - # Lazy broadcasting (linear ops only; arbitrary ops error). - # --------------------------------------------------------------------------- - @testset "Lazy broadcasting" begin - a = nameddims(randn(2, 2), (:i, :j)) - la, la2 = lazy(a), lazy(a) - style = LazyNamedDimsArrays.LazyNamedDimsArrayStyle() - - # Broadcasted linear ops route through `+, -, *, /, unary -`, - # all of which themselves error in the lazy framework. - @test_throws ErrorException Base.Broadcast.broadcasted(style, +, la, la2) - @test_throws ErrorException Base.Broadcast.broadcasted(style, -, la, la2) - @test_throws ErrorException Base.Broadcast.broadcasted(style, *, 2.0, la) - @test_throws ErrorException Base.Broadcast.broadcasted(style, *, la, 2.0) - @test Base.Broadcast.broadcasted(style, *, 2.0, 3.0) == 6.0 - @test_throws ErrorException Base.Broadcast.broadcasted(style, /, la, 2.0) - @test_throws ErrorException Base.Broadcast.broadcasted(style, -, la) - - # Arbitrary functions error explicitly. - @test_throws ErrorException Base.Broadcast.broadcasted(style, sin, la) - end - - # --------------------------------------------------------------------------- - # `SymbolicArray`: getindex/setindex! errors, permutedims, show, printnode. - # --------------------------------------------------------------------------- - @testset "SymbolicArray operations" begin - sa = SymbolicArray(:x, (Base.OneTo(2), Base.OneTo(3))) - @test size(sa) == (2, 3) - - # Indexing errors. - @test_throws ErrorException sa[1, 1] - @test_throws ErrorException (sa[1, 1] = 0) - - # `permutedims`. - pa = permutedims(sa, (2, 1)) - @test size(pa) == (3, 2) - - # `show` writes the symbolic name. - s_plain = sprint(show, MIME"text/plain"(), sa) - @test occursin("x", s_plain) - s_default = sprint(show, sa) - @test occursin("SymbolicArray", s_default) - - # `printnode` writes the symbolic name. - s_node = sprint(AbstractTrees.printnode, sa) - @test occursin("x", s_node) - end - - # --------------------------------------------------------------------------- - # `SymbolicNamedDimsArray`: equality and printnode with non-zero ndims. - # --------------------------------------------------------------------------- - @testset "SymbolicNamedDimsArray equality / printnode" begin - i, j = namedoneto.(2, (:i, :j)) - sa = symnameddims(:a, (i, j)) - sa2 = symnameddims(:a, (i, j)) - sa_perm = symnameddims(:a, (j, i)) - sa_other = symnameddims(:b, (i, j)) - - # Equality: same name + same dimnames (any order) -> equal. - @test unwrap(sa) == unwrap(sa2) - @test unwrap(sa) == unwrap(sa_perm) - @test unwrap(sa) != unwrap(sa_other) - - # `printnode` on a non-scalar prints both name and dims. - s_node = sprint(AbstractTrees.printnode, unwrap(sa)) - @test occursin("a", s_node) - @test occursin("[", s_node) - end - - # --------------------------------------------------------------------------- - # `evaluation_time_complexity` / `flatten_expression` / `optimize_evaluation_order` - # --------------------------------------------------------------------------- - @testset "LazyNamedDimsArrays evaluation_order" begin - a = nameddims(randn(3, 3), (:i, :j)) - b = nameddims(randn(3, 3), (:j, :k)) - la, lb = lazy.((a, b)) - expr = la * lb - - # Time complexity for a known mul. - @test LazyNamedDimsArrays.evaluation_time_complexity(expr) > 0 - - # Flatten of a `Mul` of `Mul`s. - c = nameddims(randn(3, 3), (:k, :i)) - lc = lazy(c) - nested = (la * lb) * lc - flat = LazyNamedDimsArrays.flatten_expression(nested) - @test ismul(flat) - @test length(arguments(flat)) == 3 - - # `flatten_expression` is identity on leaves. - @test LazyNamedDimsArrays.flatten_expression(la) === la - - # `optimize_evaluation_order` on a leaf is identity. - @test LazyNamedDimsArrays.optimize_evaluation_order(la) === la - - # `optimize_contraction_order` with eager picks an ordering. - eager = Algorithm"eager"() - flat_expr = LazyNamedDimsArrays.flatten_expression((la * lb) * lc) - @test LazyNamedDimsArrays.optimize_evaluation_order(eager, flat_expr) isa - LazyNamedDimsArray - - # Time-complexity for scalar*tensor and tensor*scalar. - n = nameddims(randn(3, 3), (:i, :j)) - @test LazyNamedDimsArrays.time_complexity(*, 2.0, n) > 0 - @test LazyNamedDimsArrays.time_complexity(*, n, 2.0) > 0 - - # Time complexity for elementwise +. - n2 = nameddims(randn(3, 3), (:i, :j)) - @test LazyNamedDimsArrays.time_complexity(+, n, n2) > 0 - end - - # --------------------------------------------------------------------------- - # `nameddimsarraysextensions._hash` fallback for non-NamedDimsArray. - # --------------------------------------------------------------------------- - @testset "_hash fallback" begin - @test LazyNamedDimsArrays._hash(42, UInt64(0)) == hash(42, UInt64(0)) - @test LazyNamedDimsArrays._hash("x", UInt64(0)) == hash("x", UInt64(0)) - end - - # --------------------------------------------------------------------------- - # `generic_map` for arrays / dicts / sets. - # --------------------------------------------------------------------------- - @testset "generic_map" begin - @test LazyNamedDimsArrays.generic_map(x -> x + 1, [1, 2, 3]) == [2, 3, 4] - - d = Dict(:a => 1, :b => 2) - md = LazyNamedDimsArrays.generic_map(x -> x * 10, d) - @test md isa Dict - @test md[:a] == 10 - @test md[:b] == 20 - - ms = LazyNamedDimsArrays.generic_map(x -> x * 2, Set([1, 2, 3])) - @test ms == Set([2, 4, 6]) - end - - # --------------------------------------------------------------------------- - # `Mul` core hooks. - # --------------------------------------------------------------------------- - @testset "Mul / Applied basics" begin - a = lazy(nameddims(randn(2, 2), (:i, :j))) - b = lazy(nameddims(randn(2, 2), (:j, :i))) - m = Mul([a, b]) - - @test arguments(m) == [a, b] - @test operation(m) ≡ * - @test iscall(m) - @test isexpr(m) - @test head(m) ≡ * - - # `show` for an `Applied` writes parens-joined arguments. - @test occursin("*", sprint(show, m)) - - # Hashing of equal `Mul`s. - m2 = Mul([a, b]) - @test hash(m) == hash(m2) - end - - # --------------------------------------------------------------------------- - # AlgorithmsInterfaceExtensions: `NonIterativeAlgorithm` fallback `solve!`. - # --------------------------------------------------------------------------- - @testset "NonIterativeAlgorithm fallback solve!" begin - problem = _DummyProblem() - algorithm = _DummyNonIter() - state = AI.initialize_state(problem, algorithm; iterate = [0.0]) - @test_throws Exception AI.solve!(problem, algorithm, state) - end - - # --------------------------------------------------------------------------- - # Latent-bug catchers — these tests are currently expected to FAIL. - # They exercise code paths whose source references variables that aren't - # defined in the function body. They exist to surface those bugs the next - # time someone runs the suite, not to lock in current (buggy) behavior. - # --------------------------------------------------------------------------- - @testset "siteaxes / sitenames (latent UndefVarError)" begin - # Build a TN where each tensor has both link indices (one per neighbor) - # and a "site" index that no neighbor shares. - g = named_grid((3,)) - site_idx = Dict(v => Index(2) for v in vertices(g)) - link = Dict(e => Index(2) for e in edges(g)) - link = merge(link, Dict(reverse(e) => link[e] for e in edges(g))) - tn = TensorNetwork(g) do v - is = (site_idx[v], (link[e] for e in incident_edges(g, v))...) - return randn(is) - end - - e = first(edges(tn)) - - # Both functions reference `v` inside their `for v′ in neighbors(tn, v)` - # loop, but `v` is never defined in either body — only `edge` is in - # scope. Calling them currently throws `UndefVarError(:v)`. - # The expected (post-fix) behavior is to return a non-empty collection - # of the site axes / site names at the edge endpoints, so we assert that - # the call succeeds and returns something sensible. - sax = ITensorNetworksNext.siteaxes(tn, e) - @test sax isa AbstractVector || sax isa AbstractSet || sax isa Tuple - @test !isempty(sax) - - snm = ITensorNetworksNext.sitenames(tn, e) - @test snm isa AbstractVector || snm isa AbstractSet || snm isa Tuple - @test !isempty(snm) - end -end diff --git a/test/test_tensornetwork.jl b/test/test_tensornetwork.jl index aaf96e8..08e241c 100644 --- a/test/test_tensornetwork.jl +++ b/test/test_tensornetwork.jl @@ -1,8 +1,9 @@ using DataGraphs: assigned_edge_data, assigned_vertex_data, underlying_graph, vertex_data -using Graphs: dst, edges, has_edge, ne, nv, src, vertices +using Graphs: dst, edges, edgetype, has_edge, ne, nv, src, vertices using ITensorBase: Index using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray -using ITensorNetworksNext: TensorNetwork +using ITensorNetworksNext: + TensorNetwork, linkaxes, linkinds, linknames, siteaxes, siteinds, sitenames using NamedGraphs.GraphsExtensions: vertextype using NamedGraphs.NamedGraphGenerators: named_grid using NamedGraphs.PartitionedGraphs: AbstractPartitionedGraph, QuotientVertex, departition, @@ -12,6 +13,77 @@ using NamedGraphs: convert_vertextype, similar_graph using Test: @test, @test_throws, @testset @testset "`TensorNetwork`" begin + @testset "Basics" begin + g = named_grid((2, 2)) + s = Dict(v => Index(2) for v in vertices(g)) + tn = TensorNetwork(g) do v + return randn(s[v]) + end + + # `iterate` works (delegates to `vertex_data`). + @test !isempty(collect(tn)) + # `keys` returns vertices. + @test issetequal(keys(tn), vertices(tn)) + # `eltype` matches the eltype of the vertex data. + @test eltype(tn) === eltype(vertex_data(tn)) + # `is_directed` is `false` for AbstractTensorNetwork. + @test !Graphs.is_directed(typeof(tn)) + + # `show` MIME and default both succeed and mention vertices/edges. + s_plain = sprint(show, MIME"text/plain"(), tn) + @test occursin("vertices", s_plain) + @test occursin("edge", s_plain) + s_default = sprint(show, tn) + @test occursin("vertices", s_default) + + # `setindex!` for edges is intentionally unimplemented. + e = first(edges(tn)) + @test_throws ErrorException tn[e] = randn(2, 2) + @test_throws ErrorException tn[src(e) => dst(e)] = randn(2, 2) + end + + @testset "link and site functions" begin + g = named_path_graph(3) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + s = Dict(v => Index(2) for v in vertices(g)) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn((s[v], is...)) + end + + E = edgetype(tn) + @test linkinds(tn, 1 => 2) == [l[E(1 => 2)]] + @test linkinds(tn, E(1 => 2)) == [l[E(1 => 2)]] + + @test linkaxes(tn, 1 => 2) == [l[E(1 => 2)]] + @test linkaxes(tn, E(1 => 2)) == [l[E(1 => 2)]] + + @test linknames(tn, 1 => 2) == [l[E(1 => 2)].name] + @test linknames(tn, E(1 => 2)) == [l[E(1 => 2)].name] + + @test siteinds(tn, 1) == [s[1]] + @test siteaxes(tn, 2) == [s[2]] + @test sitenames(tn, 3) == [s[3].name] + end + + @testset "`subgraph`" begin + g = named_grid((3,)) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + sub_vs = [(1,), (2,)] + subtn = subgraph(tn, sub_vs) + @test subtn isa TensorNetwork + @test issetequal(vertices(subtn), sub_vs) + @test has_edge(subtn, (1,) => (2,)) + end + + @testset "DataGraphs/NamedGraphs interface" begin dims = (3, 3) g = named_grid(dims) From 44805cc3ac7b85153ffe5031782dcc213f28e6d7 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 29 Apr 2026 13:29:48 -0400 Subject: [PATCH 69/86] Further BP test improvements --- test/test_beliefpropagation.jl | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index d9112b9..2d166ac 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -3,9 +3,10 @@ using Dictionaries: Dictionary, set! using Graphs: AbstractGraph, dst, edges, src, vertices using ITensorBase: ITensor, Index, noprime, prime using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, - edge_scalar, factor, factor_type, factors, incoming_messages, linkinds, message, - message_type, messages, region_scalar, scalar, setfactor!, setmessage!, setmessages!, - vertex_scalar, vertex_scalars + adapt_factors, adapt_messages, edge_scalar, factor, factor_type, factors, + incoming_messages, linkinds, map_factors, map_messages, message, message_type, messages, + region_scalar, scalar, setfactor!, setmessage!, setmessages!, subgraph, vertex_scalar, + vertex_scalars using LinearAlgebra: LinearAlgebra using NamedDimsArrays: inds, name using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges, vertextype @@ -115,27 +116,29 @@ end @test length(in_msgs_filtered) == 1 @test only(in_msgs_filtered) == bpc[3 => 2] - # `factor_type` / `message_type` resolve to concrete types. - @test factor_type(bpc) <: ITensor - @test message_type(bpc) <: ITensor - # `map_messages` and `map_factors` produce independent caches. - bpc_doubled = ITensorNetworksNext.map_messages(m -> 2 .* m, bpc) - @test !(bpc_doubled === bpc) + bpc_again = map_messages(identity, bpc) + @test bpc_again !== bpc + @test bpc_again == bpc + + bpc_doubled = map_messages(m -> 2 .* m, bpc) + @test bpc_doubled != bpc @test message(bpc_doubled, 1 => 2) ≈ 2 .* message(bpc, 1 => 2) @test message(bpc_doubled, 2 => 3) ≈ 2 .* message(bpc, 2 => 3) - bpc_scaled = ITensorNetworksNext.map_factors(f -> f .* 2, bpc) + bpc_again = map_factors(identity, bpc) + @test bpc_again !== bpc + @test bpc_again == bpc + + bpc_scaled = map_factors(f -> f .* 2, bpc) @test !(bpc_scaled === bpc) for vv in vertices(bpc_scaled) @test factor(bpc_scaled, vv) ≈ factor(bpc, vv) .* 2 end # `adapt_factors` and `adapt_messages` should at least be callable. - @test ITensorNetworksNext.adapt_factors(identity, bpc) isa - BeliefPropagationCache - @test ITensorNetworksNext.adapt_messages(identity, bpc) isa - BeliefPropagationCache + @test adapt_factors(identity, bpc) isa BeliefPropagationCache + @test adapt_messages(identity, bpc) isa BeliefPropagationCache end @testset "subgraph" begin From 1950a805913e0df1326216dad40703954605c9db Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 29 Apr 2026 13:30:08 -0400 Subject: [PATCH 70/86] Fix incomplete `sitenames` and `siteaxes` definitions. --- src/abstracttensornetwork.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index 0cb997f..ac9c9c3 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -85,15 +85,15 @@ function siteinds(tn::AbstractGraph, v) end return s end -function siteaxes(tn::AbstractGraph, edge::AbstractEdge) - s = axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) +function siteaxes(tn::AbstractGraph, v) + s = axes(tn[v]) for v′ in neighbors(tn, v) s = setdiff(s, axes(tn[v′])) end return s end -function sitenames(tn::AbstractGraph, edge::AbstractEdge) - s = dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) +function sitenames(tn::AbstractGraph, v) + s = dimnames(tn[v]) for v′ in neighbors(tn, v) s = setdiff(s, dimnames(tn[v′])) end From 20dca72f68dde051fe74ec7b8108463da8c90965 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 29 Apr 2026 13:30:38 -0400 Subject: [PATCH 71/86] Remove `default_message` and other fixes. --- .../abstractbeliefpropagationcache.jl | 10 ++--- .../beliefpropagationcache.jl | 43 +++++++++---------- 2 files changed, 25 insertions(+), 28 deletions(-) diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 0a3e28c..e39d02f 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -79,17 +79,17 @@ function edge_scalars( end function region_scalar(bpc::AbstractGraph, region) - return mapreduce(ind -> _graph_index_scalar(bpc, ind), *, region) + return mapreduce(ind -> _graph_index_scalar(bpc, to_graph_index(bpc, ind)), *, region) end function incoming_messages(bp_cache::AbstractGraph, vertices; ignore_edges = []) b_edges = boundary_edges(bp_cache, [vertices;]; dir = :in) - b_edges = !isempty(ignore_edges) ? setdiff(b_edges, ignore_edges) : b_edges + if !isempty(ignore_edges) + b_edges = setdiff(b_edges, to_graph_index(bp_cache, ignore_edges)) + end return messages(bp_cache, b_edges) end -default_messages(::AbstractGraph) = not_implemented() - #Adapt interface for changing device map_messages(f, bp_cache, es = edges(bp_cache)) = map_messages!(f, copy(bp_cache), es) function map_messages!(f, bp_cache, es = edges(bp_cache)) @@ -112,7 +112,7 @@ adapt_factors(to, bp_cache, vs = vertices(bp_cache)) = map_factors(adapt(to), bp abstract type AbstractBeliefPropagationCache{V, VD, ED} <: AbstractDataGraph{V, VD, ED} end -factor_type(bpc::AbstractBeliefPropagationCache) = typeof(bpc) +factor_type(bpc::AbstractBeliefPropagationCache) = factor_type(typeof(bpc)) factor_type(::Type{<:AbstractBeliefPropagationCache{<:Any, VD}}) where {VD} = VD message_type(bpc::AbstractBeliefPropagationCache) = message_type(typeof(bpc)) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 0971303..83ed700 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -3,10 +3,11 @@ using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, edge_data_type, using Dictionaries: Dictionary, delete!, getindices, set! using Graphs: AbstractGraph, connected_components, is_directed, is_tree using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, lazy, parenttype -using NamedGraphs.GraphsExtensions: - default_root_vertex, forest_cover, post_order_dfs_edges, undirected_graph, vertextype +using NamedGraphs.GraphsExtensions: IsDirected, default_root_vertex, directed_graph, + forest_cover, post_order_dfs_edges, undirected_graph, vertextype using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, quotient_graph using NamedGraphs: Vertices, convert_vertextype, parent_graph_indices +using SimpleTraits: SimpleTraits, @traitfn, Not struct BeliefPropagationCache{V, VD, ED, E, G <: AbstractGraph{V}} <: AbstractBeliefPropagationCache{V, VD, ED} @@ -18,8 +19,8 @@ struct BeliefPropagationCache{V, VD, ED, E, G <: AbstractGraph{V}} <: factors::Dictionary, messages::Dictionary ) - # Ensure the graph is directed, if not make it directed. - digraph = is_directed(graph) ? graph : directed_graph(graph) + # Ensure the graph is directed and if not, make it directed. + digraph = directed_graph(graph) V = keytype(factors) VD = eltype(factors) @@ -29,9 +30,6 @@ struct BeliefPropagationCache{V, VD, ED, E, G <: AbstractGraph{V}} <: bpc = new{V, VD, ED, E, typeof(digraph)}(digraph, factors, messages) - for edge in edges(bpc) - get!(() -> default_message(bpc, edge), messages, edge) - end return bpc end end @@ -71,13 +69,24 @@ function BeliefPropagationCache(graph::AbstractGraph, factors::Dictionary) return BeliefPropagationCache(MT, graph, factors) end -function BeliefPropagationCache(MT::Type, graph::AbstractGraph, factors::Dictionary) - messages = Dictionary{edgetype(graph), MT}() +@traitfn function BeliefPropagationCache( + f::Function, + graph::AbstractGraph::!(IsDirected), + factors::Dictionary + ) + return BeliefPropagationCache(f, directed_graph(graph), factors) +end +@traitfn function BeliefPropagationCache( + f::Function, + graph::AbstractGraph::IsDirected, + factors::Dictionary + ) + messages = map(f, Indices(edges(graph))) return BeliefPropagationCache(graph, factors, messages) end -function BeliefPropagationCache(f::Function, graph::AbstractGraph, factors::Dictionary) - messages = map(f, Indices(edges(graph))) +function BeliefPropagationCache(MT::Type, graph::AbstractGraph, factors::Dictionary) + messages = Dictionary{edgetype(graph), MT}() return BeliefPropagationCache(graph, factors, messages) end @@ -148,18 +157,6 @@ function PartitionedGraphs.quotientview(bpc::BeliefPropagationCache) return BeliefPropagationCache(quotient_view, factors, messages) end -function default_message(bpc::BeliefPropagationCache, edge) - return default_message(message_type(bpc), bpc[src(edge)], bpc[dst(edge)]) -end -function default_message(T::Type, src, dst) - array = ones(Tuple(inds(src) ∩ inds(dst))) - return convert(T, array) -end -function default_message(T::Type{<:LazyNamedDimsArray}, src, dst) - message = default_message(parenttype(T), src, dst) - return convert(T, lazy(message)) -end - NamedGraphs.to_graph_index(::BeliefPropagationCache, vertex::QuotientVertex) = vertex # When getting data according the quotient vertices, take a lazy contraction. function DataGraphs.get_index_data(tn::BeliefPropagationCache, vertex::QuotientVertex) From 8e18614cbff12f2d7b57c0f02308bc91fbc62aae Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 29 Apr 2026 13:39:06 -0400 Subject: [PATCH 72/86] Fix test imports --- test/test_beliefpropagation.jl | 4 ++-- test/test_tensornetwork.jl | 9 ++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 2d166ac..244d780 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -1,6 +1,6 @@ using DiagonalArrays: δ using Dictionaries: Dictionary, set! -using Graphs: AbstractGraph, dst, edges, src, vertices +using Graphs: AbstractGraph, dst, edges, has_edge, src, vertices using ITensorBase: ITensor, Index, noprime, prime using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, adapt_factors, adapt_messages, edge_scalar, factor, factor_type, factors, @@ -10,7 +10,7 @@ using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagationCache, TensorNe using LinearAlgebra: LinearAlgebra using NamedDimsArrays: inds, name using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges, vertextype -using NamedGraphs.NamedGraphGenerators: named_comb_tree, named_grid +using NamedGraphs.NamedGraphGenerators: named_comb_tree, named_grid, named_path_graph using Test: @test, @testset function spin_ice_tensornetwork(g) diff --git a/test/test_tensornetwork.jl b/test/test_tensornetwork.jl index 08e241c..ac34332 100644 --- a/test/test_tensornetwork.jl +++ b/test/test_tensornetwork.jl @@ -1,11 +1,11 @@ using DataGraphs: assigned_edge_data, assigned_vertex_data, underlying_graph, vertex_data -using Graphs: dst, edges, edgetype, has_edge, ne, nv, src, vertices +using Graphs: dst, edges, edgetype, has_edge, ne, nv, src, vertices, is_directed using ITensorBase: Index using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray using ITensorNetworksNext: TensorNetwork, linkaxes, linkinds, linknames, siteaxes, siteinds, sitenames -using NamedGraphs.GraphsExtensions: vertextype -using NamedGraphs.NamedGraphGenerators: named_grid +using NamedGraphs.GraphsExtensions: incident_edges, subgraph, vertextype +using NamedGraphs.NamedGraphGenerators: named_grid, named_path_graph using NamedGraphs.PartitionedGraphs: AbstractPartitionedGraph, QuotientVertex, departition, partitioned_vertices, partitionedgraph, quotient_graph, quotient_graph_type, quotientvertices @@ -27,7 +27,7 @@ using Test: @test, @test_throws, @testset # `eltype` matches the eltype of the vertex data. @test eltype(tn) === eltype(vertex_data(tn)) # `is_directed` is `false` for AbstractTensorNetwork. - @test !Graphs.is_directed(typeof(tn)) + @test !is_directed(typeof(tn)) # `show` MIME and default both succeed and mention vertices/edges. s_plain = sprint(show, MIME"text/plain"(), tn) @@ -83,7 +83,6 @@ using Test: @test, @test_throws, @testset @test has_edge(subtn, (1,) => (2,)) end - @testset "DataGraphs/NamedGraphs interface" begin dims = (3, 3) g = named_grid(dims) From ef7e6595cf2e4321cd61ae9c9b5cd75d61125acb Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 29 Apr 2026 13:54:43 -0400 Subject: [PATCH 73/86] Formatting. --- test/test_tensornetwork.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_tensornetwork.jl b/test/test_tensornetwork.jl index ac34332..07c618f 100644 --- a/test/test_tensornetwork.jl +++ b/test/test_tensornetwork.jl @@ -1,5 +1,5 @@ using DataGraphs: assigned_edge_data, assigned_vertex_data, underlying_graph, vertex_data -using Graphs: dst, edges, edgetype, has_edge, ne, nv, src, vertices, is_directed +using Graphs: dst, edges, edgetype, has_edge, is_directed, ne, nv, src, vertices using ITensorBase: Index using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray using ITensorNetworksNext: From fbd0331f01cbc39885557b37b2a58bda005ed292 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 29 Apr 2026 14:41:07 -0400 Subject: [PATCH 74/86] Fix and test tensor network graph manipulation functions. --- src/abstracttensornetwork.jl | 7 ++++++- src/tensornetwork.jl | 6 ++++++ test/test_tensornetwork.jl | 18 ++++++++++++++++-- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index ac9c9c3..121073d 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -160,7 +160,12 @@ end # Fix the edges of the TensorNetwork `tn` to match # the tensor connectivity at vertex `v`. function fix_edges!(tn::AbstractGraph, v) - rem_edges!(tn, incident_edges(tn, v)) + for e in incident_edges(tn, v) + # Remove an edge if there is no index on that edge. + if isempty(linkinds(tn, e)) + rem_edge!(tn, e) + end + end add_missing_edges!(tn, v) return tn end diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 80f81a0..5357b5f 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -44,6 +44,12 @@ function TensorNetwork{V, VD, UG, Tensors}( return _TensorNetwork(graph, Tensors()) end +function Graphs.rem_vertex!(tn::TensorNetwork, v) + delete!(tn.tensors, v) + rem_vertex!(tn.underlying_graph, v) + return tn +end + # DataGraphs interface DataGraphs.underlying_graph(tn::TensorNetwork) = tn.underlying_graph diff --git a/test/test_tensornetwork.jl b/test/test_tensornetwork.jl index 07c618f..3b4211b 100644 --- a/test/test_tensornetwork.jl +++ b/test/test_tensornetwork.jl @@ -1,9 +1,10 @@ using DataGraphs: assigned_edge_data, assigned_vertex_data, underlying_graph, vertex_data -using Graphs: dst, edges, edgetype, has_edge, is_directed, ne, nv, src, vertices +using Graphs: add_edge!, add_vertex!, dst, edges, edgetype, has_edge, has_vertex, + is_directed, ne, nv, rem_vertex!, src, vertices using ITensorBase: Index using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray using ITensorNetworksNext: - TensorNetwork, linkaxes, linkinds, linknames, siteaxes, siteinds, sitenames + TensorNetwork, fix_edges!, linkaxes, linkinds, linknames, siteaxes, siteinds, sitenames using NamedGraphs.GraphsExtensions: incident_edges, subgraph, vertextype using NamedGraphs.NamedGraphGenerators: named_grid, named_path_graph using NamedGraphs.PartitionedGraphs: AbstractPartitionedGraph, QuotientVertex, departition, @@ -40,6 +41,19 @@ using Test: @test, @test_throws, @testset e = first(edges(tn)) @test_throws ErrorException tn[e] = randn(2, 2) @test_throws ErrorException tn[src(e) => dst(e)] = randn(2, 2) + + rem_vertex!(tn, (2, 2)) + @test !has_vertex(tn, (2, 2)) + add_vertex!(tn, (2, 2)) + @test has_vertex(tn, (2, 2)) + @test !isassigned(tn, (2, 2)) + + # Test `fix_edges!` removes edges where there is no link index + t = randn(s[(2, 2)]) + tn[(2, 2)] = t + add_edge!(tn.underlying_graph, (1, 2) => (2, 2)) + fix_edges!(tn, (2, 2)) + @test !has_edge(tn, (1, 2) => (2, 2)) end @testset "link and site functions" begin From e7d69c62e4514514009e0bf05a1766f7633b8eb0 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 1 May 2026 10:30:54 -0400 Subject: [PATCH 75/86] Simplify `factors` and `messages` methods on `AbstractGraph` --- src/beliefpropagation/messagecache.jl | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/beliefpropagation/messagecache.jl b/src/beliefpropagation/messagecache.jl index ec2e1f1..ee5c5f5 100644 --- a/src/beliefpropagation/messagecache.jl +++ b/src/beliefpropagation/messagecache.jl @@ -121,9 +121,7 @@ function factors(all_factors, vertices) end # Specific for graphs -function factors(all_factors::AbstractGraph) - return map(vertex -> factor(all_factors, vertex), vertices(all_factors)) -end +factors(all_factors::AbstractGraph) = factors(all_factors, vertices(all_factors)) message(_messages, _edge) = not_implemented() message(messages::AbstractGraph, edge) = messages[edge] @@ -133,9 +131,7 @@ function messages(all_messages, edges) end # Specific for graphs -function messages(all_messages::AbstractGraph) - return map(edge -> message(all_messages, edge), edges(all_messages)) -end +messages(all_messages::AbstractGraph) = messages(all_messages, edges(all_messages)) # Specific to the concrete type. messages(cache::MessageCache) = cache.messages From 38c391aefb9721346dcddcf1a5c13f3393a949a3 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 30 Apr 2026 16:33:30 -0400 Subject: [PATCH 76/86] Refactor `BeliefPropagationCache` -> `MessageCache`, remove abstract type; other small changes. --- src/ITensorNetworksNext.jl | 3 +- .../abstractbeliefpropagationcache.jl | 138 ---------- .../beliefpropagationcache.jl | 171 ------------ .../beliefpropagationproblem.jl | 67 ++--- src/beliefpropagation/messagecache.jl | 247 ++++++++++++++++++ test/test_beliefpropagation.jl | 80 +++--- 6 files changed, 313 insertions(+), 393 deletions(-) delete mode 100644 src/beliefpropagation/abstractbeliefpropagationcache.jl delete mode 100644 src/beliefpropagation/beliefpropagationcache.jl create mode 100644 src/beliefpropagation/messagecache.jl diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index d3c5c21..dd7dc50 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -9,8 +9,7 @@ include("contract_network.jl") include("sweeping/utils.jl") include("sweeping/eigenproblem.jl") -include("beliefpropagation/abstractbeliefpropagationcache.jl") -include("beliefpropagation/beliefpropagationcache.jl") +include("beliefpropagation/messagecache.jl") include("beliefpropagation/beliefpropagationproblem.jl") end diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl deleted file mode 100644 index e39d02f..0000000 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ /dev/null @@ -1,138 +0,0 @@ -using DataGraphs: AbstractDataGraph, edge_data, edge_data_type, vertex_data -using Graphs: AbstractEdge, AbstractGraph -using NamedGraphs.GraphsExtensions: boundary_edges -using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, parent -using NamedGraphs: AbstractEdges, AbstractVertices, to_graph_index - -messages(bpc::AbstractDataGraph) = edge_data(bpc) -messages(bpc::AbstractGraph, edges) = map(e -> message(bpc, e), edges) - -message(bpc::AbstractGraph, edge) = messages(bpc)[edge] - -deletemessage!(bpc::AbstractGraph, edge) = not_implemented() - -function deletemessages!(bpc::AbstractGraph, edges = edges(bpc)) - for e in edges - deletemessage!(bpc, e) - end - return bpc -end - -# Fallback; assume `setindex!` is implemented. -function setmessage!(bpc::AbstractGraph, edge, message) - bpc[edge] = message - return bpc -end -function setmessages!(bpc::AbstractGraph, messages) - for (key, val) in messages - setmessage!(bpc, key, val) - end - return bpc -end -function setmessages!(bpc_dst::AbstractGraph, bpc_src::AbstractGraph, edges) - for e in edges - setmessage!(bpc_dst, e, message(bpc_src, e)) - end - return bpc_dst -end - -factors(bpc::AbstractDataGraph) = vertex_data(bpc) -factors(bpc::AbstractGraph, vertices) = map(v -> factor(bpc, v), vertices) - -factor(bpc::AbstractGraph, vertex) = bpc[vertex] - -function setfactor!(bpc::AbstractGraph, vertex, factor) - bpc[vertex] = factor - return bpc -end - -# Internal convenience only -_graph_index_scalar(bpc::AbstractGraph, vertex) = vertex_scalar(bpc, vertex) -_graph_index_scalar(bpc::AbstractGraph, edge::AbstractEdge) = edge_scalar(bpc, edge) - -function edge_scalar(bp_cache::AbstractGraph, edge; kwargs...) - m1s = messages(bp_cache, [edge]) - m2s = messages(bp_cache, [reverse(edge)]) - return contract_network(vcat(m1s, m2s); kwargs...)[] -end - -function vertex_scalar(bp_cache::AbstractGraph, vertex; kwargs...) - messages = incoming_messages(bp_cache, vertex) - state = factors(bp_cache, [vertex]) - - return contract_network(vcat(messages, state); kwargs...)[] -end - -message_type(bpc::AbstractGraph) = message_type(typeof(bpc)) -message_type(G::Type{<:AbstractGraph}) = eltype(Base.promote_op(messages, G)) -message_type(type::Type{<:AbstractDataGraph}) = edge_data_type(type) - -function vertex_scalars(bp_cache::AbstractGraph, vertices = vertices(bp_cache)) - return map(v -> vertex_scalar(bp_cache, v), vertices) -end - -function edge_scalars( - bp_cache::AbstractGraph, - edges = edges(undirected_graph(underlying_graph(bp_cache))) - ) - return map(e -> edge_scalar(bp_cache, e), edges) -end - -function region_scalar(bpc::AbstractGraph, region) - return mapreduce(ind -> _graph_index_scalar(bpc, to_graph_index(bpc, ind)), *, region) -end - -function incoming_messages(bp_cache::AbstractGraph, vertices; ignore_edges = []) - b_edges = boundary_edges(bp_cache, [vertices;]; dir = :in) - if !isempty(ignore_edges) - b_edges = setdiff(b_edges, to_graph_index(bp_cache, ignore_edges)) - end - return messages(bp_cache, b_edges) -end - -#Adapt interface for changing device -map_messages(f, bp_cache, es = edges(bp_cache)) = map_messages!(f, copy(bp_cache), es) -function map_messages!(f, bp_cache, es = edges(bp_cache)) - for e in es - setmessage!(bp_cache, e, f(message(bp_cache, e))) - end - return bp_cache -end - -map_factors(f, bp_cache, vs = vertices(bp_cache)) = map_factors!(f, copy(bp_cache), vs) -function map_factors!(f, bp_cache, vs = vertices(bp_cache)) - for v in vs - setfactor!(bp_cache, v, f(factor(bp_cache, v))) - end - return bp_cache -end - -adapt_messages(to, bp_cache, es = edges(bp_cache)) = map_messages(adapt(to), bp_cache, es) -adapt_factors(to, bp_cache, vs = vertices(bp_cache)) = map_factors(adapt(to), bp_cache, vs) - -abstract type AbstractBeliefPropagationCache{V, VD, ED} <: AbstractDataGraph{V, VD, ED} end - -factor_type(bpc::AbstractBeliefPropagationCache) = factor_type(typeof(bpc)) -factor_type(::Type{<:AbstractBeliefPropagationCache{<:Any, VD}}) where {VD} = VD - -message_type(bpc::AbstractBeliefPropagationCache) = message_type(typeof(bpc)) -message_type(::Type{<:AbstractBeliefPropagationCache{<:Any, <:Any, ED}}) where {ED} = ED - -function logscalar(bpc::AbstractBeliefPropagationCache) - numerator_terms = vertex_scalars(bpc) - denominator_terms = edge_scalars(bpc) - - if any(t -> real(t) < 0, numerator_terms) - numerator_terms = complex.(numerator_terms) - end - if any(t -> real(t) < 0, denominator_terms) - denominator_terms = complex.(denominator_terms) - end - - if any(iszero, denominator_terms) - return -Inf - end - - return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) -end -scalar(bp_cache::AbstractBeliefPropagationCache) = exp(logscalar(bp_cache)) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl deleted file mode 100644 index 83ed700..0000000 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ /dev/null @@ -1,171 +0,0 @@ -using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, edge_data_type, - set_vertex_data!, underlying_graph, underlying_graph_type, vertex_data, vertex_data_type -using Dictionaries: Dictionary, delete!, getindices, set! -using Graphs: AbstractGraph, connected_components, is_directed, is_tree -using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, lazy, parenttype -using NamedGraphs.GraphsExtensions: IsDirected, default_root_vertex, directed_graph, - forest_cover, post_order_dfs_edges, undirected_graph, vertextype -using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, quotient_graph -using NamedGraphs: Vertices, convert_vertextype, parent_graph_indices -using SimpleTraits: SimpleTraits, @traitfn, Not - -struct BeliefPropagationCache{V, VD, ED, E, G <: AbstractGraph{V}} <: - AbstractBeliefPropagationCache{V, VD, ED} - underlying_graph::G # we only use this for the edges. - factors::Dictionary{V, VD} - messages::Dictionary{E, ED} - function BeliefPropagationCache( - graph::AbstractGraph, - factors::Dictionary, - messages::Dictionary - ) - # Ensure the graph is directed and if not, make it directed. - digraph = directed_graph(graph) - - V = keytype(factors) - VD = eltype(factors) - - E = keytype(messages) - ED = eltype(messages) - - bpc = new{V, VD, ED, E, typeof(digraph)}(digraph, factors, messages) - - return bpc - end -end - -DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = bpc.underlying_graph - -function DataGraphs.is_vertex_assigned(bpc::BeliefPropagationCache, vertex) - return haskey(bpc.factors, vertex) -end -DataGraphs.is_edge_assigned(bpc::BeliefPropagationCache, edge) = haskey(bpc.messages, edge) - -DataGraphs.get_vertex_data(bpc::BeliefPropagationCache, vertex) = bpc.factors[vertex] -function DataGraphs.get_edge_data(bpc::BeliefPropagationCache, edge::AbstractEdge) - return bpc.messages[edge] -end - -function DataGraphs.set_vertex_data!(bpc::BeliefPropagationCache, val, vertex) - return set!(bpc.factors, vertex, val) -end -function DataGraphs.set_edge_data!(bpc::BeliefPropagationCache, val, edge) - return set!(bpc.messages, edge, val) -end - -# These two methods assume `network` behaves llike a tensor network -# (could be e.g. a QuotientView) otherwise how would one know what the factors should be. -function BeliefPropagationCache(network::AbstractGraph) - graph = underlying_graph(network) - return BeliefPropagationCache(graph, copy(vertex_data(network))) -end -function BeliefPropagationCache(callable::Base.Callable, network::AbstractGraph) - graph = underlying_graph(network) - return BeliefPropagationCache(callable, graph, copy(vertex_data(network))) -end - -function BeliefPropagationCache(graph::AbstractGraph, factors::Dictionary) - MT = eltype(factors) - return BeliefPropagationCache(MT, graph, factors) -end - -@traitfn function BeliefPropagationCache( - f::Function, - graph::AbstractGraph::!(IsDirected), - factors::Dictionary - ) - return BeliefPropagationCache(f, directed_graph(graph), factors) -end -@traitfn function BeliefPropagationCache( - f::Function, - graph::AbstractGraph::IsDirected, - factors::Dictionary - ) - messages = map(f, Indices(edges(graph))) - return BeliefPropagationCache(graph, factors, messages) -end - -function BeliefPropagationCache(MT::Type, graph::AbstractGraph, factors::Dictionary) - messages = Dictionary{edgetype(graph), MT}() - return BeliefPropagationCache(graph, factors, messages) -end - -function Base.copy(bp_cache::BeliefPropagationCache) - return BeliefPropagationCache( - copy(bp_cache.underlying_graph), - copy(bp_cache.factors), - copy(bp_cache.messages) - ) -end - -# TODO: This needs to go in GraphsExtensions -function forest_cover_edge_sequence(gi::AbstractGraph; root_vertex = default_root_vertex) - # All we care about are the edges so the type of the graph doesnt matter - g = NamedGraph(vertices(gi)) - add_edges!(g, edges(gi)) - forests = forest_cover(g) - rv = edgetype(g)[] - for forest in forests - trees = [forest[Vertices(vs)] for vs in connected_components(forest)] - for tree in trees - tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) - push!(rv, vcat(tree_edges, reverse(reverse.(tree_edges)))...) - end - end - return rv -end - -function induced_subgraph_bpcache(graph, subvertices) - underlying_subgraph, vlist = - Graphs.induced_subgraph(underlying_graph(graph), subvertices) - - assigned = v -> isassigned(graph, v) - - assigned_subvertices = Iterators.filter(assigned, subvertices) - assigned_subedges = Iterators.filter(assigned, edges(underlying_subgraph)) - - factors = getindices(vertex_data(graph), Indices(assigned_subvertices)) - messages = getindices(edge_data(graph), Indices(assigned_subedges)) - - subgraph = BeliefPropagationCache(underlying_subgraph, factors, messages) - - return subgraph, vlist -end - -function NamedGraphs.induced_subgraph_from_vertices( - graph::BeliefPropagationCache, - subvertices - ) - return induced_subgraph_bpcache(graph, subvertices) -end - -## PartitionedGraphs - -function PartitionedGraphs.partitioned_vertices(bpc::BeliefPropagationCache) - return partitioned_vertices(bpc.underlying_graph) -end - -# Take a QuotientView of the underlying graph. -function PartitionedGraphs.quotientview(bpc::BeliefPropagationCache) - graph = underlying_graph(bpc) - - quotient_view = QuotientView(graph) - - factors = map(v -> bpc[QuotientVertex(v)], Indices(vertices(quotient_view))) - messages = map(e -> bpc[QuotientEdge(e)], Indices(edges(quotient_view))) - - return BeliefPropagationCache(quotient_view, factors, messages) -end - -NamedGraphs.to_graph_index(::BeliefPropagationCache, vertex::QuotientVertex) = vertex -# When getting data according the quotient vertices, take a lazy contraction. -function DataGraphs.get_index_data(tn::BeliefPropagationCache, vertex::QuotientVertex) - data = collect(map(v -> tn[v], vertices(tn, vertex))) - return mapreduce(lazy, *, data) -end -function DataGraphs.is_graph_index_assigned( - tn::BeliefPropagationCache, - vertex::QuotientVertex - ) - return isassigned(tn, Vertices(vertices(tn, vertex))) -end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index e4a1a00..729c10e 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -8,13 +8,12 @@ using NamedDimsArrays: AbstractNamedDimsArray using NamedGraphs.GraphsExtensions: add_edges!, boundary_edges, subgraph using NamedGraphs.PartitionedGraphs: quotientvertices -@kwdef struct StopWhenConverged{Tol <: Real} <: AI.StoppingCriterion - tol::Tol = 0.0 +@kwdef struct StopWhenConverged <: AI.StoppingCriterion + tol::Float64 end -@kwdef mutable struct StopWhenConvergedState{Iterate, Delta <: Real} <: - AI.StoppingCriterionState - delta::Delta = Inf +@kwdef mutable struct StopWhenConvergedState{Iterate} <: AI.StoppingCriterionState + delta::Float64 = Inf at_iteration::Int = -1 previous_iterate::Iterate end @@ -40,8 +39,6 @@ function AI.is_finished!( c::StopWhenConverged, st::StopWhenConvergedState ) - - # maxdiff = 0.0 initially, so skip this the first time. iterate = state.iterate previous_iterate = st.previous_iterate @@ -49,6 +46,7 @@ function AI.is_finished!( st.previous_iterate = copy(iterate) + # maxdiff = 0.0 initially, so skip this the first time. state.iteration == 0 && return false st.delta = delta @@ -71,13 +69,13 @@ function AI.is_finished( return st.delta < c.tol end -struct BeliefPropagationProblem{Network} <: AIE.Problem - network::Network +struct BeliefPropagationProblem{Factors} <: AIE.Problem + factors::Factors end function iterate_diff( - cache1::AbstractBeliefPropagationCache, - cache2::AbstractBeliefPropagationCache + cache1::MessageCache, + cache2::MessageCache ) return maximum(edges(cache1)) do edge m1 = cache1[edge] @@ -145,7 +143,7 @@ function AIE.set_substate!( ::BeliefPropagationProblem, ::BeliefPropagationSweep, state::AIE.DefaultState, - cache::AbstractBeliefPropagationCache + cache::MessageCache ) state.iterate = cache @@ -153,18 +151,19 @@ function AIE.set_substate!( end function AI.solve!( - ::BeliefPropagationProblem, + problem::BeliefPropagationProblem, algorithm::SimpleMessageUpdate, - cache::AbstractBeliefPropagationCache; kwargs... + cache::MessageCache; + logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm) ) edge = algorithm.edge vertex = src(edge) - messages = incoming_messages(cache, vertex; ignore_edges = [reverse(edge)]) - tensors = vcat([factor(cache, vertex)], messages) + messages = incoming_messages(cache, vertex; ignore_edges = [reverse(edge)]) + factors = vcat([factor(problem.factors, vertex)], messages) - new_message = contract_network(tensors; algorithm.contraction_alg) + new_message = contract_network(factors; algorithm.contraction_alg) if algorithm.normalize message_norm = sum(new_message) @@ -178,32 +177,22 @@ function AI.solve!( return cache end -function beliefpropagation(network; kwargs...) - return beliefpropagation(BeliefPropagationCache(network), network; kwargs...) +function beliefpropagation(network::AbstractGraph, messages::Dictionary; kwargs...) + cache = MessageCache(messages, network) + return beliefpropagation(network, cache; kwargs...) end -function beliefpropagation( - cache::AbstractBeliefPropagationCache, - network = nothing; - kwargs... - ) +function beliefpropagation(network, cache; kwargs...) problem = BeliefPropagationProblem(network) algorithm = select_algorithm(beliefpropagation, cache; kwargs...) state = AI.solve(problem, algorithm; iterate = cache) - return state.iterate + return state.iterate # -> typeof(cache) end -function select_algorithm( - ::typeof(beliefpropagation), - cache::AbstractBeliefPropagationCache; - edges = forest_cover_edge_sequence(cache), - maxiter = is_tree(cache) ? 1 : nothing, - tol = nothing, - kwargs... - ) +function default_stopping_criterion(::typeof(beliefpropagation); maxiter, tol) if isnothing(maxiter) throw(ArgumentError("`maxiter` must be specified for non-tree graphs")) end @@ -214,6 +203,18 @@ function select_algorithm( stopping_criterion = stopping_criterion | StopWhenConverged(tol) end + return stopping_criterion +end + +function select_algorithm( + alg::typeof(beliefpropagation), + cache::MessageCache; + edges = forest_cover_edge_sequence(cache), + maxiter = is_tree(cache) ? 1 : nothing, + tol = nothing, + stopping_criterion = default_stopping_criterion(alg; maxiter, tol), + kwargs... + ) extended_kwargs = extend_columns((; kwargs...), maxiter) edge_kwargs = rows(extended_kwargs, maxiter) diff --git a/src/beliefpropagation/messagecache.jl b/src/beliefpropagation/messagecache.jl new file mode 100644 index 0000000..ec2e1f1 --- /dev/null +++ b/src/beliefpropagation/messagecache.jl @@ -0,0 +1,247 @@ +using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, edge_data_type, + set_vertex_data!, underlying_graph, underlying_graph_type, vertex_data, vertex_data_type +using Dictionaries: Dictionary, delete!, getindices, set! +using Graphs: AbstractGraph, connected_components, is_directed, is_tree +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, lazy, parenttype +using NamedGraphs.GraphsExtensions: IsDirected, default_root_vertex, directed_graph, + forest_cover, post_order_dfs_edges, undirected_graph, vertextype +using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, quotient_graph +using NamedGraphs: + NamedDiGraph, Vertices, convert_vertextype, parent_graph_indices, to_graph_index + +struct MessageCache{MT, V, E} <: AbstractDataGraph{V, Nothing, MT} + messages::Dictionary{E, MT} + underlying_graph::NamedDiGraph{V} + global function _MessageCache( + messages::Dictionary{E, MT}, + underlying_graph::NamedDiGraph{V} + ) where {MT, V, E} + return new{MT, V, E}(messages, underlying_graph) + end +end + +DataGraphs.underlying_graph(c::MessageCache) = c.underlying_graph + +DataGraphs.is_vertex_assigned(::MessageCache, _) = false +DataGraphs.is_edge_assigned(c::MessageCache, edge) = haskey(c.messages, edge) + +function DataGraphs.get_edge_data(c::MessageCache, edge::AbstractEdge) + return c.messages[edge] +end +function DataGraphs.set_edge_data!(c::MessageCache, val, edge) + return set!(c.messages, edge, val) +end + +# Utility function for constructing a directed graph with existing edges + all reverses. +function _message_cache_underlying_graph(graph::AbstractGraph) + digraph = similar_graph(NamedDiGraph, vertices(graph)) + for edge in edges(graph) + add_edge!(digraph, edge) + if !is_directed(graph) + add_edge!(digraph, reverse(edge)) + end + end + return digraph +end + +MessageCache(::UndefInitializer, graph::AbstractGraph) = MessageCache{Any}(undef, graph) + +function MessageCache{ED}(::UndefInitializer, graph::AbstractGraph) where {ED} + messages = Dictionary{edgetype(graph), ED}() + return MessageCache(messages, graph) +end + +function MessageCache(f::Function, graph::AbstractGraph) + digraph = _message_cache_underlying_graph(graph) + messages = map(f, Indices(edges(digraph))) + return MessageCache(messages, digraph) +end + +function MessageCache(messages, graph::AbstractGraph) + digraph = _message_cache_underlying_graph(graph) + return _MessageCache(Dictionary(messages), digraph) # Call the inner constructor. +end + +function Base.copy(cache::MessageCache) + return MessageCache(copy(cache.messages), copy(cache.underlying_graph)) +end + +function Base.:(==)(cache1::MessageCache, cache2::MessageCache) + if cache1.underlying_graph != cache2.underlying_graph + return false + elseif cache1.messages != cache2.messages + return false + end + return true +end + +function NamedGraphs.induced_subgraph_from_vertices(cache::MessageCache, subvertices) + underlying_subgraph, vlist = + Graphs.induced_subgraph(cache.underlying_graph, subvertices) + + assigned = v -> isassigned(cache, v) + + assigned_subedges = Iterators.filter(assigned, edges(underlying_subgraph)) + + messages = getindices(edge_data(cache), Indices(assigned_subedges)) + + subgraph = MessageCache(messages, underlying_subgraph) + + return subgraph, vlist +end + +# TODO: This needs to go in GraphsExtensions +function forest_cover_edge_sequence(gi::AbstractGraph; root_vertex = default_root_vertex) + # All we care about are the edges so the type of the graph doesnt matter + g = NamedGraph(vertices(gi)) + add_edges!(g, edges(gi)) + forests = forest_cover(g) + rv = edgetype(g)[] + for forest in forests + trees = [forest[Vertices(vs)] for vs in connected_components(forest)] + for tree in trees + tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) + push!(rv, vcat(tree_edges, reverse(reverse.(tree_edges)))...) + end + end + return rv +end + +# =============================== message/factor interface =============================== # + +message_type(::Type) = not_implemented() +message_type(cache) = message_type(typeof(cache)) +message_type(T::Type{<:MessageCache}) = edge_data_type(T) + +factor(_factors, _vertex) = not_implemented() +factor(factors::AbstractGraph, vertex) = factors[vertex] + +function factors(all_factors, vertices) + return map(vertex -> factor(all_factors, vertex), vertices) +end + +# Specific for graphs +function factors(all_factors::AbstractGraph) + return map(vertex -> factor(all_factors, vertex), vertices(all_factors)) +end + +message(_messages, _edge) = not_implemented() +message(messages::AbstractGraph, edge) = messages[edge] + +function messages(all_messages, edges) + return map(edge -> message(all_messages, edge), edges) +end + +# Specific for graphs +function messages(all_messages::AbstractGraph) + return map(edge -> message(all_messages, edge), edges(all_messages)) +end + +# Specific to the concrete type. +messages(cache::MessageCache) = cache.messages + +function incoming_messages(cache::AbstractGraph, vertices; ignore_edges = []) + b_edges = boundary_edges(cache, [vertices;]; dir = :in) + if !isempty(ignore_edges) + b_edges = setdiff(b_edges, to_graph_index(cache, ignore_edges)) + end + return messages(cache, b_edges) +end + +function setmessage!(cache::AbstractGraph, edge, message) + cache[edge] = message + return cache +end +function setmessages!(cache::AbstractGraph, messages) + for (key, val) in messages + setmessage!(cache, key, val) + end + return cache +end +function setmessages!(cache_dst::AbstractGraph, cache_src::AbstractGraph, edges) + for e in edges + setmessage!(cache_dst, e, message(cache_src, e)) + end + return cache_dst +end + +# =================================== adapt interface ==================================== # + +map_messages(f, cache, es = edges(cache)) = map_messages!(f, copy(cache), es) +function map_messages!(f, cache, es = edges(cache)) + for e in es + setmessage!(cache, e, f(message(cache, e))) + end + return cache +end + +adapt_messages(to, cache, es = edges(cache)) = map_messages(adapt(to), cache, es) + +# ===================================== contraction ====================================== # + +function vertex_scalar(factors, messages, vertex; kwargs...) + in_messages = incoming_messages(messages, vertex) + state = [factor(factors, vertex)] + return contract_network(vcat(in_messages, state); kwargs...)[] +end + +vertex_scalars(factors, messages) = vertex_scalars(factors, messages, keys(factors)) +function vertex_scalars(factors::AbstractGraph, messages) + return vertex_scalars(factors, messages, vertices(factors)) +end +function vertex_scalars(factors, messages, vertices) + return map(v -> vertex_scalar(factors, messages, v), vertices) +end + +function edge_scalar(cache, edge; kwargs...) + m1s = messages(cache, [edge]) + m2s = messages(cache, [reverse(edge)]) + return contract_network(vcat(m1s, m2s); kwargs...)[] +end + +edge_scalars(cache) = edge_scalars(cache, keys(cache)) +edge_scalars(cache::AbstractGraph) = edge_scalars(cache, edges(cache)) + +function edge_scalars(cache, edges) + processed = Set{eltype(edges)}() + + T = Base.promote_op(edge_scalar, typeof(cache), eltype(edges)) + + scalars = T[] + + # Ignore repeated edges and their reverses. + for e in edges + if e in processed || reverse(e) in processed + continue + end + push!(processed, e) + push!(scalars, edge_scalar(cache, e)) + end + + return scalars +end + +function region_scalar(factors, messages, region) + return mapreduce(vertex -> vertex_scalar(factors, messages, vertex), *, region) +end + +# We need a graph structure here, so assume `factors` is a graph. +function logscalar(factors, messages) + numerator_terms = vertex_scalars(factors, messages) + denominator_terms = edge_scalars(messages) + + if any(t -> real(t) < 0, numerator_terms) + numerator_terms = complex.(numerator_terms) + end + if any(t -> real(t) < 0, denominator_terms) + denominator_terms = complex.(denominator_terms) + end + + if any(iszero, denominator_terms) + return -Inf + end + + return sum(log.(numerator_terms)) - sum(log.(denominator_terms)) +end + +scalar(factors, messages) = exp(logscalar(factors, messages)) diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 244d780..13b9af2 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -1,12 +1,12 @@ +using DataGraphs: edge_data using DiagonalArrays: δ using Dictionaries: Dictionary, set! using Graphs: AbstractGraph, dst, edges, has_edge, src, vertices using ITensorBase: ITensor, Index, noprime, prime -using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, - adapt_factors, adapt_messages, edge_scalar, factor, factor_type, factors, - incoming_messages, linkinds, map_factors, map_messages, message, message_type, messages, - region_scalar, scalar, setfactor!, setmessage!, setmessages!, subgraph, vertex_scalar, - vertex_scalars +using ITensorNetworksNext: ITensorNetworksNext, MessageCache, TensorNetwork, adapt_messages, + edge_scalar, factor, factors, incoming_messages, linkinds, map_messages, message, + message_type, messages, region_scalar, scalar, setmessage!, setmessages!, subgraph, + vertex_scalar, vertex_scalars using LinearAlgebra: LinearAlgebra using NamedDimsArrays: inds, name using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges, vertextype @@ -37,8 +37,8 @@ function spin_ice_tensornetwork(g) return TensorNetwork(g, ts) end -@testset "BeliefPropagation" begin - @testset "`BeliefPropagationCache`" begin +@testset "Belief propagation" begin + @testset "`MessageCache`" begin @testset "Basics" begin dims = (3, 3) g = named_grid(dims) @@ -50,24 +50,19 @@ end return randn(Tuple(is)) end - bpc = BeliefPropagationCache(tn) do edge + # By default for graphs, assume factors refers to the vertex data + @test length(factors(tn)) == 9 + @test factor(tn, (1, 1)) == tn[(1, 1)] + + bpc = MessageCache(tn) do edge return "$(src(edge)) => $(dst(edge))" end - @test factor_type(bpc) <: ITensor @test message_type(bpc) <: String - @test length(factors(bpc)) == 9 @test length(messages(bpc)) == 2 * length(edges(g)) - @test bpc[(2, 2)] == tn[(2, 2)] - @test factor(bpc, (1, 1)) == tn[(1, 1)] @test bpc[(1, 1) => (1, 2)] == "(1, 1) => (1, 2)" @test message(bpc, (2, 1) => (1, 1)) == "(2, 1) => (1, 1)" - # set factor - f = factor(bpc, (1, 1)) - setfactor!(bpc, (1, 1), 2 * f) - @test factor(bpc, (1, 1)) == 2 * f - # set message setmessage!(bpc, (1, 1) => (1, 2), "new message") @test message(bpc, (1, 1) => (1, 2)) == "new message" @@ -77,7 +72,7 @@ end @test message(bpc, (1, 2) => (2, 2)) == "m1" @test message(bpc, (2, 2) => (2, 3)) == "m2" - bpc_dst = BeliefPropagationCache(tn) do edge + bpc_dst = MessageCache(tn) do edge return "" end setmessages!(bpc_dst, bpc, [(1, 2) => (2, 2), (2, 2) => (2, 3)]) @@ -94,18 +89,16 @@ end return randn(ComplexF32, Tuple(is)) end - bpc = BeliefPropagationCache(tn) do edge + bpc = MessageCache(tn) do edge return ones(Float64, Tuple(linkinds(tn, edge))) end # Vertex/edge/region scalars. - @test vertex_scalar(bpc, 2) isa ComplexF64 + @test vertex_scalar(tn, bpc, 2) isa ComplexF64 @test edge_scalar(bpc, 1 => 2) isa Float64 - @test region_scalar(bpc, [1]) == vertex_scalar(bpc, 1) - @test region_scalar(bpc, [1 => 2]) == edge_scalar(bpc, 1 => 2) - @test region_scalar(bpc, [2 => 1]) == edge_scalar(bpc, 1 => 2) - @test region_scalar(bpc, [1, 2, 3]) == prod(vertex_scalars(bpc)) + @test region_scalar(tn, bpc, [1]) == vertex_scalar(tn, bpc, 1) + @test region_scalar(tn, bpc, [2, 3]) == prod(vertex_scalars(tn, bpc, [2, 3])) # `incoming_messages` excludes specified edges. in_msgs = incoming_messages(bpc, 2) @@ -126,19 +119,7 @@ end @test message(bpc_doubled, 1 => 2) ≈ 2 .* message(bpc, 1 => 2) @test message(bpc_doubled, 2 => 3) ≈ 2 .* message(bpc, 2 => 3) - bpc_again = map_factors(identity, bpc) - @test bpc_again !== bpc - @test bpc_again == bpc - - bpc_scaled = map_factors(f -> f .* 2, bpc) - @test !(bpc_scaled === bpc) - for vv in vertices(bpc_scaled) - @test factor(bpc_scaled, vv) ≈ factor(bpc, vv) .* 2 - end - - # `adapt_factors` and `adapt_messages` should at least be callable. - @test adapt_factors(identity, bpc) isa BeliefPropagationCache - @test adapt_messages(identity, bpc) isa BeliefPropagationCache + @test adapt_messages(identity, bpc) == bpc end @testset "subgraph" begin @@ -149,13 +130,13 @@ end is = map(e -> l[e], incident_edges(g, v)) return randn(Tuple(is)) end - bpc = BeliefPropagationCache(tn) do edge + bpc = MessageCache(tn) do edge return ones(Tuple(linkinds(tn, edge))) end sub_vs = [(1,), (2,)] subbpc = subgraph(bpc, sub_vs) - @test subbpc isa BeliefPropagationCache + @test subbpc isa MessageCache @test issetequal(vertices(subbpc), sub_vs) @test has_edge(subbpc, (1,) => (2,)) end @@ -168,10 +149,10 @@ end return randn(Tuple(is)) end - bpc1 = BeliefPropagationCache(tn) do edge + bpc1 = MessageCache(tn) do edge return ones(Tuple(linkinds(tn, edge))) end - bpc2 = BeliefPropagationCache(tn) do edge + bpc2 = MessageCache(tn) do edge return ones(Tuple(linkinds(tn, edge))) end @@ -193,11 +174,11 @@ end return randn(T, Tuple(is)) end - bpc = BeliefPropagationCache(tn) do edge + bpc = MessageCache(tn) do edge return ones(T, Tuple(linkinds(tn, edge))) end - bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) - z_bp = scalar(bpc) + bpc = ITensorNetworksNext.beliefpropagation(tn, bpc; maxiter = 1) + z_bp = scalar(tn, bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test z_bp ≈ z_exact @@ -211,11 +192,11 @@ end return randn(T, Tuple(is)) end - bpc = BeliefPropagationCache(tn) do edge + bpc = MessageCache(tn) do edge return ones(T, Tuple(linkinds(tn, edge))) end - bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) - z_bp = scalar(bpc) + bpc = ITensorNetworksNext.beliefpropagation(tn, bpc; maxiter = 1) + z_bp = scalar(tn, bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test z_bp ≈ z_exact @@ -226,18 +207,19 @@ end g = named_grid(dims; periodic = true) tn = spin_ice_tensornetwork(g) - bpc = ITensorNetworksNext.BeliefPropagationCache(tn) do edge + bpc = ITensorNetworksNext.MessageCache(tn) do edge # Use `rand` so messages have positive elements. return rand(T, Tuple(linkinds(tn, edge))) end bpc = ITensorNetworksNext.beliefpropagation( + tn, bpc; tol = 1.0e-10, maxiter = 10 ) - z_bp = scalar(bpc) + z_bp = scalar(tn, bpc) @test z_bp ≈ 1.5^(n^2) end From 2c8b45057b630341842cbb72f13d65bed836c5ac Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 1 May 2026 10:31:35 -0400 Subject: [PATCH 77/86] Allow a custom stopping criteria input into `beliefpropagation` using the `stopping_criterion` kwarg. --- .../beliefpropagationproblem.jl | 44 ++++++++++++++----- test/test_beliefpropagation.jl | 16 ++++--- 2 files changed, 43 insertions(+), 17 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 729c10e..d09ee3c 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -185,36 +185,58 @@ end function beliefpropagation(network, cache; kwargs...) problem = BeliefPropagationProblem(network) - algorithm = select_algorithm(beliefpropagation, cache; kwargs...) + algorithm = select_algorithm(beliefpropagation, network, cache; kwargs...) state = AI.solve(problem, algorithm; iterate = cache) return state.iterate # -> typeof(cache) end -function default_stopping_criterion(::typeof(beliefpropagation); maxiter, tol) - if isnothing(maxiter) - throw(ArgumentError("`maxiter` must be specified for non-tree graphs")) - end +function default_stopping_criterion( + ::typeof(beliefpropagation); + maxiter, + tol, + stopping_criterion + ) + base_stopping_criterion = AI.StopAfterIteration(maxiter) - stopping_criterion = AI.StopAfterIteration(maxiter) + if !isnothing(stopping_criterion) + base_stopping_criterion |= stopping_criterion + end if !isnothing(tol) - stopping_criterion = stopping_criterion | StopWhenConverged(tol) + base_stopping_criterion |= StopWhenConverged(tol) end - return stopping_criterion + return base_stopping_criterion end function select_algorithm( - alg::typeof(beliefpropagation), + ::typeof(beliefpropagation), + network::AbstractGraph, cache::MessageCache; edges = forest_cover_edge_sequence(cache), - maxiter = is_tree(cache) ? 1 : nothing, + maxiter = is_tree(network) ? 1 : nothing, tol = nothing, - stopping_criterion = default_stopping_criterion(alg; maxiter, tol), + stopping_criterion = nothing, kwargs... ) + if isnothing(maxiter) + throw( + ArgumentError( + "`maxiter` must be specified for non-tree graphs, even when + `stopping_criterion` is provided." + ) + ) + end + + stopping_criterion = default_stopping_criterion( + beliefpropagation; + maxiter, + tol, + stopping_criterion + ) + extended_kwargs = extend_columns((; kwargs...), maxiter) edge_kwargs = rows(extended_kwargs, maxiter) diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 13b9af2..471add5 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -1,12 +1,13 @@ +import AlgorithmsInterface as AI using DataGraphs: edge_data using DiagonalArrays: δ using Dictionaries: Dictionary, set! using Graphs: AbstractGraph, dst, edges, has_edge, src, vertices using ITensorBase: ITensor, Index, noprime, prime -using ITensorNetworksNext: ITensorNetworksNext, MessageCache, TensorNetwork, adapt_messages, - edge_scalar, factor, factors, incoming_messages, linkinds, map_messages, message, - message_type, messages, region_scalar, scalar, setmessage!, setmessages!, subgraph, - vertex_scalar, vertex_scalars +using ITensorNetworksNext: ITensorNetworksNext, MessageCache, StopWhenConverged, + TensorNetwork, adapt_messages, edge_scalar, factor, factors, incoming_messages, + linkinds, map_messages, message, message_type, messages, region_scalar, scalar, + setmessage!, setmessages!, subgraph, vertex_scalar, vertex_scalars using LinearAlgebra: LinearAlgebra using NamedDimsArrays: inds, name using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges, vertextype @@ -211,12 +212,15 @@ end # Use `rand` so messages have positive elements. return rand(T, Tuple(linkinds(tn, edge))) end + + stopping_criterion = StopWhenConverged(tol = 1.0e-10) + bpc = ITensorNetworksNext.beliefpropagation( tn, bpc; - tol = 1.0e-10, - maxiter = 10 + maxiter = 10, + stopping_criterion ) z_bp = scalar(tn, bpc) From bc22b67b2831c660acaec7b256b6bbbba856d968 Mon Sep 17 00:00:00 2001 From: Jack Dunham <72548217+jack-dunham@users.noreply.github.com> Date: Fri, 1 May 2026 09:54:16 -0400 Subject: [PATCH 78/86] Hard code edge type in `MessageCache`. Co-authored-by: Matthew Fishman --- src/beliefpropagation/messagecache.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/beliefpropagation/messagecache.jl b/src/beliefpropagation/messagecache.jl index ec2e1f1..fc8e3c6 100644 --- a/src/beliefpropagation/messagecache.jl +++ b/src/beliefpropagation/messagecache.jl @@ -9,8 +9,8 @@ using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, quotient_graph using NamedGraphs: NamedDiGraph, Vertices, convert_vertextype, parent_graph_indices, to_graph_index -struct MessageCache{MT, V, E} <: AbstractDataGraph{V, Nothing, MT} - messages::Dictionary{E, MT} +struct MessageCache{MT, V} <: AbstractDataGraph{V, Nothing, MT} + messages::Dictionary{NamedEdge{V}, MT} underlying_graph::NamedDiGraph{V} global function _MessageCache( messages::Dictionary{E, MT}, From 115dcff8af44a8c3c93f7d61834c6d505480ed23 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 1 May 2026 10:36:48 -0400 Subject: [PATCH 79/86] Remove `MessageCache` undef initializer. --- src/beliefpropagation/messagecache.jl | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/src/beliefpropagation/messagecache.jl b/src/beliefpropagation/messagecache.jl index 07fb748..6f46116 100644 --- a/src/beliefpropagation/messagecache.jl +++ b/src/beliefpropagation/messagecache.jl @@ -12,12 +12,6 @@ using NamedGraphs: struct MessageCache{MT, V} <: AbstractDataGraph{V, Nothing, MT} messages::Dictionary{NamedEdge{V}, MT} underlying_graph::NamedDiGraph{V} - global function _MessageCache( - messages::Dictionary{E, MT}, - underlying_graph::NamedDiGraph{V} - ) where {MT, V, E} - return new{MT, V, E}(messages, underlying_graph) - end end DataGraphs.underlying_graph(c::MessageCache) = c.underlying_graph @@ -44,13 +38,6 @@ function _message_cache_underlying_graph(graph::AbstractGraph) return digraph end -MessageCache(::UndefInitializer, graph::AbstractGraph) = MessageCache{Any}(undef, graph) - -function MessageCache{ED}(::UndefInitializer, graph::AbstractGraph) where {ED} - messages = Dictionary{edgetype(graph), ED}() - return MessageCache(messages, graph) -end - function MessageCache(f::Function, graph::AbstractGraph) digraph = _message_cache_underlying_graph(graph) messages = map(f, Indices(edges(digraph))) @@ -59,7 +46,7 @@ end function MessageCache(messages, graph::AbstractGraph) digraph = _message_cache_underlying_graph(graph) - return _MessageCache(Dictionary(messages), digraph) # Call the inner constructor. + return MessageCache(Dictionary(messages), digraph) # Call the inner constructor. end function Base.copy(cache::MessageCache) From 1deba51dccc3bc38cc43ea31e91df5db0c94a99e Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 1 May 2026 10:54:53 -0400 Subject: [PATCH 80/86] Rename argument names to be more consistent. --- src/beliefpropagation/messagecache.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/beliefpropagation/messagecache.jl b/src/beliefpropagation/messagecache.jl index 6f46116..e7da293 100644 --- a/src/beliefpropagation/messagecache.jl +++ b/src/beliefpropagation/messagecache.jl @@ -100,25 +100,25 @@ message_type(::Type) = not_implemented() message_type(cache) = message_type(typeof(cache)) message_type(T::Type{<:MessageCache}) = edge_data_type(T) -factor(_factors, _vertex) = not_implemented() -factor(factors::AbstractGraph, vertex) = factors[vertex] +factor(_cache, _vertex) = not_implemented() +factor(cache::AbstractGraph, vertex) = cache[vertex] -function factors(all_factors, vertices) - return map(vertex -> factor(all_factors, vertex), vertices) +function factors(cache, vertices) + return map(vertex -> factor(cache, vertex), vertices) end # Specific for graphs -factors(all_factors::AbstractGraph) = factors(all_factors, vertices(all_factors)) +factors(cache::AbstractGraph) = factors(cache, vertices(cache)) -message(_messages, _edge) = not_implemented() -message(messages::AbstractGraph, edge) = messages[edge] +message(_cache, _edge) = not_implemented() +message(cache::AbstractGraph, edge) = cache[edge] -function messages(all_messages, edges) - return map(edge -> message(all_messages, edge), edges) +function messages(cache, edges) + return map(edge -> message(cache, edge), edges) end # Specific for graphs -messages(all_messages::AbstractGraph) = messages(all_messages, edges(all_messages)) +messages(cache::AbstractGraph) = messages(cache, edges(cache)) # Specific to the concrete type. messages(cache::MessageCache) = cache.messages From 80bbc99bc7096c01b1f436649852f3af9cbfb1ce Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 4 May 2026 16:24:09 -0400 Subject: [PATCH 81/86] Simplify `MessageCache` interface. --- .../beliefpropagationproblem.jl | 25 +- src/beliefpropagation/messagecache.jl | 279 ++++++++++-------- test/test_beliefpropagation.jl | 145 +++++---- 3 files changed, 239 insertions(+), 210 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index d09ee3c..dbe91e2 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -158,12 +158,10 @@ function AI.solve!( ) edge = algorithm.edge - vertex = src(edge) + messages = collect(incoming_messages(cache, edge)) + factor = problem.factors[src(edge)] - messages = incoming_messages(cache, vertex; ignore_edges = [reverse(edge)]) - factors = vcat([factor(problem.factors, vertex)], messages) - - new_message = contract_network(factors; algorithm.contraction_alg) + new_message = contract_network(vcat(messages, [factor]); algorithm.contraction_alg) if algorithm.normalize message_norm = sum(new_message) @@ -172,26 +170,27 @@ function AI.solve!( end end - setmessage!(cache, edge, new_message) + cache[edge] = new_message return cache end -function beliefpropagation(network::AbstractGraph, messages::Dictionary; kwargs...) - cache = MessageCache(messages, network) - return beliefpropagation(network, cache; kwargs...) -end +function beliefpropagation(factors, messages; kwargs...) + problem = BeliefPropagationProblem(factors) -function beliefpropagation(network, cache; kwargs...) - problem = BeliefPropagationProblem(network) + cache = initialize_cache(beliefpropagation, factors, messages) - algorithm = select_algorithm(beliefpropagation, network, cache; kwargs...) + algorithm = select_algorithm(beliefpropagation, factors, cache; kwargs...) state = AI.solve(problem, algorithm; iterate = cache) return state.iterate # -> typeof(cache) end +# Use a `MessageCache` by default. Note if `messages` is already a `MessageCache` this +# will make a copy of the the existing cache, thus protecting the original from mutation. +initialize_cache(::typeof(beliefpropagation), _factors, messages) = MessageCache(messages) + function default_stopping_criterion( ::typeof(beliefpropagation); maxiter, diff --git a/src/beliefpropagation/messagecache.jl b/src/beliefpropagation/messagecache.jl index e7da293..dc20a66 100644 --- a/src/beliefpropagation/messagecache.jl +++ b/src/beliefpropagation/messagecache.jl @@ -3,18 +3,72 @@ using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, edge_data_type, using Dictionaries: Dictionary, delete!, getindices, set! using Graphs: AbstractGraph, connected_components, is_directed, is_tree using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, lazy, parenttype -using NamedGraphs.GraphsExtensions: IsDirected, default_root_vertex, directed_graph, - forest_cover, post_order_dfs_edges, undirected_graph, vertextype +using NamedGraphs.GraphsExtensions: IsDirected, boundary_edges, default_root_vertex, + directed_graph, forest_cover, in_incident_edges, post_order_dfs_edges, undirected_graph, + vertextype using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, quotient_graph -using NamedGraphs: - NamedDiGraph, Vertices, convert_vertextype, parent_graph_indices, to_graph_index +using NamedGraphs: NamedDiGraph, Vertices, convert_vertextype, ordered_vertices, + parent_graph_indices, position_graph, to_graph_index, vertex_positions -struct MessageCache{MT, V} <: AbstractDataGraph{V, Nothing, MT} - messages::Dictionary{NamedEdge{V}, MT} +struct MessageCache{T, V} <: AbstractDataGraph{V, Nothing, T} + messages::Dictionary{NamedEdge{V}, T} underlying_graph::NamedDiGraph{V} + function MessageCache{T, V}(::UndefInitializer, vertices) where {T, V} + messages = Dictionary{NamedEdge{V}, T}() + underlying_graph = NamedDiGraph{V}(vertices) + return new{T, V}(messages, underlying_graph) + end +end + +# single type parameter version of the inner constructor +function MessageCache{T}(::UndefInitializer, vertices) where {T} + return MessageCache{T, eltype(vertices)}(undef, vertices) +end + +# compatibility with generic key-val iterables +Base.keytype(c::MessageCache) = keytype(typeof(c)) +Base.keytype(::Type{<:MessageCache{T, V}}) where {T, V} = NamedEdge{V} + +Base.valtype(c::MessageCache) = valtype(typeof(c)) +Base.valtype(::Type{<:MessageCache{T}}) where {T} = T + +Base.keys(cache::MessageCache) = edges(cache) + +MessageCache(messages) = MessageCache{valtype(messages)}(messages) + +function MessageCache{T}(messages) where {T} + V = vertextype(keytype(messages)) + return MessageCache{T, V}(messages) +end + +# `messages` is any iterable data structure, where `keys(messages)` are edges +# and the values are the messages on those edges. +function MessageCache{T, V}(messages) where {T, V} + edges = keys(messages) + vertices = union(src.(edges), dst.(edges)) + cache = MessageCache{T, V}(undef, vertices) + add_edges!(cache.underlying_graph, edges) + copyto!(cache, messages) + return cache +end + +messagecache(pairs) = MessageCache(Dict(pairs)) + +# ================================ NamedGraphs interface ================================= # +function NamedGraphs.add_edge!(c::MessageCache, edge) + add_edge!(c.underlying_graph, edge) + return c +end + +function NamedGraphs.rem_edge!(c::MessageCache, edge) + delete!(c.messages, to_graph_index(c, edge)) + rem_edge!(c.underlying_graph, edge) + return c end -DataGraphs.underlying_graph(c::MessageCache) = c.underlying_graph +# ================================= DataGraphs interface ================================= # + +DataGraphs.underlying_graph(cache::MessageCache) = cache.underlying_graph DataGraphs.is_vertex_assigned(::MessageCache, _) = false DataGraphs.is_edge_assigned(c::MessageCache, edge) = haskey(c.messages, edge) @@ -26,43 +80,20 @@ function DataGraphs.set_edge_data!(c::MessageCache, val, edge) return set!(c.messages, edge, val) end -# Utility function for constructing a directed graph with existing edges + all reverses. -function _message_cache_underlying_graph(graph::AbstractGraph) - digraph = similar_graph(NamedDiGraph, vertices(graph)) - for edge in edges(graph) - add_edge!(digraph, edge) - if !is_directed(graph) - add_edge!(digraph, reverse(edge)) - end - end - return digraph -end - -function MessageCache(f::Function, graph::AbstractGraph) - digraph = _message_cache_underlying_graph(graph) - messages = map(f, Indices(edges(digraph))) - return MessageCache(messages, digraph) -end +Base.copy(cache::MessageCache) = MessageCache(copy(cache.messages)) -function MessageCache(messages, graph::AbstractGraph) - digraph = _message_cache_underlying_graph(graph) - return MessageCache(Dictionary(messages), digraph) # Call the inner constructor. -end +function Base.:(==)(cache1::MessageCache, cache2::MessageCache) + ug1 = cache1.underlying_graph + ug2 = cache2.underlying_graph -function Base.copy(cache::MessageCache) - return MessageCache(copy(cache.messages), copy(cache.underlying_graph)) -end + ms1 = cache1.messages + ms2 = cache2.messages -function Base.:(==)(cache1::MessageCache, cache2::MessageCache) - if cache1.underlying_graph != cache2.underlying_graph - return false - elseif cache1.messages != cache2.messages - return false - end - return true + return (ug1 == ug2 && ms1 == ms2) end function NamedGraphs.induced_subgraph_from_vertices(cache::MessageCache, subvertices) + # TODO: once we have `subgraph_edges` in `NamedGraphs`, simplify this. underlying_subgraph, vlist = Graphs.induced_subgraph(cache.underlying_graph, subvertices) @@ -70,102 +101,74 @@ function NamedGraphs.induced_subgraph_from_vertices(cache::MessageCache, subvert assigned_subedges = Iterators.filter(assigned, edges(underlying_subgraph)) - messages = getindices(edge_data(cache), Indices(assigned_subedges)) + messages = getindices(cache.messages, Indices(assigned_subedges)) - subgraph = MessageCache(messages, underlying_subgraph) - - return subgraph, vlist -end - -# TODO: This needs to go in GraphsExtensions -function forest_cover_edge_sequence(gi::AbstractGraph; root_vertex = default_root_vertex) - # All we care about are the edges so the type of the graph doesnt matter - g = NamedGraph(vertices(gi)) - add_edges!(g, edges(gi)) - forests = forest_cover(g) - rv = edgetype(g)[] - for forest in forests - trees = [forest[Vertices(vs)] for vs in connected_components(forest)] - for tree in trees - tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) - push!(rv, vcat(tree_edges, reverse(reverse.(tree_edges)))...) - end - end - return rv + return MessageCache(messages), vlist end -# =============================== message/factor interface =============================== # - -message_type(::Type) = not_implemented() -message_type(cache) = message_type(typeof(cache)) -message_type(T::Type{<:MessageCache}) = edge_data_type(T) - -factor(_cache, _vertex) = not_implemented() -factor(cache::AbstractGraph, vertex) = cache[vertex] - -function factors(cache, vertices) - return map(vertex -> factor(cache, vertex), vertices) +# see: copyto!(dest, src) for analogous behaviour to 2 argument method +# see: copyto!(dest, Rdest::CartesianIndices, src, Rsrc::CartesianIndices) +# for analogous behaviour to 3 argument method. +# TODO: these can be made generic for `AbtractDataGraph` in `DataGraphs.jl` +function copyto!_messagecache( + cache_dst::MessageCache, + cache_src, + inds = nothing + ) + inds = isnothing(inds) ? Indices(keys(cache_src)) : Indices(inds) + view(edge_data(cache_dst), inds) .= view(cache_src, inds) + return cache_dst end -# Specific for graphs -factors(cache::AbstractGraph) = factors(cache, vertices(cache)) - -message(_cache, _edge) = not_implemented() -message(cache::AbstractGraph, edge) = cache[edge] - -function messages(cache, edges) - return map(edge -> message(cache, edge), edges) +function Base.copyto!( + cache_dst::MessageCache, + cache_src::AbstractDataGraph, + inds = nothing + ) + copyto!_messagecache(cache_dst, edge_data(cache_src), inds) + return cache_dst end -# Specific for graphs -messages(cache::AbstractGraph) = messages(cache, edges(cache)) - -# Specific to the concrete type. -messages(cache::MessageCache) = cache.messages - -function incoming_messages(cache::AbstractGraph, vertices; ignore_edges = []) - b_edges = boundary_edges(cache, [vertices;]; dir = :in) - if !isempty(ignore_edges) - b_edges = setdiff(b_edges, to_graph_index(cache, ignore_edges)) - end - return messages(cache, b_edges) +function Base.copyto!( + cache_dst::MessageCache, + dictionary_src::Dictionary, + inds = nothing + ) + copyto!_messagecache(cache_dst, dictionary_src, inds) + return cache_dst end -function setmessage!(cache::AbstractGraph, edge, message) - cache[edge] = message - return cache -end -function setmessages!(cache::AbstractGraph, messages) - for (key, val) in messages - setmessage!(cache, key, val) - end - return cache -end -function setmessages!(cache_dst::AbstractGraph, cache_src::AbstractGraph, edges) - for e in edges - setmessage!(cache_dst, e, message(cache_src, e)) +function Base.copyto!( + cache_dst::MessageCache, + dict_src::Dict, + inds = keys(dict_src) + ) + for key in inds + cache_dst[key] = dict_src[key] end return cache_dst end -# =================================== adapt interface ==================================== # +# ===================================== contraction ====================================== # -map_messages(f, cache, es = edges(cache)) = map_messages!(f, copy(cache), es) -function map_messages!(f, cache, es = edges(cache)) - for e in es - setmessage!(cache, e, f(message(cache, e))) - end - return cache +function incoming_messages(cache::AbstractGraph, pair::Pair) + edge = to_graph_index(cache, pair) + return incoming_messages(cache, edge) +end +function incoming_messages(cache::AbstractGraph, edge::AbstractEdge) + inds = Indices(in_incident_edges(cache, src(edge))) + return getindices(cache, filter(e -> e != reverse(edge), inds)) end -adapt_messages(to, cache, es = edges(cache)) = map_messages(adapt(to), cache, es) - -# ===================================== contraction ====================================== # +function environment_messages(cache::AbstractGraph, vertices) + inds = Indices(boundary_edges(cache, vertices; dir = :in)) + return getindices(cache, inds) +end function vertex_scalar(factors, messages, vertex; kwargs...) - in_messages = incoming_messages(messages, vertex) - state = [factor(factors, vertex)] - return contract_network(vcat(in_messages, state); kwargs...)[] + in_messages = environment_messages(messages, [vertex]) + tensors = vcat([factors[vertex]], collect(in_messages)) + return contract_network(tensors; kwargs...)[] end vertex_scalars(factors, messages) = vertex_scalars(factors, messages, keys(factors)) @@ -177,13 +180,12 @@ function vertex_scalars(factors, messages, vertices) end function edge_scalar(cache, edge; kwargs...) - m1s = messages(cache, [edge]) - m2s = messages(cache, [reverse(edge)]) - return contract_network(vcat(m1s, m2s); kwargs...)[] + m1 = cache[edge] + m2 = cache[reverse(edge)] + return contract_network([m1, m2]; kwargs...)[] end edge_scalars(cache) = edge_scalars(cache, keys(cache)) -edge_scalars(cache::AbstractGraph) = edge_scalars(cache, edges(cache)) function edge_scalars(cache, edges) processed = Set{eltype(edges)}() @@ -228,3 +230,40 @@ function logscalar(factors, messages) end scalar(factors, messages) = exp(logscalar(factors, messages)) + +# TODO: This needs to go in NamedGraphs.GraphsExtensions +function forest_cover_edge_sequence(gi::AbstractGraph; root_vertex = default_root_vertex) + # All we care about are the edges so the type of the graph doesnt matter + g = similar_graph(NamedGraph, vertices(gi)) + add_edges!(g, edges(gi)) + forests = forest_cover(g) + rv = edgetype(g)[] + for forest in forests + trees = [forest[Vertices(vs)] for vs in connected_components(forest)] + for tree in trees + tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) + push!(rv, vcat(tree_edges, reverse(reverse.(tree_edges)))...) + end + end + return rv +end + +# ======================================= printing ======================================= # + +# TODO: This is the definition for the proposed `DataGraphs.AbstractEdgeDataGraph`. +function Base.show(io::IO, mime::MIME"text/plain", graph::MessageCache) + println(io, "$(typeof(graph)) with $(nv(graph)) vertices:") + show(io, mime, vertices(graph)) + println(io, "\n") + println(io, "and $(ne(graph)) edge(s):") + for e in edges(graph) + show(io, mime, e) + println(io) + end + println(io) + println(io, "with edge data:") + show(io, mime, edge_data(graph)) + return nothing +end + +Base.show(io::IO, graph::MessageCache) = show(io, MIME"text/plain"(), graph) diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 471add5..ec68655 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -1,17 +1,17 @@ import AlgorithmsInterface as AI -using DataGraphs: edge_data +using DataGraphs: edge_data, edge_data_type using DiagonalArrays: δ -using Dictionaries: Dictionary, set! +using Dictionaries: Dictionary, dictionary, set! using Graphs: AbstractGraph, dst, edges, has_edge, src, vertices using ITensorBase: ITensor, Index, noprime, prime using ITensorNetworksNext: ITensorNetworksNext, MessageCache, StopWhenConverged, - TensorNetwork, adapt_messages, edge_scalar, factor, factors, incoming_messages, - linkinds, map_messages, message, message_type, messages, region_scalar, scalar, - setmessage!, setmessages!, subgraph, vertex_scalar, vertex_scalars + TensorNetwork, edge_scalar, incoming_messages, linkinds, messagecache, region_scalar, + scalar, subgraph, vertex_scalar, vertex_scalars using LinearAlgebra: LinearAlgebra using NamedDimsArrays: inds, name -using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges, vertextype +using NamedGraphs.GraphsExtensions: all_edges, arranged_edges, incident_edges, vertextype using NamedGraphs.NamedGraphGenerators: named_comb_tree, named_grid, named_path_graph +using NamedGraphs: NamedEdge using Test: @test, @testset function spin_ice_tensornetwork(g) @@ -51,35 +51,38 @@ end return randn(Tuple(is)) end - # By default for graphs, assume factors refers to the vertex data - @test length(factors(tn)) == 9 - @test factor(tn, (1, 1)) == tn[(1, 1)] - - bpc = MessageCache(tn) do edge - return "$(src(edge)) => $(dst(edge))" - end + bpc = messagecache( + edge => "$(src(edge)) => $(dst(edge))" for edge in all_edges(g) + ) - @test message_type(bpc) <: String - @test length(messages(bpc)) == 2 * length(edges(g)) + @test valtype(bpc) <: String + @test edge_data_type(bpc) <: String + @test valtype(bpc) === edge_data_type(bpc) + @test length(edge_data(bpc)) == 2 * length(edges(g)) @test bpc[(1, 1) => (1, 2)] == "(1, 1) => (1, 2)" - @test message(bpc, (2, 1) => (1, 1)) == "(2, 1) => (1, 1)" # set message - setmessage!(bpc, (1, 1) => (1, 2), "new message") - @test message(bpc, (1, 1) => (1, 2)) == "new message" + bpc[(1, 1) => (1, 2)] = "new message" + @test bpc[(1, 1) => (1, 2)] == "new message" - setmessages!(bpc, Dict(((1, 2) => (2, 2)) => "m1", ((2, 2) => (2, 3)) => "m2")) - @test message(bpc, (1, 1) => (1, 2)) == "new message" - @test message(bpc, (1, 2) => (2, 2)) == "m1" - @test message(bpc, (2, 2) => (2, 3)) == "m2" + pairs = [((1, 2) => (2, 2), "m1"), ((2, 2) => (2, 3), "m2")] - bpc_dst = MessageCache(tn) do edge - return "" - end - setmessages!(bpc_dst, bpc, [(1, 2) => (2, 2), (2, 2) => (2, 3)]) - @test message(bpc_dst, (1, 1) => (1, 2)) == "" - @test message(bpc, (1, 2) => (2, 2)) == "m1" - @test message(bpc, (2, 2) => (2, 3)) == "m2" + new_bpc = copyto!(deepcopy(bpc), Dict(pairs)) + @test new_bpc[(1, 1) => (1, 2)] == "new message" + @test new_bpc[(1, 2) => (2, 2)] == "m1" + @test new_bpc[(2, 2) => (2, 3)] == "m2" + + new_bpc = copyto!(deepcopy(bpc), dictionary(pairs)) + @test new_bpc[(1, 1) => (1, 2)] == "new message" + @test new_bpc[(1, 2) => (2, 2)] == "m1" + @test new_bpc[(2, 2) => (2, 3)] == "m2" + + bpc_dst = messagecache(edge => "" for edge in all_edges(g)) + + copyto!(bpc_dst, bpc, [(1, 2) => (2, 2), (2, 2) => (2, 3)]) + @test bpc_dst[(1, 1) => (1, 2)] == "" + @test bpc_dst[(1, 2) => (2, 2)] == "(1, 2) => (2, 2)" + @test bpc_dst[(2, 2) => (2, 3)] == "(2, 2) => (2, 3)" end @testset "Vertex/region scalars" begin g = named_path_graph(3) @@ -90,9 +93,10 @@ end return randn(ComplexF32, Tuple(is)) end - bpc = MessageCache(tn) do edge - return ones(Float64, Tuple(linkinds(tn, edge))) - end + bpc = messagecache( + edge => ones(Float64, Tuple(linkinds(tn, edge))) for + edge in all_edges(g) + ) # Vertex/edge/region scalars. @test vertex_scalar(tn, bpc, 2) isa ComplexF64 @@ -101,26 +105,17 @@ end @test region_scalar(tn, bpc, [1]) == vertex_scalar(tn, bpc, 1) @test region_scalar(tn, bpc, [2, 3]) == prod(vertex_scalars(tn, bpc, [2, 3])) - # `incoming_messages` excludes specified edges. - in_msgs = incoming_messages(bpc, 2) - in_msgs_filtered = incoming_messages( - bpc, 2; ignore_edges = [1 => 2] - ) - @test length(in_msgs) == 2 - @test length(in_msgs_filtered) == 1 - @test only(in_msgs_filtered) == bpc[3 => 2] - - # `map_messages` and `map_factors` produce independent caches. - bpc_again = map_messages(identity, bpc) - @test bpc_again !== bpc - @test bpc_again == bpc + # `incoming_messages` excludes the reverse of the passed edge + in_msgs = incoming_messages(bpc, 2 => 3) + @test length(in_msgs) == 1 + @test only(in_msgs) == bpc[1 => 2] - bpc_doubled = map_messages(m -> 2 .* m, bpc) - @test bpc_doubled != bpc - @test message(bpc_doubled, 1 => 2) ≈ 2 .* message(bpc, 1 => 2) - @test message(bpc_doubled, 2 => 3) ≈ 2 .* message(bpc, 2 => 3) + in_msgs = incoming_messages(bpc, NamedEdge(1 => 2)) + @test length(in_msgs) == 0 - @test adapt_messages(identity, bpc) == bpc + in_msgs = incoming_messages(bpc, NamedEdge(2 => 1)) + @test length(in_msgs) == 1 + @test only(in_msgs) == bpc[3 => 2] end @testset "subgraph" begin @@ -131,9 +126,9 @@ end is = map(e -> l[e], incident_edges(g, v)) return randn(Tuple(is)) end - bpc = MessageCache(tn) do edge - return ones(Tuple(linkinds(tn, edge))) - end + bpc = messagecache( + edge => ones(Tuple(linkinds(tn, edge))) for edge in all_edges(g) + ) sub_vs = [(1,), (2,)] subbpc = subgraph(bpc, sub_vs) @@ -150,12 +145,11 @@ end return randn(Tuple(is)) end - bpc1 = MessageCache(tn) do edge - return ones(Tuple(linkinds(tn, edge))) - end - bpc2 = MessageCache(tn) do edge - return ones(Tuple(linkinds(tn, edge))) - end + bpc1 = messagecache( + edge => ones(Tuple(linkinds(tn, edge))) for edge in all_edges(g) + ) + + bpc2 = copy(bpc1) # Identical caches: diff should be ~0. @test ITensorNetworksNext.iterate_diff(bpc1, bpc2) ≈ 0.0 atol = 10 * eps() @@ -164,6 +158,9 @@ end @testset "Algorithm" begin @testset "$T" for T in (Float32, Float64, ComplexF64, BigFloat) + onet = (tn, edge) -> ones(T, Tuple(linkinds(tn, edge))) + randt = (tn, edge) -> rand(T, Tuple(linkinds(tn, edge))) + #Chain of tensors dims = (2, 1) g = named_grid(dims) @@ -175,11 +172,10 @@ end return randn(T, Tuple(is)) end - bpc = MessageCache(tn) do edge - return ones(T, Tuple(linkinds(tn, edge))) - end - bpc = ITensorNetworksNext.beliefpropagation(tn, bpc; maxiter = 1) - z_bp = scalar(tn, bpc) + messages = Dict(edge => onet(tn, edge) for edge in all_edges(g)) + + cache = ITensorNetworksNext.beliefpropagation(tn, messages; maxiter = 1) + z_bp = scalar(tn, cache) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test z_bp ≈ z_exact @@ -193,11 +189,10 @@ end return randn(T, Tuple(is)) end - bpc = MessageCache(tn) do edge - return ones(T, Tuple(linkinds(tn, edge))) - end - bpc = ITensorNetworksNext.beliefpropagation(tn, bpc; maxiter = 1) - z_bp = scalar(tn, bpc) + messages = Dict(edge => onet(tn, edge) for edge in all_edges(g)) + + cache = ITensorNetworksNext.beliefpropagation(tn, messages; maxiter = 1) + z_bp = scalar(tn, cache) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test z_bp ≈ z_exact @@ -208,22 +203,18 @@ end g = named_grid(dims; periodic = true) tn = spin_ice_tensornetwork(g) - bpc = ITensorNetworksNext.MessageCache(tn) do edge - # Use `rand` so messages have positive elements. - return rand(T, Tuple(linkinds(tn, edge))) - end + messages = Dict(edge => randt(tn, edge) for edge in all_edges(g)) stopping_criterion = StopWhenConverged(tol = 1.0e-10) - bpc = - ITensorNetworksNext.beliefpropagation( + cache = ITensorNetworksNext.beliefpropagation( tn, - bpc; + messages; maxiter = 10, stopping_criterion ) - z_bp = scalar(tn, bpc) + z_bp = scalar(tn, cache) @test z_bp ≈ 1.5^(n^2) end From 8c5278869307248bdcc251713af78d92dea5aaa6 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 6 May 2026 13:16:23 -0400 Subject: [PATCH 82/86] Rename `ennvironment_messages` to `incoming_edge_data`. --- src/beliefpropagation/messagecache.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/beliefpropagation/messagecache.jl b/src/beliefpropagation/messagecache.jl index dc20a66..25ab319 100644 --- a/src/beliefpropagation/messagecache.jl +++ b/src/beliefpropagation/messagecache.jl @@ -160,13 +160,14 @@ function incoming_messages(cache::AbstractGraph, edge::AbstractEdge) return getindices(cache, filter(e -> e != reverse(edge), inds)) end -function environment_messages(cache::AbstractGraph, vertices) +# TODO: maybe this should be defined in `DataGraphs`. +function incoming_edge_data(cache::AbstractGraph, vertices) inds = Indices(boundary_edges(cache, vertices; dir = :in)) return getindices(cache, inds) end function vertex_scalar(factors, messages, vertex; kwargs...) - in_messages = environment_messages(messages, [vertex]) + in_messages = incoming_edge_data(messages, [vertex]) tensors = vcat([factors[vertex]], collect(in_messages)) return contract_network(tensors; kwargs...)[] end From cdfa57c0743cd0d2757b2afea97a9610468c57cf Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 6 May 2026 13:18:53 -0400 Subject: [PATCH 83/86] Rename `logscalar` to `bethe_free_energy`; remove `scalar`. Previous names were too generic. --- src/beliefpropagation/messagecache.jl | 4 +--- test/test_beliefpropagation.jl | 10 +++++----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/beliefpropagation/messagecache.jl b/src/beliefpropagation/messagecache.jl index 25ab319..2927758 100644 --- a/src/beliefpropagation/messagecache.jl +++ b/src/beliefpropagation/messagecache.jl @@ -212,7 +212,7 @@ function region_scalar(factors, messages, region) end # We need a graph structure here, so assume `factors` is a graph. -function logscalar(factors, messages) +function bethe_free_energy(factors, messages) numerator_terms = vertex_scalars(factors, messages) denominator_terms = edge_scalars(messages) @@ -230,8 +230,6 @@ function logscalar(factors, messages) return sum(log.(numerator_terms)) - sum(log.(denominator_terms)) end -scalar(factors, messages) = exp(logscalar(factors, messages)) - # TODO: This needs to go in NamedGraphs.GraphsExtensions function forest_cover_edge_sequence(gi::AbstractGraph; root_vertex = default_root_vertex) # All we care about are the edges so the type of the graph doesnt matter diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index ec68655..02151b1 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -5,8 +5,8 @@ using Dictionaries: Dictionary, dictionary, set! using Graphs: AbstractGraph, dst, edges, has_edge, src, vertices using ITensorBase: ITensor, Index, noprime, prime using ITensorNetworksNext: ITensorNetworksNext, MessageCache, StopWhenConverged, - TensorNetwork, edge_scalar, incoming_messages, linkinds, messagecache, region_scalar, - scalar, subgraph, vertex_scalar, vertex_scalars + TensorNetwork, bethe_free_energy, edge_scalar, incoming_messages, linkinds, + messagecache, region_scalar, subgraph, vertex_scalar, vertex_scalars using LinearAlgebra: LinearAlgebra using NamedDimsArrays: inds, name using NamedGraphs.GraphsExtensions: all_edges, arranged_edges, incident_edges, vertextype @@ -175,7 +175,7 @@ end messages = Dict(edge => onet(tn, edge) for edge in all_edges(g)) cache = ITensorNetworksNext.beliefpropagation(tn, messages; maxiter = 1) - z_bp = scalar(tn, cache) + z_bp = exp(bethe_free_energy(tn, cache)) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test z_bp ≈ z_exact @@ -192,7 +192,7 @@ end messages = Dict(edge => onet(tn, edge) for edge in all_edges(g)) cache = ITensorNetworksNext.beliefpropagation(tn, messages; maxiter = 1) - z_bp = scalar(tn, cache) + z_bp = exp(bethe_free_energy(tn, cache)) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test z_bp ≈ z_exact @@ -214,7 +214,7 @@ end stopping_criterion ) - z_bp = scalar(tn, cache) + z_bp = exp(bethe_free_energy(tn, cache)) @test z_bp ≈ 1.5^(n^2) end From f40500adad022a9080f4c0b03a17fec132bcc4c6 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 6 May 2026 14:25:56 -0400 Subject: [PATCH 84/86] Add `messagecache(f, edge)` method. --- src/beliefpropagation/messagecache.jl | 1 + test/test_beliefpropagation.jl | 21 +++++++-------------- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/src/beliefpropagation/messagecache.jl b/src/beliefpropagation/messagecache.jl index 2927758..cb83610 100644 --- a/src/beliefpropagation/messagecache.jl +++ b/src/beliefpropagation/messagecache.jl @@ -53,6 +53,7 @@ function MessageCache{T, V}(messages) where {T, V} end messagecache(pairs) = MessageCache(Dict(pairs)) +messagecache(f, edges) = messagecache(edge => f(edge) for edge in edges) # ================================ NamedGraphs interface ================================= # function NamedGraphs.add_edge!(c::MessageCache, edge) diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 02151b1..01ca6e7 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -51,9 +51,7 @@ end return randn(Tuple(is)) end - bpc = messagecache( - edge => "$(src(edge)) => $(dst(edge))" for edge in all_edges(g) - ) + bpc = messagecache(edge -> "$(src(edge)) => $(dst(edge))", all_edges(g)) @test valtype(bpc) <: String @test edge_data_type(bpc) <: String @@ -77,7 +75,7 @@ end @test new_bpc[(1, 2) => (2, 2)] == "m1" @test new_bpc[(2, 2) => (2, 3)] == "m2" - bpc_dst = messagecache(edge => "" for edge in all_edges(g)) + bpc_dst = messagecache(edge -> "", all_edges(g)) copyto!(bpc_dst, bpc, [(1, 2) => (2, 2), (2, 2) => (2, 3)]) @test bpc_dst[(1, 1) => (1, 2)] == "" @@ -93,10 +91,9 @@ end return randn(ComplexF32, Tuple(is)) end - bpc = messagecache( - edge => ones(Float64, Tuple(linkinds(tn, edge))) for - edge in all_edges(g) - ) + bpc = messagecache(all_edges(g)) do edge + return ones(Float64, Tuple(linkinds(tn, edge))) + end # Vertex/edge/region scalars. @test vertex_scalar(tn, bpc, 2) isa ComplexF64 @@ -126,9 +123,7 @@ end is = map(e -> l[e], incident_edges(g, v)) return randn(Tuple(is)) end - bpc = messagecache( - edge => ones(Tuple(linkinds(tn, edge))) for edge in all_edges(g) - ) + bpc = messagecache(edge -> ones(Tuple(linkinds(tn, edge))), all_edges(g)) sub_vs = [(1,), (2,)] subbpc = subgraph(bpc, sub_vs) @@ -145,9 +140,7 @@ end return randn(Tuple(is)) end - bpc1 = messagecache( - edge => ones(Tuple(linkinds(tn, edge))) for edge in all_edges(g) - ) + bpc1 = messagecache(edge -> ones(Tuple(linkinds(tn, edge))), all_edges(g)) bpc2 = copy(bpc1) From a6809bb0db2d86f970e47a68111cca4b7ac785fd Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 6 May 2026 14:26:24 -0400 Subject: [PATCH 85/86] Inline belief propagation algorithm construction. --- .../beliefpropagationproblem.jl | 74 +++++++------------ 1 file changed, 25 insertions(+), 49 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index dbe91e2..50dbc5b 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -175,48 +175,10 @@ function AI.solve!( return cache end -function beliefpropagation(factors, messages; kwargs...) - problem = BeliefPropagationProblem(factors) - - cache = initialize_cache(beliefpropagation, factors, messages) - - algorithm = select_algorithm(beliefpropagation, factors, cache; kwargs...) - - state = AI.solve(problem, algorithm; iterate = cache) - - return state.iterate # -> typeof(cache) -end - -# Use a `MessageCache` by default. Note if `messages` is already a `MessageCache` this -# will make a copy of the the existing cache, thus protecting the original from mutation. -initialize_cache(::typeof(beliefpropagation), _factors, messages) = MessageCache(messages) - -function default_stopping_criterion( - ::typeof(beliefpropagation); - maxiter, - tol, - stopping_criterion - ) - base_stopping_criterion = AI.StopAfterIteration(maxiter) - - if !isnothing(stopping_criterion) - base_stopping_criterion |= stopping_criterion - end - - if !isnothing(tol) - base_stopping_criterion |= StopWhenConverged(tol) - end - - return base_stopping_criterion -end - -function select_algorithm( - ::typeof(beliefpropagation), - network::AbstractGraph, - cache::MessageCache; - edges = forest_cover_edge_sequence(cache), - maxiter = is_tree(network) ? 1 : nothing, - tol = nothing, +function beliefpropagation( + factors, messages; + edges = nothing, + maxiter = is_tree(factors) ? 1 : nothing, stopping_criterion = nothing, kwargs... ) @@ -229,19 +191,33 @@ function select_algorithm( ) end - stopping_criterion = default_stopping_criterion( - beliefpropagation; - maxiter, - tol, - stopping_criterion - ) + cache = MessageCache(messages) + problem = BeliefPropagationProblem(factors) + + ## Algorithm construction: + + edges = isnothing(edges) ? forest_cover_edge_sequence(cache) : edges + + base_stopping_criterion = AI.StopAfterIteration(maxiter) + + if !isnothing(stopping_criterion) + base_stopping_criterion |= stopping_criterion + end + + stopping_criterion = base_stopping_criterion extended_kwargs = extend_columns((; kwargs...), maxiter) edge_kwargs = rows(extended_kwargs, maxiter) - return BeliefPropagation(maxiter; stopping_criterion) do repnum + algorithm = BeliefPropagation(maxiter; stopping_criterion) do repnum return BeliefPropagationSweep(edges) do edge return SimpleMessageUpdate(edge; edge_kwargs[repnum]...) end end + + ## + + state = AI.solve(problem, algorithm; iterate = cache) + + return state.iterate # -> typeof(cache) end From 00a23f6be4ec6f17befe1bb8845c032927372a22 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 6 May 2026 16:48:44 -0400 Subject: [PATCH 86/86] Upgrade to registered `AlgorithmsInterface` version; remove `solve!` overloads. --- .../AlgorithmsInterfaceExtensions.jl | 83 ++++--------------- .../beliefpropagationproblem.jl | 7 +- test/test_algorithmsinterfaceextensions.jl | 58 ++----------- test/test_sweeping.jl | 14 ++-- 4 files changed, 31 insertions(+), 131 deletions(-) diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index f042dc0..d9edb0d 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -46,61 +46,6 @@ end function AI.increment!(problem::Problem, algorithm::Algorithm, state::State) return AI.increment!(state) end -# ============================ solve! ====================================================== - -# Custom version of `solve!` that allows specifying the logger and also overloads -# `increment!` on the problem and algorithm. -function basetypenameof(x) - return Symbol(last(split(String(Symbol(Base.typename(typeof(x)).wrapper)), "."))) -end -default_logging_context_prefix(x) = Symbol(basetypenameof(x), :_) -function default_logging_context_prefix(problem::Problem, algorithm::Algorithm) - return Symbol( - default_logging_context_prefix(problem), - default_logging_context_prefix(algorithm) - ) -end -function AI.solve!( - problem::Problem, algorithm::Algorithm, state::State; - logging_context_prefix = default_logging_context_prefix(problem, algorithm), - kwargs... - ) - logger = AI.algorithm_logger() - - context_suffixes = [:Start, :PreStep, :PostStep, :Stop] - contexts = Dict(context_suffixes .=> Symbol.(logging_context_prefix, context_suffixes)) - - # initialize the state and emit message - AI.initialize_state!(problem, algorithm, state; kwargs...) - AI.emit_message(logger, problem, algorithm, state, contexts[:Start]) - - # main body of the algorithm - while !AI.is_finished!(problem, algorithm, state) - AI.increment!(problem, algorithm, state) - - # logging event between convergence check and algorithm step - AI.emit_message(logger, problem, algorithm, state, contexts[:PreStep]) - - # algorithm step - AI.step!(problem, algorithm, state; logging_context_prefix) - - # logging event between algorithm step and convergence check - AI.emit_message(logger, problem, algorithm, state, contexts[:PostStep]) - end - - # emit message about finished state - AI.emit_message(logger, problem, algorithm, state, contexts[:Stop]) - return state -end - -function AI.solve( - problem::Problem, algorithm::Algorithm; - logging_context_prefix = default_logging_context_prefix(problem, algorithm), - kwargs... - ) - state = AI.initialize_state(problem, algorithm; kwargs...) - return AI.solve!(problem, algorithm, state; logging_context_prefix, kwargs...) -end # ============================ AlgorithmIterator =========================================== @@ -174,18 +119,12 @@ function set_substate!( return state end -function AI.step!( - problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State; - logging_context_prefix = Symbol() - ) +function AI.step!(problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State) # Get the subproblem, subalgorithm, and substate. subproblem, subalgorithm, substate = get_subproblem(problem, algorithm, state) # Solve the subproblem with the subalgorithm. - logging_context_prefix = Symbol( - logging_context_prefix, default_logging_context_prefix(subalgorithm) - ) - AI.solve!(subproblem, subalgorithm, substate; logging_context_prefix) + AI.solve!(subproblem, subalgorithm, substate) # Update the state with the substate. set_substate!(problem, algorithm, state, substate) @@ -247,15 +186,14 @@ function AI.increment!( return state end function AI.step!( - problem::AI.Problem, algorithm::FlattenedAlgorithm, state::FlattenedAlgorithmState; - logging_context_prefix = Symbol() + problem::AI.Problem, algorithm::FlattenedAlgorithm, state::FlattenedAlgorithmState ) algorithm_sweep = algorithm.algorithms[state.parent_iteration] state_sweep = AI.initialize_state( problem, algorithm_sweep; state.iterate, iteration = state.child_iteration ) - AI.step!(problem, algorithm_sweep, state_sweep; logging_context_prefix) + AI.step!(problem, algorithm_sweep, state_sweep) state.iterate = state_sweep.iterate return state end @@ -292,10 +230,17 @@ abstract type NonIterativeAlgorithmState <: State end function AI.initialize_state(problem::Problem, algorithm::NonIterativeAlgorithm; kwargs...) return DefaultNonIterativeAlgorithmState(; kwargs...) end -function AI.solve!( - problem::Problem, algorithm::NonIterativeAlgorithm, state::State; kwargs... + +function AI.initialize_state!( + problem::Problem, + algorithm::NonIterativeAlgorithm, + state::NonIterativeAlgorithmState ) - return throw(MethodError(AI.solve!, (problem, algorithm, state))) + return state +end + +function AI.solve_loop!(problem::Problem, algorithm::NonIterativeAlgorithm, state::State) + return throw(MethodError(AI.solve_loop!, (problem, algorithm, state))) end @kwdef mutable struct DefaultNonIterativeAlgorithmState{Iterate} <: diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 50dbc5b..004e449 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -153,8 +153,7 @@ end function AI.solve!( problem::BeliefPropagationProblem, algorithm::SimpleMessageUpdate, - cache::MessageCache; - logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm) + cache::MessageCache ) edge = algorithm.edge @@ -217,7 +216,5 @@ function beliefpropagation( ## - state = AI.solve(problem, algorithm; iterate = cache) - - return state.iterate # -> typeof(cache) + return AI.solve(problem, algorithm; iterate = cache) # -> typeof(cache) end diff --git a/test/test_algorithmsinterfaceextensions.jl b/test/test_algorithmsinterfaceextensions.jl index 44e6a09..6f80527 100644 --- a/test/test_algorithmsinterfaceextensions.jl +++ b/test/test_algorithmsinterfaceextensions.jl @@ -16,16 +16,14 @@ end end function AI.step!( - problem::TestProblem, algorithm::TestAlgorithm, state::AIE.DefaultState; - logging_context_prefix = Symbol() + problem::TestProblem, algorithm::TestAlgorithm, state::AIE.DefaultState ) state.iterate .+= 1 # Simple increment step return state end function AI.step!( - problem::TestProblem, algorithm::TestAlgorithmStep, state::AIE.DefaultState; - kwargs... + problem::TestProblem, algorithm::TestAlgorithmStep, state::AIE.DefaultState ) state.iterate .+= 2 # Different increment step return state @@ -101,22 +99,22 @@ end # Solve with custom initial iterate initial_iterate = [5.0, 10.0] - final_state = AI.solve!( + final_iterate = AI.solve!( problem, algorithm, state; iterate = copy(initial_iterate) ) - @test final_state.iteration == 3 + @test state.iteration == 3 + @test final_iterate == state.iterate # Each step increments by 1, so after 3 steps: [5, 10] + 3 = [8, 13] - @test final_state.iterate ≈ [8.0, 13.0] + @test state.iterate ≈ [8.0, 13.0] # Test solve without exclamation problem2 = TestProblem([1.0, 2.0]) algorithm2 = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) initial_iterate2 = [5.0, 10.0] - final_state2 = AI.solve(problem2, algorithm2; iterate = copy(initial_iterate2)) - @test final_state2.iteration == 2 - @test final_state2.iterate ≈ [7.0, 12.0] + final_iterate2 = AI.solve(problem2, algorithm2; iterate = copy(initial_iterate2)) + @test final_iterate2 ≈ [7.0, 12.0] end @testset "DefaultAlgorithmIterator" begin @@ -146,31 +144,6 @@ end @test AI.is_finished!(iterator) end - @testset "with_algorithmlogger" begin - # Test with_algorithmlogger with functions - results = [] - function callback1(problem, algorithm, state) - push!(results, :callback1) - return nothing - end - function callback2(problem, algorithm, state) - push!(results, :callback2) - return nothing - end - - problem = TestProblem([1.0]) - algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)) - - # Test with CallbackAction (wrapped functions) - state = AIE.with_algorithmlogger( - :TestProblem_TestAlgorithm_PreStep => callback1, - :TestProblem_TestAlgorithm_PostStep => callback2 - ) do - return AI.solve(problem, algorithm; iterate = [0.0]) - end - @test results == [:callback1, :callback2] - end - @testset "DefaultNestedAlgorithm" begin # Test creating nested algorithm with function nested_alg = AIE.nested_algorithm(3) do i @@ -270,21 +243,6 @@ end @test state.iterate ≈ [100.0, 200.0] end - @testset "basetypenameof and default_logging_context_prefix" begin - # Test basetypenameof utility - problem = TestProblem([1.0]) - algorithm = TestAlgorithm() - - prefix_problem = AIE.default_logging_context_prefix(problem) - prefix_algorithm = AIE.default_logging_context_prefix(algorithm) - prefix_combined = AIE.default_logging_context_prefix(problem, algorithm) - - @test prefix_problem isa Symbol - @test prefix_algorithm isa Symbol - @test prefix_combined isa Symbol - @test contains(String(prefix_combined), String(prefix_problem)) - end - @testset "DefaultFlattenedAlgorithm" begin # Create nested algorithms that support max_iterations nested_algs = map(1:3) do i diff --git a/test/test_sweeping.jl b/test/test_sweeping.jl index 215a8b8..01881d9 100644 --- a/test/test_sweeping.jl +++ b/test/test_sweeping.jl @@ -11,7 +11,7 @@ struct TestRegion{R, Kwargs <: NamedTuple} <: AIE.NonIterativeAlgorithm end TestRegion(region; kwargs...) = TestRegion(region, (; kwargs...)) -function AI.solve!(problem::TestProblem, algorithm::TestRegion, state::AIE.State; kwargs...) +function AI.solve_loop!(problem::TestProblem, algorithm::TestRegion, state::AIE.State) new_iterate = (; algorithm.region, algorithm.kwargs.foo, algorithm.kwargs.bar) state.iterate = [state.iterate; [new_iterate]] return state @@ -28,8 +28,8 @@ end problem = TestProblem() iterate = [] - state = AI.solve(problem, algorithm; iterate) - @test state.iterate == [(; region = "region", foo = 1, bar = 2)] + iterate = AI.solve(problem, algorithm; iterate) + @test iterate == [(; region = "region", foo = 1, bar = 2)] end @testset "Sweep" begin algorithm = AIE.nested_algorithm(3) do i @@ -37,8 +37,8 @@ end end problem = TestProblem() iterate = [] - state = AI.solve(problem, algorithm; iterate) - @test state.iterate == [ + iterate = AI.solve(problem, algorithm; iterate) + @test iterate == [ (; region = "region1", foo = 1, bar = 2), (; region = "region2", foo = 2, bar = 4), (; region = "region3", foo = 3, bar = 6), @@ -52,8 +52,8 @@ end end problem = TestProblem() iterate = [] - state = AI.solve(problem, algorithm; iterate) - @test state.iterate == [ + iterate = AI.solve(problem, algorithm; iterate) + @test iterate == [ (; region = "sweep1, region1", foo = (1, 1), bar = (2, 2)), (; region = "sweep1, region2", foo = (1, 2), bar = (2, 4)), (; region = "sweep1, region3", foo = (1, 3), bar = (2, 6)),