diff --git a/Project.toml b/Project.toml index 609216e..5455ec2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" -version = "0.3.24" +version = "0.4.0" authors = ["ITensor developers and contributors"] [workspace] @@ -39,15 +39,15 @@ Adapt = "4.3" AlgorithmsInterface = "0.1" BackendSelection = "0.1.6" Combinatorics = "1" -DataGraphs = "0.2.7" +DataGraphs = "0.4" DiagonalArrays = "0.3.31" Dictionaries = "0.4.5" -FunctionImplementations = "0.4" +FunctionImplementations = "0.4.1" Graphs = "1.13.1" LinearAlgebra = "1.10" MacroTools = "0.5.16" -NamedDimsArrays = "0.14.2" -NamedGraphs = "0.6.9, 0.7, 0.8" +NamedDimsArrays = "0.14.3" +NamedGraphs = "0.11" SimpleTraits = "0.9.5" SplitApplyCombine = "1.2.3" TensorOperations = "5.3.1" diff --git a/docs/Project.toml b/docs/Project.toml index fbde468..b4f33d5 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -10,5 +10,5 @@ path = ".." [compat] Documenter = "1" ITensorFormatter = "0.2.27" -ITensorNetworksNext = "0.3" +ITensorNetworksNext = "0.4" Literate = "2" diff --git a/examples/Project.toml b/examples/Project.toml index 780f959..4108720 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -5,4 +5,4 @@ ITensorNetworksNext = "302f2e75-49f0-4526-aef7-d8ba550cb06c" path = ".." [compat] -ITensorNetworksNext = "0.3" +ITensorNetworksNext = "0.4" diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index 5d9561a..d9edb0d 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -22,12 +22,12 @@ function AI.initialize_state!( end function AI.initialize_state( - problem::Problem, algorithm::Algorithm; kwargs... + problem::Problem, algorithm::Algorithm; iterate, kwargs... ) stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion + problem, algorithm, algorithm.stopping_criterion; iterate ) - return DefaultState(; stopping_criterion_state, kwargs...) + return DefaultState(; iterate, stopping_criterion_state, kwargs...) end # ============================ DefaultState ================================================ @@ -46,61 +46,6 @@ end function AI.increment!(problem::Problem, algorithm::Algorithm, state::State) return AI.increment!(state) end -# ============================ solve! ====================================================== - -# Custom version of `solve!` that allows specifying the logger and also overloads -# `increment!` on the problem and algorithm. -function basetypenameof(x) - return Symbol(last(split(String(Symbol(Base.typename(typeof(x)).wrapper)), "."))) -end -default_logging_context_prefix(x) = Symbol(basetypenameof(x), :_) -function default_logging_context_prefix(problem::Problem, algorithm::Algorithm) - return Symbol( - default_logging_context_prefix(problem), - default_logging_context_prefix(algorithm) - ) -end -function AI.solve!( - problem::Problem, algorithm::Algorithm, state::State; - logging_context_prefix = default_logging_context_prefix(problem, algorithm), - kwargs... - ) - logger = AI.algorithm_logger() - - context_suffixes = [:Start, :PreStep, :PostStep, :Stop] - contexts = Dict(context_suffixes .=> Symbol.(logging_context_prefix, context_suffixes)) - - # initialize the state and emit message - AI.initialize_state!(problem, algorithm, state; kwargs...) - AI.emit_message(logger, problem, algorithm, state, contexts[:Start]) - - # main body of the algorithm - while !AI.is_finished!(problem, algorithm, state) - AI.increment!(problem, algorithm, state) - - # logging event between convergence check and algorithm step - AI.emit_message(logger, problem, algorithm, state, contexts[:PreStep]) - - # algorithm step - AI.step!(problem, algorithm, state; logging_context_prefix) - - # logging event between algorithm step and convergence check - AI.emit_message(logger, problem, algorithm, state, contexts[:PostStep]) - end - - # emit message about finished state - AI.emit_message(logger, problem, algorithm, state, contexts[:Stop]) - return state -end - -function AI.solve( - problem::Problem, algorithm::Algorithm; - logging_context_prefix = default_logging_context_prefix(problem, algorithm), - kwargs... - ) - state = AI.initialize_state(problem, algorithm; kwargs...) - return AI.solve!(problem, algorithm, state; logging_context_prefix, kwargs...) -end # ============================ AlgorithmIterator =========================================== @@ -151,8 +96,9 @@ end abstract type NestedAlgorithm <: Algorithm end -function nested_algorithm(f::Function, nalgorithms::Int; kwargs...) - return DefaultNestedAlgorithm(f, nalgorithms; kwargs...) +nested_algorithm(f::Function, int::Int; kwargs...) = nested_algorithm(f, 1:int; kwargs...) +function nested_algorithm(f::Function, iterable; kwargs...) + return DefaultNestedAlgorithm(f, iterable; kwargs...) end max_iterations(algorithm::NestedAlgorithm) = length(algorithm.algorithms) @@ -173,18 +119,12 @@ function set_substate!( return state end -function AI.step!( - problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State; - logging_context_prefix = Symbol() - ) +function AI.step!(problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State) # Get the subproblem, subalgorithm, and substate. subproblem, subalgorithm, substate = get_subproblem(problem, algorithm, state) # Solve the subproblem with the subalgorithm. - logging_context_prefix = Symbol( - logging_context_prefix, default_logging_context_prefix(subalgorithm) - ) - AI.solve!(subproblem, subalgorithm, substate; logging_context_prefix) + AI.solve!(subproblem, subalgorithm, substate) # Update the state with the substate. set_substate!(problem, algorithm, state, substate) @@ -206,8 +146,8 @@ from a list of stored algorithms. algorithms::Algorithms stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end -function DefaultNestedAlgorithm(f::Function, nalgorithms::Int; kwargs...) - return DefaultNestedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...) +function DefaultNestedAlgorithm(f::Function, iterable; kwargs...) + return DefaultNestedAlgorithm(; algorithms = f.(iterable), kwargs...) end # ============================ FlattenedAlgorithm ========================================== @@ -246,15 +186,14 @@ function AI.increment!( return state end function AI.step!( - problem::AI.Problem, algorithm::FlattenedAlgorithm, state::FlattenedAlgorithmState; - logging_context_prefix = Symbol() + problem::AI.Problem, algorithm::FlattenedAlgorithm, state::FlattenedAlgorithmState ) algorithm_sweep = algorithm.algorithms[state.parent_iteration] state_sweep = AI.initialize_state( problem, algorithm_sweep; state.iterate, iteration = state.child_iteration ) - AI.step!(problem, algorithm_sweep, state_sweep; logging_context_prefix) + AI.step!(problem, algorithm_sweep, state_sweep) state.iterate = state_sweep.iterate return state end @@ -291,10 +230,17 @@ abstract type NonIterativeAlgorithmState <: State end function AI.initialize_state(problem::Problem, algorithm::NonIterativeAlgorithm; kwargs...) return DefaultNonIterativeAlgorithmState(; kwargs...) end -function AI.solve!( - problem::Problem, algorithm::NonIterativeAlgorithm, state::State; kwargs... + +function AI.initialize_state!( + problem::Problem, + algorithm::NonIterativeAlgorithm, + state::NonIterativeAlgorithmState ) - return throw(MethodError(AI.solve!, (problem, algorithm, state))) + return state +end + +function AI.solve_loop!(problem::Problem, algorithm::NonIterativeAlgorithm, state::State) + return throw(MethodError(AI.solve_loop!, (problem, algorithm, state))) end @kwdef mutable struct DefaultNonIterativeAlgorithmState{Iterate} <: diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index ace4030..dd7dc50 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -9,4 +9,7 @@ include("contract_network.jl") include("sweeping/utils.jl") include("sweeping/eigenproblem.jl") +include("beliefpropagation/messagecache.jl") +include("beliefpropagation/beliefpropagationproblem.jl") + end diff --git a/src/LazyNamedDimsArrays/lazynameddimsarray.jl b/src/LazyNamedDimsArrays/lazynameddimsarray.jl index 4cbd3f9..62774a7 100644 --- a/src/LazyNamedDimsArrays/lazynameddimsarray.jl +++ b/src/LazyNamedDimsArrays/lazynameddimsarray.jl @@ -7,7 +7,7 @@ using WrappedUnions: @wrapped union::Union{A, Mul{LazyNamedDimsArray{T, A}}} end -parenttype(::Type{LazyNamedDimsArray{<:Any, A}}) where {A} = A +parenttype(::Type{LazyNamedDimsArray{T, A}}) where {T, A} = A parenttype(::Type{LazyNamedDimsArray{T}}) where {T} = AbstractNamedDimsArray{T} parenttype(::Type{LazyNamedDimsArray}) = AbstractNamedDimsArray diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index fb661f5..121073d 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -1,27 +1,21 @@ -using Adapt: Adapt, adapt, adapt_structure +using Adapt: Adapt, adapt using BackendSelection: @Algorithm_str, Algorithm -using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, underlying_graph, - underlying_graph_type, vertex_data +using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, set_vertex_data!, + underlying_graph, underlying_graph_type, vertex_data using Dictionaries: Dictionary -using Graphs: Graphs, AbstractEdge, AbstractGraph, Graph, add_edge!, add_vertex!, bfs_tree, - center, dst, edges, edgetype, ne, neighbors, nv, rem_edge!, src, vertices -using LinearAlgebra: LinearAlgebra, factorize +using Graphs: Graphs, AbstractEdge, AbstractGraph, add_edge!, add_vertex!, dst, edges, + edgetype, ne, neighbors, nv, rem_edge!, src, vertices +using LinearAlgebra: LinearAlgebra using MacroTools: @capture using NamedDimsArrays: dimnames, inds -using NamedGraphs.GraphsExtensions: - directed_graph, incident_edges, rem_edges!, rename_vertices, vertextype, ⊔ -using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree -using SplitApplyCombine: flatten +using NamedGraphs.GraphsExtensions: directed_graph, incident_edges, rem_edges!, vertextype +using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger +using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, similar_graph abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} end -function Graphs.rem_edge!(tn::AbstractTensorNetwork, e) - rem_edge!(underlying_graph(tn), e) - return tn -end - -# TODO: Define a generic fallback for `AbstractDataGraph`? -DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = error("No edge data") +# Need to be careful about removing edges from tensor networks in case there is a bond +Graphs.rem_edge!(::AbstractTensorNetwork, edge) = not_implemented() # Graphs.jl overloads function Graphs.weights(graph::AbstractTensorNetwork) @@ -36,7 +30,7 @@ function Graphs.weights(graph::AbstractTensorNetwork) end # Copy -Base.copy(tn::AbstractTensorNetwork) = error("Not implemented") +Base.copy(::AbstractTensorNetwork) = not_implemented() # Iteration Base.iterate(tn::AbstractTensorNetwork, args...) = iterate(vertex_data(tn), args...) @@ -49,20 +43,7 @@ Base.eltype(tn::AbstractTensorNetwork) = eltype(vertex_data(tn)) # Overload if needed Graphs.is_directed(::Type{<:AbstractTensorNetwork}) = false -# Derived interface, may need to be overloaded -function DataGraphs.underlying_graph_type(G::Type{<:AbstractTensorNetwork}) - return underlying_graph_type(data_graph_type(G)) -end - -# AbstractDataGraphs overloads -function DataGraphs.vertex_data(graph::AbstractTensorNetwork, args...) - return error("Not implemented") -end -function DataGraphs.edge_data(graph::AbstractTensorNetwork, args...) - return error("Not implemented") -end - -DataGraphs.underlying_graph(tn::AbstractTensorNetwork) = error("Not implemented") +DataGraphs.underlying_graph(::AbstractTensorNetwork) = not_implemented() function NamedGraphs.vertex_positions(tn::AbstractTensorNetwork) return NamedGraphs.vertex_positions(underlying_graph(tn)) end @@ -81,49 +62,46 @@ function Adapt.adapt_structure(to, tn::AbstractTensorNetwork) return map_vertex_data_preserve_graph(adapt(to), tn) end -function linkinds(tn::AbstractTensorNetwork, edge::Pair) - return linkinds(tn, edgetype(tn)(edge)) -end -function linkinds(tn::AbstractTensorNetwork, edge::AbstractEdge) - return inds(tn[src(edge)]) ∩ inds(tn[dst(edge)]) -end -function linkaxes(tn::AbstractTensorNetwork, edge::Pair) +linkinds(tn::AbstractGraph, edge::Pair) = linkinds(tn, edgetype(tn)(edge)) +linkinds(tn::AbstractGraph, edge::AbstractEdge) = inds(tn[src(edge)]) ∩ inds(tn[dst(edge)]) + +function linkaxes(tn::AbstractGraph, edge::Pair) return linkaxes(tn, edgetype(tn)(edge)) end -function linkaxes(tn::AbstractTensorNetwork, edge::AbstractEdge) +function linkaxes(tn::AbstractGraph, edge::AbstractEdge) return axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) end -function linknames(tn::AbstractTensorNetwork, edge::Pair) +function linknames(tn::AbstractGraph, edge::Pair) return linknames(tn, edgetype(tn)(edge)) end -function linknames(tn::AbstractTensorNetwork, edge::AbstractEdge) +function linknames(tn::AbstractGraph, edge::AbstractEdge) return dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) end -function siteinds(tn::AbstractTensorNetwork, v) +function siteinds(tn::AbstractGraph, v) s = inds(tn[v]) for v′ in neighbors(tn, v) s = setdiff(s, inds(tn[v′])) end return s end -function siteaxes(tn::AbstractTensorNetwork, edge::AbstractEdge) - s = axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) +function siteaxes(tn::AbstractGraph, v) + s = axes(tn[v]) for v′ in neighbors(tn, v) s = setdiff(s, axes(tn[v′])) end return s end -function sitenames(tn::AbstractTensorNetwork, edge::AbstractEdge) - s = dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) +function sitenames(tn::AbstractGraph, v) + s = dimnames(tn[v]) for v′ in neighbors(tn, v) s = setdiff(s, dimnames(tn[v′])) end return s end -function setindex_preserve_graph!(tn::AbstractTensorNetwork, value, vertex) - vertex_data(tn)[vertex] = value +function setindex_preserve_graph!(tn::AbstractGraph, value, vertex) + set_vertex_data!(tn, value, vertex) return tn end @@ -153,7 +131,7 @@ end # Update the graph of the TensorNetwork `tn` to include # edges that should exist based on the tensor connectivity. -function add_missing_edges!(tn::AbstractTensorNetwork) +function add_missing_edges!(tn::AbstractGraph) foreach(v -> add_missing_edges!(tn, v), vertices(tn)) return tn end @@ -161,7 +139,7 @@ end # Update the graph of the TensorNetwork `tn` to include # edges that should be incident to the vertex `v` # based on the tensor connectivity. -function add_missing_edges!(tn::AbstractTensorNetwork, v) +function add_missing_edges!(tn::AbstractGraph, v) for v′ in vertices(tn) if v ≠ v′ e = v => v′ @@ -175,14 +153,19 @@ end # Fix the edges of the TensorNetwork `tn` to match # the tensor connectivity. -function fix_edges!(tn::AbstractTensorNetwork) +function fix_edges!(tn::AbstractGraph) foreach(v -> fix_edges!(tn, v), vertices(tn)) return tn end # Fix the edges of the TensorNetwork `tn` to match # the tensor connectivity at vertex `v`. -function fix_edges!(tn::AbstractTensorNetwork, v) - rem_edges!(tn, incident_edges(tn, v)) +function fix_edges!(tn::AbstractGraph, v) + for e in incident_edges(tn, v) + # Remove an edge if there is no index on that edge. + if isempty(linkinds(tn, e)) + rem_edge!(tn, e) + end + end add_missing_edges!(tn, v) return tn end @@ -215,28 +198,20 @@ function Base.setindex!(tn::AbstractTensorNetwork, value, v) fix_edges!(tn, v) return tn end -using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger # Fix ambiguity error. function Base.setindex!(graph::AbstractTensorNetwork, value, vertex::OrdinalSuffixedInteger) graph[vertices(graph)[vertex]] = value return graph end -# Fix ambiguity error. -function Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge) - return error("No edge data.") -end -# Fix ambiguity error. -function Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) - return error("No edge data.") -end -using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger +Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge) = not_implemented() +Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) = not_implemented() # Fix ambiguity error. function Base.setindex!( tn::AbstractTensorNetwork, value, edge::Pair{<:OrdinalSuffixedInteger, <:OrdinalSuffixedInteger} ) - return error("No edge data.") + return not_implemented() end function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl new file mode 100644 index 0000000..004e449 --- /dev/null +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -0,0 +1,220 @@ +import .AlgorithmsInterfaceExtensions as AIE +import AlgorithmsInterface as AI +using BackendSelection: @Algorithm_str, Algorithm +using DataGraphs: edge_data +using Graphs: AbstractEdge, edges, has_edge, vertices +using LinearAlgebra: norm, normalize +using NamedDimsArrays: AbstractNamedDimsArray +using NamedGraphs.GraphsExtensions: add_edges!, boundary_edges, subgraph +using NamedGraphs.PartitionedGraphs: quotientvertices + +@kwdef struct StopWhenConverged <: AI.StoppingCriterion + tol::Float64 +end + +@kwdef mutable struct StopWhenConvergedState{Iterate} <: AI.StoppingCriterionState + delta::Float64 = Inf + at_iteration::Int = -1 + previous_iterate::Iterate +end + +function AI.initialize_state(::AIE.Problem, ::AIE.Algorithm, ::StopWhenConverged; iterate) + return StopWhenConvergedState(; previous_iterate = copy(iterate)) +end + +function AI.initialize_state!( + ::AIE.Problem, + ::AIE.Algorithm, + ::StopWhenConverged, + st::StopWhenConvergedState + ) + st.delta = Inf + return st +end + +function AI.is_finished!( + problem::AIE.Problem, + algorithm::AIE.Algorithm, + state::AIE.State, + c::StopWhenConverged, + st::StopWhenConvergedState + ) + iterate = state.iterate + previous_iterate = st.previous_iterate + + delta = iterate_diff(iterate, previous_iterate) + + st.previous_iterate = copy(iterate) + + # maxdiff = 0.0 initially, so skip this the first time. + state.iteration == 0 && return false + + st.delta = delta + + if AI.is_finished(problem, algorithm, state, c, st) + st.at_iteration = state.iteration + return true + end + + return false +end + +function AI.is_finished( + ::AIE.Problem, + ::AIE.Algorithm, + ::AIE.State, + c::StopWhenConverged, + st::StopWhenConvergedState + ) + return st.delta < c.tol +end + +struct BeliefPropagationProblem{Factors} <: AIE.Problem + factors::Factors +end + +function iterate_diff( + cache1::MessageCache, + cache2::MessageCache + ) + return maximum(edges(cache1)) do edge + m1 = cache1[edge] + m2 = cache2[edge] + return 1 - abs2(LinearAlgebra.dot(normalize(m1), normalize(m2))) + end +end + +@kwdef struct BeliefPropagation{ + ChildAlgorithm <: AIE.Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, + } <: AIE.NestedAlgorithm + algorithms::Algorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) +end + +function BeliefPropagation(f::Function, niterations::Int; kwargs...) + return BeliefPropagation(; algorithms = f.(1:niterations), kwargs...) +end + +struct SimpleMessageUpdate{E <: AbstractEdge, Kwargs <: NamedTuple} + edge::E + kwargs::Kwargs +end + +function SimpleMessageUpdate( + edge; + normalize = true, + contraction_alg = Algorithm"exact", + kwargs... + ) + return SimpleMessageUpdate( + edge, + (; normalize, contraction_alg, kwargs...) + ) +end + +function Base.getproperty(alg::SimpleMessageUpdate, name::Symbol) + if name in (:edge, :kwargs) + return getfield(alg, name) + else + return getproperty(getfield(alg, :kwargs), name) + end +end + +AI.initialize_state(::BeliefPropagationProblem, ::SimpleMessageUpdate; iterate) = iterate + +struct BeliefPropagationSweep{ + ChildAlgorithm, Algorithms <: AbstractVector{ChildAlgorithm}, + } <: AIE.NestedAlgorithm + algorithms::Algorithms + stopping_criterion::AI.StopAfterIteration + function BeliefPropagationSweep(; algorithms) + stopping_criterion = AI.StopAfterIteration(length(algorithms)) + return new{eltype(algorithms), typeof(algorithms)}(algorithms, stopping_criterion) + end +end + +function BeliefPropagationSweep(f::Function, edges) + return BeliefPropagationSweep(; algorithms = f.(edges)) +end + +function AIE.set_substate!( + ::BeliefPropagationProblem, + ::BeliefPropagationSweep, + state::AIE.DefaultState, + cache::MessageCache + ) + state.iterate = cache + + return state +end + +function AI.solve!( + problem::BeliefPropagationProblem, + algorithm::SimpleMessageUpdate, + cache::MessageCache + ) + edge = algorithm.edge + + messages = collect(incoming_messages(cache, edge)) + factor = problem.factors[src(edge)] + + new_message = contract_network(vcat(messages, [factor]); algorithm.contraction_alg) + + if algorithm.normalize + message_norm = sum(new_message) + if !iszero(message_norm) + new_message /= message_norm + end + end + + cache[edge] = new_message + + return cache +end + +function beliefpropagation( + factors, messages; + edges = nothing, + maxiter = is_tree(factors) ? 1 : nothing, + stopping_criterion = nothing, + kwargs... + ) + if isnothing(maxiter) + throw( + ArgumentError( + "`maxiter` must be specified for non-tree graphs, even when + `stopping_criterion` is provided." + ) + ) + end + + cache = MessageCache(messages) + problem = BeliefPropagationProblem(factors) + + ## Algorithm construction: + + edges = isnothing(edges) ? forest_cover_edge_sequence(cache) : edges + + base_stopping_criterion = AI.StopAfterIteration(maxiter) + + if !isnothing(stopping_criterion) + base_stopping_criterion |= stopping_criterion + end + + stopping_criterion = base_stopping_criterion + + extended_kwargs = extend_columns((; kwargs...), maxiter) + edge_kwargs = rows(extended_kwargs, maxiter) + + algorithm = BeliefPropagation(maxiter; stopping_criterion) do repnum + return BeliefPropagationSweep(edges) do edge + return SimpleMessageUpdate(edge; edge_kwargs[repnum]...) + end + end + + ## + + return AI.solve(problem, algorithm; iterate = cache) # -> typeof(cache) +end diff --git a/src/beliefpropagation/messagecache.jl b/src/beliefpropagation/messagecache.jl new file mode 100644 index 0000000..cb83610 --- /dev/null +++ b/src/beliefpropagation/messagecache.jl @@ -0,0 +1,269 @@ +using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, edge_data_type, + set_vertex_data!, underlying_graph, underlying_graph_type, vertex_data, vertex_data_type +using Dictionaries: Dictionary, delete!, getindices, set! +using Graphs: AbstractGraph, connected_components, is_directed, is_tree +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, lazy, parenttype +using NamedGraphs.GraphsExtensions: IsDirected, boundary_edges, default_root_vertex, + directed_graph, forest_cover, in_incident_edges, post_order_dfs_edges, undirected_graph, + vertextype +using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, quotient_graph +using NamedGraphs: NamedDiGraph, Vertices, convert_vertextype, ordered_vertices, + parent_graph_indices, position_graph, to_graph_index, vertex_positions + +struct MessageCache{T, V} <: AbstractDataGraph{V, Nothing, T} + messages::Dictionary{NamedEdge{V}, T} + underlying_graph::NamedDiGraph{V} + function MessageCache{T, V}(::UndefInitializer, vertices) where {T, V} + messages = Dictionary{NamedEdge{V}, T}() + underlying_graph = NamedDiGraph{V}(vertices) + return new{T, V}(messages, underlying_graph) + end +end + +# single type parameter version of the inner constructor +function MessageCache{T}(::UndefInitializer, vertices) where {T} + return MessageCache{T, eltype(vertices)}(undef, vertices) +end + +# compatibility with generic key-val iterables +Base.keytype(c::MessageCache) = keytype(typeof(c)) +Base.keytype(::Type{<:MessageCache{T, V}}) where {T, V} = NamedEdge{V} + +Base.valtype(c::MessageCache) = valtype(typeof(c)) +Base.valtype(::Type{<:MessageCache{T}}) where {T} = T + +Base.keys(cache::MessageCache) = edges(cache) + +MessageCache(messages) = MessageCache{valtype(messages)}(messages) + +function MessageCache{T}(messages) where {T} + V = vertextype(keytype(messages)) + return MessageCache{T, V}(messages) +end + +# `messages` is any iterable data structure, where `keys(messages)` are edges +# and the values are the messages on those edges. +function MessageCache{T, V}(messages) where {T, V} + edges = keys(messages) + vertices = union(src.(edges), dst.(edges)) + cache = MessageCache{T, V}(undef, vertices) + add_edges!(cache.underlying_graph, edges) + copyto!(cache, messages) + return cache +end + +messagecache(pairs) = MessageCache(Dict(pairs)) +messagecache(f, edges) = messagecache(edge => f(edge) for edge in edges) + +# ================================ NamedGraphs interface ================================= # +function NamedGraphs.add_edge!(c::MessageCache, edge) + add_edge!(c.underlying_graph, edge) + return c +end + +function NamedGraphs.rem_edge!(c::MessageCache, edge) + delete!(c.messages, to_graph_index(c, edge)) + rem_edge!(c.underlying_graph, edge) + return c +end + +# ================================= DataGraphs interface ================================= # + +DataGraphs.underlying_graph(cache::MessageCache) = cache.underlying_graph + +DataGraphs.is_vertex_assigned(::MessageCache, _) = false +DataGraphs.is_edge_assigned(c::MessageCache, edge) = haskey(c.messages, edge) + +function DataGraphs.get_edge_data(c::MessageCache, edge::AbstractEdge) + return c.messages[edge] +end +function DataGraphs.set_edge_data!(c::MessageCache, val, edge) + return set!(c.messages, edge, val) +end + +Base.copy(cache::MessageCache) = MessageCache(copy(cache.messages)) + +function Base.:(==)(cache1::MessageCache, cache2::MessageCache) + ug1 = cache1.underlying_graph + ug2 = cache2.underlying_graph + + ms1 = cache1.messages + ms2 = cache2.messages + + return (ug1 == ug2 && ms1 == ms2) +end + +function NamedGraphs.induced_subgraph_from_vertices(cache::MessageCache, subvertices) + # TODO: once we have `subgraph_edges` in `NamedGraphs`, simplify this. + underlying_subgraph, vlist = + Graphs.induced_subgraph(cache.underlying_graph, subvertices) + + assigned = v -> isassigned(cache, v) + + assigned_subedges = Iterators.filter(assigned, edges(underlying_subgraph)) + + messages = getindices(cache.messages, Indices(assigned_subedges)) + + return MessageCache(messages), vlist +end + +# see: copyto!(dest, src) for analogous behaviour to 2 argument method +# see: copyto!(dest, Rdest::CartesianIndices, src, Rsrc::CartesianIndices) +# for analogous behaviour to 3 argument method. +# TODO: these can be made generic for `AbtractDataGraph` in `DataGraphs.jl` +function copyto!_messagecache( + cache_dst::MessageCache, + cache_src, + inds = nothing + ) + inds = isnothing(inds) ? Indices(keys(cache_src)) : Indices(inds) + view(edge_data(cache_dst), inds) .= view(cache_src, inds) + return cache_dst +end + +function Base.copyto!( + cache_dst::MessageCache, + cache_src::AbstractDataGraph, + inds = nothing + ) + copyto!_messagecache(cache_dst, edge_data(cache_src), inds) + return cache_dst +end + +function Base.copyto!( + cache_dst::MessageCache, + dictionary_src::Dictionary, + inds = nothing + ) + copyto!_messagecache(cache_dst, dictionary_src, inds) + return cache_dst +end + +function Base.copyto!( + cache_dst::MessageCache, + dict_src::Dict, + inds = keys(dict_src) + ) + for key in inds + cache_dst[key] = dict_src[key] + end + return cache_dst +end + +# ===================================== contraction ====================================== # + +function incoming_messages(cache::AbstractGraph, pair::Pair) + edge = to_graph_index(cache, pair) + return incoming_messages(cache, edge) +end +function incoming_messages(cache::AbstractGraph, edge::AbstractEdge) + inds = Indices(in_incident_edges(cache, src(edge))) + return getindices(cache, filter(e -> e != reverse(edge), inds)) +end + +# TODO: maybe this should be defined in `DataGraphs`. +function incoming_edge_data(cache::AbstractGraph, vertices) + inds = Indices(boundary_edges(cache, vertices; dir = :in)) + return getindices(cache, inds) +end + +function vertex_scalar(factors, messages, vertex; kwargs...) + in_messages = incoming_edge_data(messages, [vertex]) + tensors = vcat([factors[vertex]], collect(in_messages)) + return contract_network(tensors; kwargs...)[] +end + +vertex_scalars(factors, messages) = vertex_scalars(factors, messages, keys(factors)) +function vertex_scalars(factors::AbstractGraph, messages) + return vertex_scalars(factors, messages, vertices(factors)) +end +function vertex_scalars(factors, messages, vertices) + return map(v -> vertex_scalar(factors, messages, v), vertices) +end + +function edge_scalar(cache, edge; kwargs...) + m1 = cache[edge] + m2 = cache[reverse(edge)] + return contract_network([m1, m2]; kwargs...)[] +end + +edge_scalars(cache) = edge_scalars(cache, keys(cache)) + +function edge_scalars(cache, edges) + processed = Set{eltype(edges)}() + + T = Base.promote_op(edge_scalar, typeof(cache), eltype(edges)) + + scalars = T[] + + # Ignore repeated edges and their reverses. + for e in edges + if e in processed || reverse(e) in processed + continue + end + push!(processed, e) + push!(scalars, edge_scalar(cache, e)) + end + + return scalars +end + +function region_scalar(factors, messages, region) + return mapreduce(vertex -> vertex_scalar(factors, messages, vertex), *, region) +end + +# We need a graph structure here, so assume `factors` is a graph. +function bethe_free_energy(factors, messages) + numerator_terms = vertex_scalars(factors, messages) + denominator_terms = edge_scalars(messages) + + if any(t -> real(t) < 0, numerator_terms) + numerator_terms = complex.(numerator_terms) + end + if any(t -> real(t) < 0, denominator_terms) + denominator_terms = complex.(denominator_terms) + end + + if any(iszero, denominator_terms) + return -Inf + end + + return sum(log.(numerator_terms)) - sum(log.(denominator_terms)) +end + +# TODO: This needs to go in NamedGraphs.GraphsExtensions +function forest_cover_edge_sequence(gi::AbstractGraph; root_vertex = default_root_vertex) + # All we care about are the edges so the type of the graph doesnt matter + g = similar_graph(NamedGraph, vertices(gi)) + add_edges!(g, edges(gi)) + forests = forest_cover(g) + rv = edgetype(g)[] + for forest in forests + trees = [forest[Vertices(vs)] for vs in connected_components(forest)] + for tree in trees + tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) + push!(rv, vcat(tree_edges, reverse(reverse.(tree_edges)))...) + end + end + return rv +end + +# ======================================= printing ======================================= # + +# TODO: This is the definition for the proposed `DataGraphs.AbstractEdgeDataGraph`. +function Base.show(io::IO, mime::MIME"text/plain", graph::MessageCache) + println(io, "$(typeof(graph)) with $(nv(graph)) vertices:") + show(io, mime, vertices(graph)) + println(io, "\n") + println(io, "and $(ne(graph)) edge(s):") + for e in edges(graph) + show(io, mime, e) + println(io) + end + println(io) + println(io, "with edge data:") + show(io, mime, edge_data(graph)) + return nothing +end + +Base.show(io::IO, graph::MessageCache) = show(io, MIME"text/plain"(), graph) diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 9c03adf..5357b5f 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -1,10 +1,19 @@ +using .LazyNamedDimsArrays: Mul, lazy using Combinatorics: combinations -using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph -using Dictionaries: AbstractDictionary, Indices, dictionary -using Graphs: AbstractSimpleGraph +using DataGraphs.DataGraphsPartitionedGraphsExt +using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph, edge_data, get_vertices_data, + vertex_data, vertex_data_type +using Dictionaries: AbstractDictionary, Indices, dictionary, set!, unset! +using Graphs: AbstractSimpleGraph, rem_edge!, rem_vertex! using NamedDimsArrays: AbstractNamedDimsArray, dimnames -using NamedGraphs.GraphsExtensions: add_edges!, arrange_edge, arranged_edges, vertextype -using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype +using NamedGraphs.GraphsExtensions: + GraphsExtensions, arrange_edge, arranged_edges, vertextype +using NamedGraphs.PartitionedGraphs: AbstractPartitionedGraph, PartitionedGraphs, + QuotientVertex, QuotientVertexVertices, QuotientVertices, departition, + partitioned_vertices, partitionedgraph, quotient_graph, quotient_graph_type, + quotientvertices +using NamedGraphs: + NamedGraphs, NamedEdge, NamedGraph, Vertices, parent_graph_indices, vertextype function _TensorNetwork end @@ -20,12 +29,38 @@ struct TensorNetwork{V, VD, UG <: AbstractGraph{V}, Tensors <: AbstractDictionar end end # This assumes the tensor connectivity matches the graph structure. -function _TensorNetwork(graph::AbstractGraph, tensors) - return _TensorNetwork(graph, Dictionary(keys(tensors), values(tensors))) +function TensorNetwork(graph::AbstractGraph, tensors) + return TensorNetwork(graph, Dictionary(keys(tensors), values(tensors))) +end +function TensorNetwork(graph::AbstractGraph, tensors::AbstractDictionary) + tn = _TensorNetwork(graph, tensors) + fix_links!(tn) + return tn +end + +function TensorNetwork{V, VD, UG, Tensors}( + graph::UG + ) where {V, VD, UG <: AbstractGraph{V}, Tensors} + return _TensorNetwork(graph, Tensors()) end -DataGraphs.underlying_graph(tn::TensorNetwork) = getfield(tn, :underlying_graph) -DataGraphs.vertex_data(tn::TensorNetwork) = getfield(tn, :tensors) +function Graphs.rem_vertex!(tn::TensorNetwork, v) + delete!(tn.tensors, v) + rem_vertex!(tn.underlying_graph, v) + return tn +end + +# DataGraphs interface + +DataGraphs.underlying_graph(tn::TensorNetwork) = tn.underlying_graph + +DataGraphs.is_vertex_assigned(tn::TensorNetwork, v) = haskey(tn.tensors, v) +DataGraphs.is_edge_assigned(tn::TensorNetwork, e) = false + +DataGraphs.get_vertex_data(tn::TensorNetwork, v) = tn.tensors[v] + +DataGraphs.set_vertex_data!(tn::TensorNetwork, val, v) = set!(tn.tensors, v, val) + function DataGraphs.underlying_graph_type(type::Type{<:TensorNetwork}) return fieldtype(type, :underlying_graph) end @@ -49,13 +84,7 @@ end tensornetwork_edges(tensors) = tensornetwork_edges(NamedEdge, tensors) function TensorNetwork(f::Base.Callable, graph::AbstractGraph) - tensors = Dictionary(vertices(graph), f.(vertices(graph))) - return TensorNetwork(graph, tensors) -end -function TensorNetwork(graph::AbstractGraph, tensors) - tn = _TensorNetwork(graph, tensors) - fix_links!(tn) - return tn + return TensorNetwork(graph, Dictionary(map(f, vertices(graph)))) end # Insert trivial links for missing edges, and also check @@ -87,9 +116,115 @@ TensorNetwork(tn::TensorNetwork) = copy(tn) TensorNetwork{V}(tn::TensorNetwork{V}) where {V} = copy(tn) function TensorNetwork{V}(tn::TensorNetwork) where {V} g = convert_vertextype(V, underlying_graph(tn)) - d = dictionary(V(k) => tn[k] for k in keys(d)) + d = dictionary(V(k) => tn[k] for k in vertices(tn)) return TensorNetwork(g, d) end NamedGraphs.convert_vertextype(::Type{V}, tn::TensorNetwork{V}) where {V} = tn NamedGraphs.convert_vertextype(V::Type, tn::TensorNetwork) = TensorNetwork{V}(tn) + +function Graphs.rem_edge!(tn::TensorNetwork, e) + if !has_edge(underlying_graph(tn), e) + return false + end + if !isempty(linkinds(tn, e)) + throw( + ArgumentError( + "cannot remove edge $e due to tensor indices existing on this edge." + ) + ) + end + rem_edge!(underlying_graph(tn), e) + return true +end + +function NamedGraphs.similar_graph( + type::Type{<:TensorNetwork}, + vertices = vertextype(type)[] + ) + DT = fieldtype(type, :tensors) + empty_dict = DT() + + underlying_graph = similar_graph(underlying_graph_type(type), vertices) + + return _TensorNetwork(underlying_graph, empty_dict) +end +function NamedGraphs.similar_graph( + graph::TensorNetwork, + VD::Type, + ::Type{<:Nothing}, + vertices + ) + V = eltype(vertices) + empty_dict = Dictionary{V, VD}() + + new_underlying_graph = similar_graph(underlying_graph(graph), vertices) + + return _TensorNetwork(new_underlying_graph, empty_dict) +end + +function NamedGraphs.induced_subgraph_from_vertices(graph::TensorNetwork, subvertices) + return induced_subgraph_tensornetwork(graph, subvertices) +end + +function induced_subgraph_tensornetwork(graph, subvertices) + underlying_subgraph, vlist = + Graphs.induced_subgraph(underlying_graph(graph), subvertices) + + subgraph = TensorNetwork(underlying_subgraph) do vertex + return graph[vertex] + end + + return subgraph, vlist +end + +## PartitionedGraphs +function PartitionedGraphs.partitioned_vertices(tn::TensorNetwork) + return partitioned_vertices(tn.underlying_graph) +end + +function PartitionedGraphs.quotient_graph(tn::TensorNetwork) + ug = quotient_graph(underlying_graph(tn)) + + inds = Indices(parent_graph_indices(QuotientVertices(tn))) + data = map(v -> tn[QuotientVertex(v)], inds) + + return TensorNetwork(ug, data) +end +# TODO: This method should not be required with a better interface with a better +# DataGraphsPartitionedGraphsExt interface. +function PartitionedGraphs.quotient_graph_type(type::Type{<:TensorNetwork}) + UG = quotient_graph_type(underlying_graph_type(type)) + VD = Vector{vertex_data_type(type)} + V = vertextype(UG) + return TensorNetwork{V, VD, UG, Dictionary{V, VD}} +end + +# Partition the underlying graph of the tensor network; does not affect the data. +function PartitionedGraphs.partitionedgraph(tn::TensorNetwork, parts) + pg = partitionedgraph(underlying_graph(tn), parts) + return TensorNetwork(pg, copy(vertex_data(tn))) +end + +PartitionedGraphs.departition(tn::TensorNetwork) = tn +function PartitionedGraphs.departition( + tn::TensorNetwork{<:Any, <:Any, <:AbstractPartitionedGraph} + ) + return TensorNetwork(departition(underlying_graph(tn)), vertex_data(tn)) +end + +NamedGraphs.to_graph_index(::TensorNetwork, vertex::QuotientVertex) = vertex +# When getting data according the quotient vertices, take a lazy contraction. +function DataGraphs.get_index_data(tn::TensorNetwork, vertex::QuotientVertex) + data = collect(map(v -> tn[v], vertices(tn, vertex))) + return mapreduce(lazy, *, data) +end +function DataGraphs.is_graph_index_assigned(tn::TensorNetwork, vertex::QuotientVertex) + return isassigned(tn, Vertices(vertices(tn, vertex))) +end + +function PartitionedGraphs.quotientview(tn::TensorNetwork) + qview = QuotientView(underlying_graph(tn)) + tensors = map(qv -> vertex_data(tn)[Indices(qv)], Indices(quotientvertices(tn))) + return TensorNetwork(qview, tensors) +end diff --git a/test/Project.toml b/test/Project.toml index 6231ad4..4bcd159 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,12 +2,15 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" AlgorithmsInterface = "d1e3940c-cd12-4505-8585-b0a4b322527d" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5" +DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a" DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" ITensorNetworksNext = "302f2e75-49f0-4526-aef7-d8ba550cb06c" ITensorPkgSkeleton = "3d388ab1-018a-49f4-ae50-18094d5f71ea" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" @@ -25,14 +28,15 @@ path = ".." AbstractTrees = "0.4.5" AlgorithmsInterface = "0.1" Aqua = "0.8.14" +DataGraphs = "0.4" DiagonalArrays = "0.3.23" Dictionaries = "0.4.5" Graphs = "1.13.1" ITensorBase = "0.5" -ITensorNetworksNext = "0.3" +ITensorNetworksNext = "0.4" ITensorPkgSkeleton = "0.3.42" NamedDimsArrays = "0.14" -NamedGraphs = "0.6.8, 0.7, 0.8" +NamedGraphs = "0.11" QuadGK = "2.11.2" SafeTestsets = "0.1" Suppressor = "0.2.8" diff --git a/test/test_algorithmsinterfaceextensions.jl b/test/test_algorithmsinterfaceextensions.jl index 44e6a09..6f80527 100644 --- a/test/test_algorithmsinterfaceextensions.jl +++ b/test/test_algorithmsinterfaceextensions.jl @@ -16,16 +16,14 @@ end end function AI.step!( - problem::TestProblem, algorithm::TestAlgorithm, state::AIE.DefaultState; - logging_context_prefix = Symbol() + problem::TestProblem, algorithm::TestAlgorithm, state::AIE.DefaultState ) state.iterate .+= 1 # Simple increment step return state end function AI.step!( - problem::TestProblem, algorithm::TestAlgorithmStep, state::AIE.DefaultState; - kwargs... + problem::TestProblem, algorithm::TestAlgorithmStep, state::AIE.DefaultState ) state.iterate .+= 2 # Different increment step return state @@ -101,22 +99,22 @@ end # Solve with custom initial iterate initial_iterate = [5.0, 10.0] - final_state = AI.solve!( + final_iterate = AI.solve!( problem, algorithm, state; iterate = copy(initial_iterate) ) - @test final_state.iteration == 3 + @test state.iteration == 3 + @test final_iterate == state.iterate # Each step increments by 1, so after 3 steps: [5, 10] + 3 = [8, 13] - @test final_state.iterate ≈ [8.0, 13.0] + @test state.iterate ≈ [8.0, 13.0] # Test solve without exclamation problem2 = TestProblem([1.0, 2.0]) algorithm2 = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) initial_iterate2 = [5.0, 10.0] - final_state2 = AI.solve(problem2, algorithm2; iterate = copy(initial_iterate2)) - @test final_state2.iteration == 2 - @test final_state2.iterate ≈ [7.0, 12.0] + final_iterate2 = AI.solve(problem2, algorithm2; iterate = copy(initial_iterate2)) + @test final_iterate2 ≈ [7.0, 12.0] end @testset "DefaultAlgorithmIterator" begin @@ -146,31 +144,6 @@ end @test AI.is_finished!(iterator) end - @testset "with_algorithmlogger" begin - # Test with_algorithmlogger with functions - results = [] - function callback1(problem, algorithm, state) - push!(results, :callback1) - return nothing - end - function callback2(problem, algorithm, state) - push!(results, :callback2) - return nothing - end - - problem = TestProblem([1.0]) - algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)) - - # Test with CallbackAction (wrapped functions) - state = AIE.with_algorithmlogger( - :TestProblem_TestAlgorithm_PreStep => callback1, - :TestProblem_TestAlgorithm_PostStep => callback2 - ) do - return AI.solve(problem, algorithm; iterate = [0.0]) - end - @test results == [:callback1, :callback2] - end - @testset "DefaultNestedAlgorithm" begin # Test creating nested algorithm with function nested_alg = AIE.nested_algorithm(3) do i @@ -270,21 +243,6 @@ end @test state.iterate ≈ [100.0, 200.0] end - @testset "basetypenameof and default_logging_context_prefix" begin - # Test basetypenameof utility - problem = TestProblem([1.0]) - algorithm = TestAlgorithm() - - prefix_problem = AIE.default_logging_context_prefix(problem) - prefix_algorithm = AIE.default_logging_context_prefix(algorithm) - prefix_combined = AIE.default_logging_context_prefix(problem, algorithm) - - @test prefix_problem isa Symbol - @test prefix_algorithm isa Symbol - @test prefix_combined isa Symbol - @test contains(String(prefix_combined), String(prefix_problem)) - end - @testset "DefaultFlattenedAlgorithm" begin # Create nested algorithms that support max_iterations nested_algs = map(1:3) do i diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl new file mode 100644 index 0000000..01ca6e7 --- /dev/null +++ b/test/test_beliefpropagation.jl @@ -0,0 +1,217 @@ +import AlgorithmsInterface as AI +using DataGraphs: edge_data, edge_data_type +using DiagonalArrays: δ +using Dictionaries: Dictionary, dictionary, set! +using Graphs: AbstractGraph, dst, edges, has_edge, src, vertices +using ITensorBase: ITensor, Index, noprime, prime +using ITensorNetworksNext: ITensorNetworksNext, MessageCache, StopWhenConverged, + TensorNetwork, bethe_free_energy, edge_scalar, incoming_messages, linkinds, + messagecache, region_scalar, subgraph, vertex_scalar, vertex_scalars +using LinearAlgebra: LinearAlgebra +using NamedDimsArrays: inds, name +using NamedGraphs.GraphsExtensions: all_edges, arranged_edges, incident_edges, vertextype +using NamedGraphs.NamedGraphGenerators: named_comb_tree, named_grid, named_path_graph +using NamedGraphs: NamedEdge +using Test: @test, @testset + +function spin_ice_tensornetwork(g) + links = Dictionary( + edges(g), + [Index(2) for e in edges(g)] + # [Index(2; tags = "edge " => "e$(src(e))_$(dst(e))") for e in edges(g)] + ) + links = merge(links, Dictionary(reverse.(edges(g)), [links[e] for e in edges(g)])) + + ts = Dictionary{vertextype(g), ITensor}() + for v in vertices(g) + es = incident_edges(g, v; dir = :in) + t_data = zeros(Int, 2, 2, 2, 2) + for (i, j, k, l) in Iterators.product(0:1, 0:1, 0:1, 0:1) + if i + j + k + l == 2 + t_data[i + 1, j + 1, k + 1, l + 1] = 1 + end + end + linkinds = [links[e] for e in es] + t = t_data[linkinds...] + set!(ts, v, t) + end + return TensorNetwork(g, ts) +end + +@testset "Belief propagation" begin + @testset "`MessageCache`" begin + @testset "Basics" begin + dims = (3, 3) + g = named_grid(dims) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + bpc = messagecache(edge -> "$(src(edge)) => $(dst(edge))", all_edges(g)) + + @test valtype(bpc) <: String + @test edge_data_type(bpc) <: String + @test valtype(bpc) === edge_data_type(bpc) + @test length(edge_data(bpc)) == 2 * length(edges(g)) + @test bpc[(1, 1) => (1, 2)] == "(1, 1) => (1, 2)" + + # set message + bpc[(1, 1) => (1, 2)] = "new message" + @test bpc[(1, 1) => (1, 2)] == "new message" + + pairs = [((1, 2) => (2, 2), "m1"), ((2, 2) => (2, 3), "m2")] + + new_bpc = copyto!(deepcopy(bpc), Dict(pairs)) + @test new_bpc[(1, 1) => (1, 2)] == "new message" + @test new_bpc[(1, 2) => (2, 2)] == "m1" + @test new_bpc[(2, 2) => (2, 3)] == "m2" + + new_bpc = copyto!(deepcopy(bpc), dictionary(pairs)) + @test new_bpc[(1, 1) => (1, 2)] == "new message" + @test new_bpc[(1, 2) => (2, 2)] == "m1" + @test new_bpc[(2, 2) => (2, 3)] == "m2" + + bpc_dst = messagecache(edge -> "", all_edges(g)) + + copyto!(bpc_dst, bpc, [(1, 2) => (2, 2), (2, 2) => (2, 3)]) + @test bpc_dst[(1, 1) => (1, 2)] == "" + @test bpc_dst[(1, 2) => (2, 2)] == "(1, 2) => (2, 2)" + @test bpc_dst[(2, 2) => (2, 3)] == "(2, 2) => (2, 3)" + end + @testset "Vertex/region scalars" begin + g = named_path_graph(3) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(ComplexF32, Tuple(is)) + end + + bpc = messagecache(all_edges(g)) do edge + return ones(Float64, Tuple(linkinds(tn, edge))) + end + + # Vertex/edge/region scalars. + @test vertex_scalar(tn, bpc, 2) isa ComplexF64 + @test edge_scalar(bpc, 1 => 2) isa Float64 + + @test region_scalar(tn, bpc, [1]) == vertex_scalar(tn, bpc, 1) + @test region_scalar(tn, bpc, [2, 3]) == prod(vertex_scalars(tn, bpc, [2, 3])) + + # `incoming_messages` excludes the reverse of the passed edge + in_msgs = incoming_messages(bpc, 2 => 3) + @test length(in_msgs) == 1 + @test only(in_msgs) == bpc[1 => 2] + + in_msgs = incoming_messages(bpc, NamedEdge(1 => 2)) + @test length(in_msgs) == 0 + + in_msgs = incoming_messages(bpc, NamedEdge(2 => 1)) + @test length(in_msgs) == 1 + @test only(in_msgs) == bpc[3 => 2] + end + + @testset "subgraph" begin + g = named_grid((3,)) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + bpc = messagecache(edge -> ones(Tuple(linkinds(tn, edge))), all_edges(g)) + + sub_vs = [(1,), (2,)] + subbpc = subgraph(bpc, sub_vs) + @test subbpc isa MessageCache + @test issetequal(vertices(subbpc), sub_vs) + @test has_edge(subbpc, (1,) => (2,)) + end + @testset "diff" begin + g = named_grid((2,)) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + bpc1 = messagecache(edge -> ones(Tuple(linkinds(tn, edge))), all_edges(g)) + + bpc2 = copy(bpc1) + + # Identical caches: diff should be ~0. + @test ITensorNetworksNext.iterate_diff(bpc1, bpc2) ≈ 0.0 atol = 10 * eps() + end + end + + @testset "Algorithm" begin + @testset "$T" for T in (Float32, Float64, ComplexF64, BigFloat) + onet = (tn, edge) -> ones(T, Tuple(linkinds(tn, edge))) + randt = (tn, edge) -> rand(T, Tuple(linkinds(tn, edge))) + + #Chain of tensors + dims = (2, 1) + g = named_grid(dims) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(T, Tuple(is)) + end + + messages = Dict(edge => onet(tn, edge) for edge in all_edges(g)) + + cache = ITensorNetworksNext.beliefpropagation(tn, messages; maxiter = 1) + z_bp = exp(bethe_free_energy(tn, cache)) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test z_bp ≈ z_exact + + #Tree of tensors + dims = (4, 3) + g = named_comb_tree(dims) + l = Dict(e => Index(3) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(T, Tuple(is)) + end + + messages = Dict(edge => onet(tn, edge) for edge in all_edges(g)) + + cache = ITensorNetworksNext.beliefpropagation(tn, messages; maxiter = 1) + z_bp = exp(bethe_free_energy(tn, cache)) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test z_bp ≈ z_exact + + #Spin Ice Model (has analytical bp solution given by 1.5^(n^2)) + @testset "Spin Ice Model (analytical)" begin + for n in (3, 4, 5) + dims = (n, n) + g = named_grid(dims; periodic = true) + tn = spin_ice_tensornetwork(g) + + messages = Dict(edge => randt(tn, edge) for edge in all_edges(g)) + + stopping_criterion = StopWhenConverged(tol = 1.0e-10) + + cache = ITensorNetworksNext.beliefpropagation( + tn, + messages; + maxiter = 10, + stopping_criterion + ) + + z_bp = exp(bethe_free_energy(tn, cache)) + + @test z_bp ≈ 1.5^(n^2) + end + end + end + end +end diff --git a/test/test_contract_network.jl b/test/test_contract_network.jl index 7dda0c6..b453e76 100644 --- a/test/test_contract_network.jl +++ b/test/test_contract_network.jl @@ -1,3 +1,4 @@ +using BackendSelection: @Algorithm_str, Algorithm using Graphs: edges using ITensorBase: Index using ITensorNetworksNext: TensorNetwork, contract_network, linkinds, siteinds @@ -7,6 +8,8 @@ using TensorOperations: TensorOperations using Test: @test, @testset @testset "contract_network" begin + orderalg = alg -> Algorithm"exact"(; order_alg = Algorithm(alg)) + @testset "Contract Vectors of ITensors" begin i, j, k = Index(2), Index(2), Index(5) A = [1.0 1.0; 0.5 1.0][i, j] @@ -14,10 +17,9 @@ using Test: @test, @testset C = [5.0, 1.0][j] D = [-2.0, 3.0, 4.0, 5.0, 1.0][k] - ABCD_1 = contract_network([A, B, C, D]; order_alg = "left_associative") - ABCD_2 = contract_network([A, B, C, D]; order_alg = "eager") - ABCD_3 = contract_network([A, B, C, D]; order_alg = "optimal") - + ABCD_1 = contract_network([A, B, C, D]; alg = orderalg("left_associative")) + ABCD_2 = contract_network([A, B, C, D]; alg = orderalg("eager")) + ABCD_3 = contract_network([A, B, C, D]; alg = orderalg("optimal")) @test ABCD_1 == ABCD_2 == ABCD_3 end @@ -31,9 +33,9 @@ using Test: @test, @testset return randn(Tuple(is)) end - z1 = contract_network(tn; order_alg = "left_associative")[] - z2 = contract_network(tn; order_alg = "eager")[] - z3 = contract_network(tn; order_alg = "optimal")[] + z1 = contract_network(tn; alg = orderalg("left_associative"))[] + z2 = contract_network(tn; alg = orderalg("eager"))[] + z3 = contract_network(tn; alg = orderalg("optimal"))[] @test abs(z1 - z2) / abs(z1) <= 1.0e3 * eps(Float64) @test abs(z1 - z3) / abs(z1) <= 1.0e3 * eps(Float64) diff --git a/test/test_sweeping.jl b/test/test_sweeping.jl index 215a8b8..01881d9 100644 --- a/test/test_sweeping.jl +++ b/test/test_sweeping.jl @@ -11,7 +11,7 @@ struct TestRegion{R, Kwargs <: NamedTuple} <: AIE.NonIterativeAlgorithm end TestRegion(region; kwargs...) = TestRegion(region, (; kwargs...)) -function AI.solve!(problem::TestProblem, algorithm::TestRegion, state::AIE.State; kwargs...) +function AI.solve_loop!(problem::TestProblem, algorithm::TestRegion, state::AIE.State) new_iterate = (; algorithm.region, algorithm.kwargs.foo, algorithm.kwargs.bar) state.iterate = [state.iterate; [new_iterate]] return state @@ -28,8 +28,8 @@ end problem = TestProblem() iterate = [] - state = AI.solve(problem, algorithm; iterate) - @test state.iterate == [(; region = "region", foo = 1, bar = 2)] + iterate = AI.solve(problem, algorithm; iterate) + @test iterate == [(; region = "region", foo = 1, bar = 2)] end @testset "Sweep" begin algorithm = AIE.nested_algorithm(3) do i @@ -37,8 +37,8 @@ end end problem = TestProblem() iterate = [] - state = AI.solve(problem, algorithm; iterate) - @test state.iterate == [ + iterate = AI.solve(problem, algorithm; iterate) + @test iterate == [ (; region = "region1", foo = 1, bar = 2), (; region = "region2", foo = 2, bar = 4), (; region = "region3", foo = 3, bar = 6), @@ -52,8 +52,8 @@ end end problem = TestProblem() iterate = [] - state = AI.solve(problem, algorithm; iterate) - @test state.iterate == [ + iterate = AI.solve(problem, algorithm; iterate) + @test iterate == [ (; region = "sweep1, region1", foo = (1, 1), bar = (2, 2)), (; region = "sweep1, region2", foo = (1, 2), bar = (2, 4)), (; region = "sweep1, region3", foo = (1, 3), bar = (2, 6)), diff --git a/test/test_tensornetwork.jl b/test/test_tensornetwork.jl new file mode 100644 index 0000000..3b4211b --- /dev/null +++ b/test/test_tensornetwork.jl @@ -0,0 +1,236 @@ +using DataGraphs: assigned_edge_data, assigned_vertex_data, underlying_graph, vertex_data +using Graphs: add_edge!, add_vertex!, dst, edges, edgetype, has_edge, has_vertex, + is_directed, ne, nv, rem_vertex!, src, vertices +using ITensorBase: Index +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray +using ITensorNetworksNext: + TensorNetwork, fix_edges!, linkaxes, linkinds, linknames, siteaxes, siteinds, sitenames +using NamedGraphs.GraphsExtensions: incident_edges, subgraph, vertextype +using NamedGraphs.NamedGraphGenerators: named_grid, named_path_graph +using NamedGraphs.PartitionedGraphs: AbstractPartitionedGraph, QuotientVertex, departition, + partitioned_vertices, partitionedgraph, quotient_graph, quotient_graph_type, + quotientvertices +using NamedGraphs: convert_vertextype, similar_graph +using Test: @test, @test_throws, @testset + +@testset "`TensorNetwork`" begin + @testset "Basics" begin + g = named_grid((2, 2)) + s = Dict(v => Index(2) for v in vertices(g)) + tn = TensorNetwork(g) do v + return randn(s[v]) + end + + # `iterate` works (delegates to `vertex_data`). + @test !isempty(collect(tn)) + # `keys` returns vertices. + @test issetequal(keys(tn), vertices(tn)) + # `eltype` matches the eltype of the vertex data. + @test eltype(tn) === eltype(vertex_data(tn)) + # `is_directed` is `false` for AbstractTensorNetwork. + @test !is_directed(typeof(tn)) + + # `show` MIME and default both succeed and mention vertices/edges. + s_plain = sprint(show, MIME"text/plain"(), tn) + @test occursin("vertices", s_plain) + @test occursin("edge", s_plain) + s_default = sprint(show, tn) + @test occursin("vertices", s_default) + + # `setindex!` for edges is intentionally unimplemented. + e = first(edges(tn)) + @test_throws ErrorException tn[e] = randn(2, 2) + @test_throws ErrorException tn[src(e) => dst(e)] = randn(2, 2) + + rem_vertex!(tn, (2, 2)) + @test !has_vertex(tn, (2, 2)) + add_vertex!(tn, (2, 2)) + @test has_vertex(tn, (2, 2)) + @test !isassigned(tn, (2, 2)) + + # Test `fix_edges!` removes edges where there is no link index + t = randn(s[(2, 2)]) + tn[(2, 2)] = t + add_edge!(tn.underlying_graph, (1, 2) => (2, 2)) + fix_edges!(tn, (2, 2)) + @test !has_edge(tn, (1, 2) => (2, 2)) + end + + @testset "link and site functions" begin + g = named_path_graph(3) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + s = Dict(v => Index(2) for v in vertices(g)) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn((s[v], is...)) + end + + E = edgetype(tn) + @test linkinds(tn, 1 => 2) == [l[E(1 => 2)]] + @test linkinds(tn, E(1 => 2)) == [l[E(1 => 2)]] + + @test linkaxes(tn, 1 => 2) == [l[E(1 => 2)]] + @test linkaxes(tn, E(1 => 2)) == [l[E(1 => 2)]] + + @test linknames(tn, 1 => 2) == [l[E(1 => 2)].name] + @test linknames(tn, E(1 => 2)) == [l[E(1 => 2)].name] + + @test siteinds(tn, 1) == [s[1]] + @test siteaxes(tn, 2) == [s[2]] + @test sitenames(tn, 3) == [s[3].name] + end + + @testset "`subgraph`" begin + g = named_grid((3,)) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + sub_vs = [(1,), (2,)] + subtn = subgraph(tn, sub_vs) + @test subtn isa TensorNetwork + @test issetequal(vertices(subtn), sub_vs) + @test has_edge(subtn, (1,) => (2,)) + end + + @testset "DataGraphs/NamedGraphs interface" begin + dims = (3, 3) + g = named_grid(dims) + s = Dict(v => Index(2) for v in vertices(g)) + tn = TensorNetwork(g) do v + return randn(s[v]) + end + + stn = similar_graph(tn) + @test stn isa TensorNetwork + @test vertices(stn) == vertices(tn) + @test edges(stn) == edges(tn) + @test isempty(assigned_vertex_data(stn)) + @test isempty(assigned_edge_data(stn)) + + stn = similar_graph(tn, vertices(tn)) + @test vertices(stn) == vertices(tn) + @test ne(stn) == 0 + @test isempty(assigned_vertex_data(stn)) + @test isempty(assigned_edge_data(stn)) + + stn = similar_graph(typeof(tn)) + @test nv(stn) == 0 + @test stn isa typeof(tn) + + stn = similar_graph(typeof(tn), vertices(tn)) + @test nv(stn) == nv(tn) + @test ne(stn) == 0 + @test stn isa typeof(tn) + + ctn = convert_vertextype(Tuple{Float64, Float64}, tn) + @test ctn isa TensorNetwork + @test vertextype(ctn) == Tuple{Float64, Float64} + @test collect(vertex_data(ctn)) == collect(vertex_data(tn)) + end + + @testset "`PartitionedGraphs`" begin + dims = (3, 3) + g = named_grid(dims) + s = Dict(v => Index(2) for v in vertices(g)) + tn = TensorNetwork(g) do v + return randn(s[v]) + end + + # Row partition: each partition is one row of the grid. + row_parts = [[(i, j) for i in 1:dims[1]] for j in 1:dims[2]] + + @testset "default `partitioned_vertices`" begin + # By default the entire underlying graph is one partition. + pvs = partitioned_vertices(tn) + @test length(pvs) == 1 + @test issetequal(only(pvs), vertices(tn)) + end + + @testset "default `quotientvertices`" begin + qvs = collect(quotientvertices(tn)) + @test length(qvs) == 1 + @test only(qvs) isa QuotientVertex + end + + @testset "`tn[QuotientVertex(...)]` (default)" begin + qv = only(collect(quotientvertices(tn))) + data = tn[qv] + @test data isa LazyNamedDimsArray + end + + @testset "`quotient_graph` (default partitioning)" begin + qtn = quotient_graph(tn) + @test qtn isa TensorNetwork + @test nv(qtn) == 1 + @test ne(qtn) == 0 + v = only(collect(vertices(qtn))) + @test qtn[v] isa LazyNamedDimsArray + end + + @testset "`quotient_graph_type`" begin + QT = quotient_graph_type(typeof(tn)) + @test QT <: TensorNetwork + qtn = quotient_graph(tn) + @test vertextype(qtn) === vertextype(QT) + end + + @testset "`partitionedgraph(tn, parts)`" begin + ptn = partitionedgraph(tn, row_parts) + @test ptn isa TensorNetwork + # The set of underlying vertices/edges is preserved. + @test issetequal(vertices(ptn), vertices(tn)) + @test issetequal(edges(ptn), edges(tn)) + @test nv(ptn) == nv(tn) + @test ne(ptn) == ne(tn) + # Vertex data is copied, not aliased. + @test collect(vertex_data(ptn)) == collect(vertex_data(tn)) + @test vertex_data(ptn) !== vertex_data(tn) + end + + @testset "`partitioned_vertices` of partitioned tn" begin + ptn = partitionedgraph(tn, row_parts) + pvs = partitioned_vertices(ptn) + @test length(pvs) == dims[2] + for part in pvs + @test length(part) == dims[1] + end + @test issetequal(reduce(vcat, pvs), vertices(tn)) + end + + @testset "`tn[QuotientVertex(...)]` (partitioned)" begin + ptn = partitionedgraph(tn, row_parts) + for qv in quotientvertices(ptn) + @test ptn[qv] isa LazyNamedDimsArray + end + end + + @testset "`quotient_graph` of partitioned tn" begin + ptn = partitionedgraph(tn, row_parts) + qtn = quotient_graph(ptn) + @test qtn isa TensorNetwork + @test nv(qtn) == dims[2] + # The row-partitioned grid quotients to a path graph of length `dims[2]`. + @test ne(qtn) == dims[2] - 1 + for v in vertices(qtn) + @test qtn[v] isa LazyNamedDimsArray + end + end + + @testset "`departition`" begin + # `departition` on a non-partitioned tn returns itself. + @test departition(tn) === tn + + # `departition` on a partitioned tn unwraps one layer of partitioning. + ptn = partitionedgraph(tn, row_parts) + dtn = departition(ptn) + @test dtn isa TensorNetwork + @test issetequal(vertices(dtn), vertices(tn)) + @test issetequal(edges(dtn), edges(tn)) + end + end +end