Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 11 additions & 28 deletions graphix/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,7 @@

from graphix.flow.core import CausalFlow, GFlow, XZCorrections
from graphix.parameter import ExpressionOrSupportsComplex, ExpressionOrSupportsFloat, Parameter
from graphix.sim import (
Backend,
Data,
DensityMatrixBackend,
StatevectorBackend,
)
from graphix.sim import Backend, Data, DensityMatrixBackend, StatevectorBackend
from graphix.sim.base_backend import _StateT_co
from graphix.sim.tensornet import TensorNetworkBackend
from graphix.simulator import _BackendLiteral
Expand Down Expand Up @@ -113,7 +108,6 @@ def __init__(
else:
self.__input_nodes = list(input_nodes) # input nodes (list() makes our own copy of the list)
self.__n_node = len(self.__input_nodes) # total number of nodes in the graph state
self._pauli_preprocessed = False # flag for `measure_pauli` preprocessing completion

self.__seq = []
# output nodes are initially a copy input nodes, since none are measured yet
Expand All @@ -135,7 +129,6 @@ def add(self, cmd: Command) -> None:
cmd : :class:`graphix.command.Command`
MBQC command.
"""
self._pauli_preprocessed = False
if cmd.kind == CommandKind.N:
self.__n_node += 1
self.__output_nodes.append(cmd.node)
Expand All @@ -149,7 +142,6 @@ def extend(self, *cmds: Command | Iterable[Command]) -> None:

:param cmds: sequences of commands
"""
self._pauli_preprocessed = False
for item in cmds:
if isinstance(item, Iterable):
for cmd in item:
Expand All @@ -159,7 +151,6 @@ def extend(self, *cmds: Command | Iterable[Command]) -> None:

def clear(self) -> None:
"""Clear the sequence of pattern commands."""
self._pauli_preprocessed = False
self.__n_node = len(self.__input_nodes)
self.__seq = []
self.__output_nodes = list(self.__input_nodes)
Expand All @@ -171,7 +162,6 @@ def replace(self, cmds: list[Command], input_nodes: list[int] | None = None) ->

:param input_nodes: optional, list of input qubits (by default, keep the same input nodes as before)
"""
self._pauli_preprocessed = False
if input_nodes is not None:
self.__input_nodes = list(input_nodes)
self.clear()
Expand Down Expand Up @@ -217,7 +207,6 @@ def compose(
- Input (and, respectively, output) nodes in the returned pattern have the order of the pattern `self` followed by those of the pattern `other`. Merged nodes are removed.
- If `preserve_mapping = True` and :math:`|M_1| = |I_2| = |O_2|`, then the outputs of the returned pattern are the outputs of pattern `self`, where the nth merged output is replaced by the output of pattern `other` corresponding to its nth input instead.
"""
self._pauli_preprocessed = False
nodes_p1 = self.extract_nodes() | self.results.keys() # Results contain preprocessed Pauli nodes
nodes_p2 = other.extract_nodes() | other.results.keys()

Expand Down Expand Up @@ -1315,13 +1304,12 @@ def minimize_space(self) -> None:
if not self.is_standard():
self.standardize()
meas_order = None
if not self._pauli_preprocessed:
try:
cf = self.extract_causal_flow()
except FlowError:
meas_order = None
else:
meas_order = list(itertools.chain(*reversed(cf.partial_order_layers[1:])))
try:
cf = self.extract_causal_flow()
except FlowError:
meas_order = None
else:
meas_order = list(itertools.chain(*reversed(cf.partial_order_layers[1:])))
if meas_order is None:
meas_order = self._measurement_order_space()
self._reorder_pattern(self.sort_measurement_commands(meas_order))
Expand Down Expand Up @@ -1483,7 +1471,7 @@ def perform_pauli_measurements(self, ignore_pauli_with_deps: bool = False) -> No
"""
if self.input_nodes:
raise ValueError("Remove inputs with `self.remove_input_nodes()` before performing Pauli presimulation.")
measure_pauli(self, copy=False, ignore_pauli_with_deps=ignore_pauli_with_deps)
self.__dict__.update(measure_pauli(self, ignore_pauli_with_deps=ignore_pauli_with_deps).__dict__)

def draw_graph(
self,
Expand Down Expand Up @@ -1600,7 +1588,6 @@ def copy(self) -> Pattern:
result.__input_nodes = self.__input_nodes.copy()
result.__output_nodes = self.__output_nodes.copy()
result.__n_node = self.__n_node
result._pauli_preprocessed = self._pauli_preprocessed
result.results = self.results.copy()
return result

Expand Down Expand Up @@ -1709,7 +1696,7 @@ def __str__(self) -> str:
assert_never(self.reason)


def measure_pauli(pattern: Pattern, *, copy: bool = False, ignore_pauli_with_deps: bool = False) -> Pattern:
def measure_pauli(pattern: Pattern, *, ignore_pauli_with_deps: bool = False) -> Pattern:
"""Perform Pauli measurement of a pattern by fast graph state simulator.

Uses the decorated-graph method implemented in graphix.graphsim to perform the measurements in Pauli bases, and then sort remaining nodes back into
Expand All @@ -1720,9 +1707,6 @@ def measure_pauli(pattern: Pattern, *, copy: bool = False, ignore_pauli_with_dep
Parameters
----------
pattern : graphix.pattern.Pattern object
copy : bool
True: changes will be applied to new copied object and will be returned
False: changes will be applied to the supplied Pattern object
ignore_pauli_with_deps : bool
Optional (*False* by default).
If *True*, Pauli measurements with domains depending on other measures are preserved as-is in the pattern.
Expand All @@ -1738,14 +1722,14 @@ def measure_pauli(pattern: Pattern, *, copy: bool = False, ignore_pauli_with_dep
.. seealso:: :class:`graphix.pattern.Pattern.remove_input_nodes`
.. seealso:: :class:`graphix.graphsim.GraphState`
"""
pat = Pattern() if copy else pattern
pat = Pattern()
standardized_pattern = optimization.StandardizedPattern.from_pattern(pattern)
if not ignore_pauli_with_deps:
standardized_pattern = standardized_pattern.perform_pauli_pushing()
output_nodes = set(pattern.output_nodes)
graph = standardized_pattern.extract_graph()
graph_state = GraphState(nodes=graph.nodes, edges=graph.edges, vops=standardized_pattern.c_dict)
results: dict[int, Outcome] = pat.results
results: dict[int, Outcome] = pattern.results
to_measure, non_pauli_meas = pauli_nodes(standardized_pattern)
if not to_measure:
return pattern
Expand Down Expand Up @@ -1813,7 +1797,6 @@ def measure_pauli(pattern: Pattern, *, copy: bool = False, ignore_pauli_with_dep
pat.reorder_output_nodes(standardized_pattern.output_nodes)
assert pat.n_node == len(graph_state.nodes)
pat.results = results
pat._pauli_preprocessed = True
return pat


Expand Down