diff --git a/graphix/pattern.py b/graphix/pattern.py index b636c97ec..edb212d8f 100644 --- a/graphix/pattern.py +++ b/graphix/pattern.py @@ -43,12 +43,7 @@ from graphix.flow.core import CausalFlow, GFlow, XZCorrections from graphix.parameter import ExpressionOrSupportsComplex, ExpressionOrSupportsFloat, Parameter - from graphix.sim import ( - Backend, - Data, - DensityMatrixBackend, - StatevectorBackend, - ) + from graphix.sim import Backend, Data, DensityMatrixBackend, StatevectorBackend from graphix.sim.base_backend import _StateT_co from graphix.sim.tensornet import TensorNetworkBackend from graphix.simulator import _BackendLiteral @@ -113,7 +108,6 @@ def __init__( else: self.__input_nodes = list(input_nodes) # input nodes (list() makes our own copy of the list) self.__n_node = len(self.__input_nodes) # total number of nodes in the graph state - self._pauli_preprocessed = False # flag for `measure_pauli` preprocessing completion self.__seq = [] # output nodes are initially a copy input nodes, since none are measured yet @@ -135,7 +129,6 @@ def add(self, cmd: Command) -> None: cmd : :class:`graphix.command.Command` MBQC command. """ - self._pauli_preprocessed = False if cmd.kind == CommandKind.N: self.__n_node += 1 self.__output_nodes.append(cmd.node) @@ -149,7 +142,6 @@ def extend(self, *cmds: Command | Iterable[Command]) -> None: :param cmds: sequences of commands """ - self._pauli_preprocessed = False for item in cmds: if isinstance(item, Iterable): for cmd in item: @@ -159,7 +151,6 @@ def extend(self, *cmds: Command | Iterable[Command]) -> None: def clear(self) -> None: """Clear the sequence of pattern commands.""" - self._pauli_preprocessed = False self.__n_node = len(self.__input_nodes) self.__seq = [] self.__output_nodes = list(self.__input_nodes) @@ -171,7 +162,6 @@ def replace(self, cmds: list[Command], input_nodes: list[int] | None = None) -> :param input_nodes: optional, list of input qubits (by default, keep the same input nodes as before) """ - self._pauli_preprocessed = False if input_nodes is not None: self.__input_nodes = list(input_nodes) self.clear() @@ -217,7 +207,6 @@ def compose( - Input (and, respectively, output) nodes in the returned pattern have the order of the pattern `self` followed by those of the pattern `other`. Merged nodes are removed. - If `preserve_mapping = True` and :math:`|M_1| = |I_2| = |O_2|`, then the outputs of the returned pattern are the outputs of pattern `self`, where the nth merged output is replaced by the output of pattern `other` corresponding to its nth input instead. """ - self._pauli_preprocessed = False nodes_p1 = self.extract_nodes() | self.results.keys() # Results contain preprocessed Pauli nodes nodes_p2 = other.extract_nodes() | other.results.keys() @@ -1315,13 +1304,12 @@ def minimize_space(self) -> None: if not self.is_standard(): self.standardize() meas_order = None - if not self._pauli_preprocessed: - try: - cf = self.extract_causal_flow() - except FlowError: - meas_order = None - else: - meas_order = list(itertools.chain(*reversed(cf.partial_order_layers[1:]))) + try: + cf = self.extract_causal_flow() + except FlowError: + meas_order = None + else: + meas_order = list(itertools.chain(*reversed(cf.partial_order_layers[1:]))) if meas_order is None: meas_order = self._measurement_order_space() self._reorder_pattern(self.sort_measurement_commands(meas_order)) @@ -1483,7 +1471,7 @@ def perform_pauli_measurements(self, ignore_pauli_with_deps: bool = False) -> No """ if self.input_nodes: raise ValueError("Remove inputs with `self.remove_input_nodes()` before performing Pauli presimulation.") - measure_pauli(self, copy=False, ignore_pauli_with_deps=ignore_pauli_with_deps) + self.__dict__.update(measure_pauli(self, ignore_pauli_with_deps=ignore_pauli_with_deps).__dict__) def draw_graph( self, @@ -1600,7 +1588,6 @@ def copy(self) -> Pattern: result.__input_nodes = self.__input_nodes.copy() result.__output_nodes = self.__output_nodes.copy() result.__n_node = self.__n_node - result._pauli_preprocessed = self._pauli_preprocessed result.results = self.results.copy() return result @@ -1709,7 +1696,7 @@ def __str__(self) -> str: assert_never(self.reason) -def measure_pauli(pattern: Pattern, *, copy: bool = False, ignore_pauli_with_deps: bool = False) -> Pattern: +def measure_pauli(pattern: Pattern, *, ignore_pauli_with_deps: bool = False) -> Pattern: """Perform Pauli measurement of a pattern by fast graph state simulator. Uses the decorated-graph method implemented in graphix.graphsim to perform the measurements in Pauli bases, and then sort remaining nodes back into @@ -1720,9 +1707,6 @@ def measure_pauli(pattern: Pattern, *, copy: bool = False, ignore_pauli_with_dep Parameters ---------- pattern : graphix.pattern.Pattern object - copy : bool - True: changes will be applied to new copied object and will be returned - False: changes will be applied to the supplied Pattern object ignore_pauli_with_deps : bool Optional (*False* by default). If *True*, Pauli measurements with domains depending on other measures are preserved as-is in the pattern. @@ -1738,14 +1722,14 @@ def measure_pauli(pattern: Pattern, *, copy: bool = False, ignore_pauli_with_dep .. seealso:: :class:`graphix.pattern.Pattern.remove_input_nodes` .. seealso:: :class:`graphix.graphsim.GraphState` """ - pat = Pattern() if copy else pattern + pat = Pattern() standardized_pattern = optimization.StandardizedPattern.from_pattern(pattern) if not ignore_pauli_with_deps: standardized_pattern = standardized_pattern.perform_pauli_pushing() output_nodes = set(pattern.output_nodes) graph = standardized_pattern.extract_graph() graph_state = GraphState(nodes=graph.nodes, edges=graph.edges, vops=standardized_pattern.c_dict) - results: dict[int, Outcome] = pat.results + results: dict[int, Outcome] = pattern.results to_measure, non_pauli_meas = pauli_nodes(standardized_pattern) if not to_measure: return pattern @@ -1813,7 +1797,6 @@ def measure_pauli(pattern: Pattern, *, copy: bool = False, ignore_pauli_with_dep pat.reorder_output_nodes(standardized_pattern.output_nodes) assert pat.n_node == len(graph_state.nodes) pat.results = results - pat._pauli_preprocessed = True return pat