diff --git a/examples/mbqc_vqe.py b/examples/mbqc_vqe.py index ae75a7f7f..fce62723a 100644 --- a/examples/mbqc_vqe.py +++ b/examples/mbqc_vqe.py @@ -37,8 +37,11 @@ if TYPE_CHECKING: from collections.abc import Iterable + from scipy.optimize import OptimizeResult + + from graphix.fundamentals import ParameterizedAngle from graphix.pattern import Pattern - from graphix.transpiler import Angle + from graphix.sim.tensornet import MBQCTensorNet Z = np.array([[1, 0], [0, -1]]) X = np.array([[0, 1], [1, 0]]) @@ -46,7 +49,7 @@ # %% # Define the Hamiltonian for the VQE problem (Example: H = Z0Z1 + X0 + X1) -def create_hamiltonian() -> npt.NDArray[np.complex128]: +def create_hamiltonian() -> npt.NDArray[np.float64]: return np.kron(Z, Z) + np.kron(X, np.eye(2)) + np.kron(np.eye(2), X) @@ -65,7 +68,7 @@ def batched(iterable, n): # %% # Function to build the VQE circuit -def build_vqe_circuit(n_qubits: int, params: Iterable[Angle]) -> Circuit: +def build_vqe_circuit(n_qubits: int, params: Iterable[ParameterizedAngle]) -> Circuit: circuit = Circuit(n_qubits) for i, (x, y, z) in enumerate(batched(params, n=3)): circuit.rx(i, x) @@ -78,13 +81,13 @@ def build_vqe_circuit(n_qubits: int, params: Iterable[Angle]) -> Circuit: # %% class MBQCVQE: - def __init__(self, n_qubits: int, hamiltonian: npt.NDArray): + def __init__(self, n_qubits: int, hamiltonian: npt.NDArray[np.float64]): self.n_qubits = n_qubits self.hamiltonian = hamiltonian # %% # Function to build the MBQC pattern - def build_mbqc_pattern(self, params: Iterable[Angle]) -> Pattern: + def build_mbqc_pattern(self, params: Iterable[ParameterizedAngle]) -> Pattern: circuit = build_vqe_circuit(self.n_qubits, params) pattern = circuit.transpile().pattern pattern.standardize() @@ -95,34 +98,36 @@ def build_mbqc_pattern(self, params: Iterable[Angle]) -> Pattern: # %% # Function to simulate the MBQC circuit - def simulate_mbqc(self, params: Iterable[float], backend="tensornetwork"): + def simulate_mbqc( + self, + params: Iterable[float], + ) -> MBQCTensorNet: pattern = self.build_mbqc_pattern(params) - simulator = PatternSimulator(pattern, backend=backend) - if backend == "tensornetwork": - simulator.run() # Simulate the MBQC circuit using tensor network - tn = simulator.backend.state - tn.default_output_nodes = pattern.output_nodes # Set the default_output_nodes attribute - if tn.default_output_nodes is None: - raise ValueError("Output nodes are not set for tensor network simulation.") - return tn - return simulator.run() # Simulate the MBQC circuit using other backends + simulator = PatternSimulator(pattern, backend="tensornetwork") + state = simulator.backend.state + simulator.run() # Simulate the MBQC circuit using tensor network + tn = simulator.backend.state + tn.default_output_nodes = pattern.output_nodes # Set the default_output_nodes attribute + if tn.default_output_nodes is None: + raise ValueError("Output nodes are not set for tensor network simulation.") + return state # %% # Function to compute the energy - def compute_energy(self, params: Iterable[float]): + def compute_energy(self, params: Iterable[float]) -> float: # Simulate the MBQC circuit using tensor network backend - tn = self.simulate_mbqc(params, backend="tensornetwork") + tn = self.simulate_mbqc(params) # Compute the expectation value using MBQCTensornet.expectation_value - return tn.expectation_value(self.hamiltonian, qubit_indices=range(self.n_qubits)) + return tn.expectation_value(self.hamiltonian.astype(np.complex128), qubit_indices=range(self.n_qubits)) class MBQCVQEWithPlaceholders(MBQCVQE): - def __init__(self, n_qubits: int, hamiltonian) -> None: + def __init__(self, n_qubits: int, hamiltonian: npt.NDArray[np.float64]) -> None: super().__init__(n_qubits, hamiltonian) self.placeholders = tuple(Placeholder(f"{r}[{q}]") for q in range(n_qubits) for r in ("X", "Y", "Z")) self.pattern = super().build_mbqc_pattern(self.placeholders) - def build_mbqc_pattern(self, params): + def build_mbqc_pattern(self, params: Iterable[ParameterizedAngle]) -> Pattern: return self.pattern.xreplace(dict(zip(self.placeholders, params, strict=True))) @@ -133,12 +138,12 @@ def build_mbqc_pattern(self, params): # %% # Instantiate the MBQCVQE class -mbqc_vqe = MBQCVQEWithPlaceholders(n_qubits, hamiltonian) +mbqc_vqe: MBQCVQE = MBQCVQEWithPlaceholders(n_qubits, hamiltonian) # %% # Define the cost function -def cost_function(params): +def cost_function(params: Iterable[float]) -> float: return mbqc_vqe.compute_energy(params) @@ -150,7 +155,7 @@ def cost_function(params): # %% # Perform the optimization using COBYLA -def compute(): +def compute() -> OptimizeResult: return minimize(cost_function, initial_params, method="COBYLA", options={"maxiter": 100}) diff --git a/examples/qaoa.py b/examples/qaoa.py index 57b2cafc9..db43af46b 100644 --- a/examples/qaoa.py +++ b/examples/qaoa.py @@ -17,9 +17,9 @@ rng = np.random.default_rng() n = 4 -xi = rng.random(6) -theta = rng.random(4) -g = nx.complete_graph(n) +xi: np.ndarray[tuple[int], np.dtype[np.float64]] = rng.random(6) +theta: np.ndarray[tuple[int], np.dtype[np.float64]] = rng.random(4) +g: nx.Graph[int] = nx.complete_graph(n) circuit = Circuit(n) for i, (u, v) in enumerate(g.edges): circuit.cnot(u, v) diff --git a/examples/qft_with_tn.py b/examples/qft_with_tn.py index 4ec0bffa7..eb92a6b3e 100644 --- a/examples/qft_with_tn.py +++ b/examples/qft_with_tn.py @@ -10,11 +10,16 @@ # %% from __future__ import annotations +from typing import TYPE_CHECKING + from graphix import Circuit from graphix.fundamentals import ANGLE_PI +if TYPE_CHECKING: + from graphix.fundamentals import Angle + -def cp(circuit, theta, control, target): +def cp(circuit: Circuit, theta: Angle, control: int, target: int) -> None: circuit.rz(control, theta / 2) circuit.rz(target, theta / 2) circuit.cnot(control, target) @@ -22,19 +27,18 @@ def cp(circuit, theta, control, target): circuit.cnot(control, target) -def qft_rotations(circuit, n): +def qft_rotations(circuit: Circuit, n: int) -> None: circuit.h(n) for qubit in range(n + 1, circuit.width): cp(circuit, ANGLE_PI / 2 ** (qubit - n), qubit, n) -def swap_registers(circuit, n): +def swap_registers(circuit: Circuit, n: int) -> None: for qubit in range(n // 2): circuit.swap(qubit, n - qubit - 1) - return circuit -def qft(circuit, n): +def qft(circuit: Circuit, n: int) -> None: for i in range(n): qft_rotations(circuit, i) swap_registers(circuit, n) diff --git a/examples/tn_simulation.py b/examples/tn_simulation.py index e96ea542a..8de62c998 100644 --- a/examples/tn_simulation.py +++ b/examples/tn_simulation.py @@ -11,18 +11,26 @@ from __future__ import annotations from functools import reduce +from typing import TYPE_CHECKING import cotengra as ctg import matplotlib.pyplot as plt import networkx as nx import numpy as np -import opt_einsum as oe import quimb.tensor as qtn +from cotengra.oe import PathOptimizer from scipy.optimize import minimize from graphix import Circuit from graphix.fundamentals import ANGLE_PI +if TYPE_CHECKING: + from collections.abc import Collection, Sequence + + import numpy.typing as npt + + from graphix.fundamentals import Angle + rng = np.random.default_rng() # %% @@ -34,7 +42,13 @@ # Let's start with defining a helper function for buidling the circuit. -def ansatz(circuit, n, gamma, beta, iterations): +def ansatz( + circuit: Circuit, + n: int, + gamma: npt.NDArray[np.float64] | Sequence[Angle], + beta: npt.NDArray[np.float64] | Sequence[Angle], + iterations: int, +) -> None: for j in range(iterations): for i in range(1, n): circuit.cnot(i, 0) @@ -109,8 +123,8 @@ def ansatz(circuit, n, gamma, beta, iterations): t2 = qtn.rand_tensor([1, 2], ["b", "c"]) t1.add_tag("T1") t2.add_tag("T2") -t = qtn.TensorNetwork([t1, t2]) -t.draw(ax=ax[1], title="MPS", legend=False, color=["T1", "T2"]) +tn = qtn.TensorNetwork([t1, t2]) +tn.draw(ax=ax[1], title="MPS", legend=False, color=["T1", "T2"]) plt.show() # %% @@ -149,7 +163,7 @@ def ansatz(circuit, n, gamma, beta, iterations): # It is also possible to change the path contraction algorithm. # Let's explore that too and define a custom optimizer for contraction, that we can use later. -opt = ctg.HyperOptimizer( +opt: PathOptimizer = ctg.HyperOptimizer( minimize="combo", reconf_opts={}, progbar=True, @@ -172,7 +186,14 @@ def ansatz(circuit, n, gamma, beta, iterations): # Create a cost function using the elements of graphix, which were already discussed above. -def cost(params, n, ham, quantum_iter, slice_index, opt=None): +def cost( + params: npt.NDArray[np.float64], + n: int, + ham: npt.NDArray[np.float64], + quantum_iter: int, + slice_index: int, + opt: str | PathOptimizer | None = None, +) -> float: circuit = Circuit(n) gamma = params[:slice_index] beta = params[slice_index:] @@ -184,7 +205,7 @@ def cost(params, n, ham, quantum_iter, slice_index, opt=None): pattern.remove_input_nodes() pattern.perform_pauli_measurements() mbqc_tn = pattern.simulate_pattern(backend="tensornetwork", graph_prep="parallel") - exp_val = 0 + exp_val: float = 0 for op in ham: exp_val += np.real(mbqc_tn.expectation_value(op, range(n), optimize=opt)) return exp_val @@ -195,7 +216,7 @@ def cost(params, n, ham, quantum_iter, slice_index, opt=None): ham = [reduce(np.kron, [pauli_z] * n)] for i in range(1, n): - op = [identity] * n + op = np.array([identity] * n) op[0] = pauli_z op[i] = pauli_z op = reduce(np.kron, op) @@ -203,8 +224,10 @@ def cost(params, n, ham, quantum_iter, slice_index, opt=None): # Use yet again another optimizer for path contraction. -class MyOptimizer(oe.paths.PathOptimizer): - def __call__(self, inputs, output, size_dict, memory_limit=None): +class MyOptimizer(PathOptimizer): + def __call__( + self, inputs: Collection[object], output: object, size_dict: object, memory_limit: object = None + ) -> list[tuple[int, int]]: return [(0, 1)] * (len(inputs) - 1) @@ -258,7 +281,7 @@ def __call__(self, inputs, output, size_dict, memory_limit=None): fig, ax = plt.subplots(ncols=2, figsize=(8, 6)) ax = ax.flatten() -g = nx.Graph() +g: nx.Graph[int] = nx.Graph() for i in range(1, n): g.add_edge(0, i) color = ["blue"] * n diff --git a/graphix/instruction.py b/graphix/instruction.py index c097b6b3b..51cf3bfc8 100644 --- a/graphix/instruction.py +++ b/graphix/instruction.py @@ -3,9 +3,13 @@ from __future__ import annotations import enum +from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum -from typing import ClassVar, Literal, SupportsFloat +from typing import ClassVar, Literal, Self, SupportsFloat + +# override introduced in Python 3.12 +from typing_extensions import override from graphix import utils from graphix.fundamentals import ( @@ -60,9 +64,39 @@ def __init_subclass__(cls) -> None: utils.check_kind(cls, {"InstructionKind": InstructionKind, "Plane": Plane}) -class BaseInstruction(DataclassReprMixin): +class InstructionVisitor: + """Visitor for instruction. + + This base class can be subclassed to rewrite instructions by + overriding some of the following functions: + + - ``visit_qubit``: rewrite qubit indices. + + - ``visit_angle``: rewrite angles. + + - ``visit_axis``: rewrite axes. + """ + + def visit_qubit(self, qubit: int) -> int: # noqa: PLR6301 + """Rewrite a qubit index.""" + return qubit + + def visit_angle(self, angle: ParameterizedAngle) -> ParameterizedAngle: # noqa: PLR6301 + """Rewrite an angle.""" + return angle + + def visit_axis(self, axis: Axis) -> Axis: # noqa: PLR6301 + """Rewrite an axis.""" + return axis + + +class BaseInstruction(ABC, DataclassReprMixin): """Base class for circuit instruction.""" + @abstractmethod + def visit(self, visitor: InstructionVisitor) -> Self: + """Rewrite the instruction according to the given visitor.""" + @dataclass(repr=False) class CCX(_KindChecker, BaseInstruction): @@ -72,6 +106,11 @@ class CCX(_KindChecker, BaseInstruction): controls: tuple[int, int] kind: ClassVar[Literal[InstructionKind.CCX]] = field(default=InstructionKind.CCX, init=False) + @override + def visit(self, visitor: InstructionVisitor) -> CCX: + u, v = self.controls + return CCX(visitor.visit_qubit(self.target), (visitor.visit_qubit(u), visitor.visit_qubit(v))) + @dataclass(repr=False) class RZZ(_KindChecker, BaseInstruction): @@ -82,6 +121,10 @@ class RZZ(_KindChecker, BaseInstruction): angle: ParameterizedAngle = field(metadata={"repr": repr_angle}) kind: ClassVar[Literal[InstructionKind.RZZ]] = field(default=InstructionKind.RZZ, init=False) + @override + def visit(self, visitor: InstructionVisitor) -> RZZ: + return RZZ(visitor.visit_qubit(self.target), visitor.visit_qubit(self.control), visitor.visit_angle(self.angle)) + @dataclass(repr=False) class CNOT(_KindChecker, BaseInstruction): @@ -91,6 +134,10 @@ class CNOT(_KindChecker, BaseInstruction): control: int kind: ClassVar[Literal[InstructionKind.CNOT]] = field(default=InstructionKind.CNOT, init=False) + @override + def visit(self, visitor: InstructionVisitor) -> CNOT: + return CNOT(visitor.visit_qubit(self.target), visitor.visit_qubit(self.control)) + @dataclass(repr=False) class CZ(_KindChecker, BaseInstruction): @@ -99,6 +146,11 @@ class CZ(_KindChecker, BaseInstruction): targets: tuple[int, int] kind: ClassVar[Literal[InstructionKind.CZ]] = field(default=InstructionKind.CZ, init=False) + @override + def visit(self, visitor: InstructionVisitor) -> CZ: + u, v = self.targets + return CZ((visitor.visit_qubit(u), visitor.visit_qubit(v))) + @dataclass(repr=False) class SWAP(_KindChecker, BaseInstruction): @@ -107,6 +159,11 @@ class SWAP(_KindChecker, BaseInstruction): targets: tuple[int, int] kind: ClassVar[Literal[InstructionKind.SWAP]] = field(default=InstructionKind.SWAP, init=False) + @override + def visit(self, visitor: InstructionVisitor) -> SWAP: + u, v = self.targets + return SWAP((visitor.visit_qubit(u), visitor.visit_qubit(v))) + @dataclass(repr=False) class H(_KindChecker, BaseInstruction): @@ -115,6 +172,10 @@ class H(_KindChecker, BaseInstruction): target: int kind: ClassVar[Literal[InstructionKind.H]] = field(default=InstructionKind.H, init=False) + @override + def visit(self, visitor: InstructionVisitor) -> H: + return H(visitor.visit_qubit(self.target)) + @dataclass(repr=False) class S(_KindChecker, BaseInstruction): @@ -123,6 +184,10 @@ class S(_KindChecker, BaseInstruction): target: int kind: ClassVar[Literal[InstructionKind.S]] = field(default=InstructionKind.S, init=False) + @override + def visit(self, visitor: InstructionVisitor) -> S: + return S(visitor.visit_qubit(self.target)) + @dataclass(repr=False) class X(_KindChecker, BaseInstruction): @@ -131,6 +196,10 @@ class X(_KindChecker, BaseInstruction): target: int kind: ClassVar[Literal[InstructionKind.X]] = field(default=InstructionKind.X, init=False) + @override + def visit(self, visitor: InstructionVisitor) -> X: + return X(visitor.visit_qubit(self.target)) + @dataclass(repr=False) class Y(_KindChecker, BaseInstruction): @@ -139,6 +208,10 @@ class Y(_KindChecker, BaseInstruction): target: int kind: ClassVar[Literal[InstructionKind.Y]] = field(default=InstructionKind.Y, init=False) + @override + def visit(self, visitor: InstructionVisitor) -> Y: + return Y(visitor.visit_qubit(self.target)) + @dataclass(repr=False) class Z(_KindChecker, BaseInstruction): @@ -147,6 +220,10 @@ class Z(_KindChecker, BaseInstruction): target: int kind: ClassVar[Literal[InstructionKind.Z]] = field(default=InstructionKind.Z, init=False) + @override + def visit(self, visitor: InstructionVisitor) -> Z: + return Z(visitor.visit_qubit(self.target)) + @dataclass(repr=False) class I(_KindChecker, BaseInstruction): @@ -155,6 +232,10 @@ class I(_KindChecker, BaseInstruction): target: int kind: ClassVar[Literal[InstructionKind.I]] = field(default=InstructionKind.I, init=False) + @override + def visit(self, visitor: InstructionVisitor) -> I: + return I(visitor.visit_qubit(self.target)) + @dataclass(repr=False) class M(_KindChecker, BaseInstruction): @@ -164,6 +245,10 @@ class M(_KindChecker, BaseInstruction): axis: Axis kind: ClassVar[Literal[InstructionKind.M]] = field(default=InstructionKind.M, init=False) + @override + def visit(self, visitor: InstructionVisitor) -> M: + return M(visitor.visit_qubit(self.target), visitor.visit_axis(self.axis)) + @dataclass(repr=False) class RX(_KindChecker, BaseInstruction): @@ -173,6 +258,10 @@ class RX(_KindChecker, BaseInstruction): angle: ParameterizedAngle = field(metadata={"repr": repr_angle}) kind: ClassVar[Literal[InstructionKind.RX]] = field(default=InstructionKind.RX, init=False) + @override + def visit(self, visitor: InstructionVisitor) -> RX: + return RX(visitor.visit_qubit(self.target), visitor.visit_angle(self.angle)) + @dataclass(repr=False) class RY(_KindChecker, BaseInstruction): @@ -182,6 +271,10 @@ class RY(_KindChecker, BaseInstruction): angle: ParameterizedAngle = field(metadata={"repr": repr_angle}) kind: ClassVar[Literal[InstructionKind.RY]] = field(default=InstructionKind.RY, init=False) + @override + def visit(self, visitor: InstructionVisitor) -> RY: + return RY(visitor.visit_qubit(self.target), visitor.visit_angle(self.angle)) + @dataclass(repr=False) class RZ(_KindChecker, BaseInstruction): @@ -191,5 +284,9 @@ class RZ(_KindChecker, BaseInstruction): angle: ParameterizedAngle = field(metadata={"repr": repr_angle}) kind: ClassVar[Literal[InstructionKind.RZ]] = field(default=InstructionKind.RZ, init=False) + @override + def visit(self, visitor: InstructionVisitor) -> RZ: + return RZ(visitor.visit_qubit(self.target), visitor.visit_angle(self.angle)) + Instruction = CCX | RZZ | CNOT | SWAP | CZ | H | S | X | Y | Z | I | M | RX | RY | RZ diff --git a/graphix/pattern.py b/graphix/pattern.py index 939da4c5e..5e26933da 100644 --- a/graphix/pattern.py +++ b/graphix/pattern.py @@ -1399,7 +1399,7 @@ def simulate_pattern( .. seealso:: :class:`graphix.simulator.PatternSimulator` """ - sim: PatternSimulator[_StateT_co] = PatternSimulator(self, backend=backend, **kwargs) + sim = PatternSimulator(self, backend=backend, **kwargs) sim.run(input_state, rng=rng) return sim.backend.state diff --git a/graphix/sim/base_backend.py b/graphix/sim/base_backend.py index 6cef7e33e..4512fa9bd 100644 --- a/graphix/sim/base_backend.py +++ b/graphix/sim/base_backend.py @@ -11,7 +11,6 @@ import numpy as np import numpy.typing as npt -# TypeAlias introduced in Python 3.10 # override introduced in Python 3.12 from typing_extensions import override diff --git a/graphix/simulator.py b/graphix/simulator.py index 8aa03f4fa..19e58eff2 100644 --- a/graphix/simulator.py +++ b/graphix/simulator.py @@ -9,7 +9,7 @@ import abc import logging import warnings -from typing import TYPE_CHECKING, Generic, Literal, overload +from typing import TYPE_CHECKING, Generic, Literal, TypeVar, overload # assert_never introduced in Python 3.11 # override introduced in Python 3.12 @@ -25,7 +25,6 @@ StatevectorBackend, TensorNetworkBackend, ) -from graphix.sim.base_backend import _StateT_co from graphix.states import BasicStates if TYPE_CHECKING: @@ -37,13 +36,22 @@ from graphix.measurements import Measurement, Outcome from graphix.noise_models.noise_model import CommandOrNoise, NoiseModel from graphix.pattern import Pattern - from graphix.sim import Data + from graphix.sim import Data, DensityMatrix, MBQCTensorNet, Statevec logger = logging.getLogger(__name__) _BuiltinBackend = DensityMatrixBackend | StatevectorBackend | TensorNetworkBackend _BackendLiteral = Literal["statevector", "densitymatrix", "tensornetwork", "mps"] +if TYPE_CHECKING: + _BuiltinBackendState = DensityMatrix | MBQCTensorNet | Statevec + + _StateT = TypeVar("_StateT") + +# This type variable should be defined outside TYPE_CHECKING block +# because it appears in the parameters of `PatternSimulator`. +_StateT_co = TypeVar("_StateT_co", covariant=True) + class PrepareMethod(abc.ABC): """Prepare method used by the simulator. @@ -235,12 +243,64 @@ class PatternSimulator(Generic[_StateT_co]): """ noise_model: NoiseModel | None - backend: Backend[_StateT_co] | _BuiltinBackend + backend: Backend[_StateT_co] + @overload def __init__( - self, + self: PatternSimulator[Statevec], + pattern: Pattern, + backend: Literal["statevector"] = ..., + prepare_method: PrepareMethod | None = None, + measure_method: MeasureMethod | None = None, + noise_model: NoiseModel | None = None, + branch_selector: BranchSelector | None = None, + graph_prep: str | None = None, + symbolic: bool = False, + ) -> None: ... + + @overload + def __init__( + self: PatternSimulator[DensityMatrix], + pattern: Pattern, + backend: Literal["densitymatrix"] = ..., + prepare_method: PrepareMethod | None = None, + measure_method: MeasureMethod | None = None, + noise_model: NoiseModel | None = None, + branch_selector: BranchSelector | None = None, + graph_prep: str | None = None, + symbolic: bool = False, + ) -> None: ... + + @overload + def __init__( + self: PatternSimulator[MBQCTensorNet], + pattern: Pattern, + backend: Literal["tensornetwork", "mps"] = ..., + prepare_method: PrepareMethod | None = None, + measure_method: MeasureMethod | None = None, + noise_model: NoiseModel | None = None, + branch_selector: BranchSelector | None = None, + graph_prep: str | None = None, + symbolic: bool = False, + ) -> None: ... + + @overload + def __init__( + self: PatternSimulator[_StateT], + pattern: Pattern, + backend: Backend[_StateT] = ..., + prepare_method: PrepareMethod | None = None, + measure_method: MeasureMethod | None = None, + noise_model: NoiseModel | None = None, + branch_selector: BranchSelector | None = None, + graph_prep: str | None = None, + symbolic: bool = False, + ) -> None: ... + + def __init__( + self: PatternSimulator[_StateT | _BuiltinBackendState], pattern: Pattern, - backend: Backend[_StateT_co] | _BackendLiteral = "statevector", + backend: Backend[_StateT] | _BackendLiteral = "statevector", prepare_method: PrepareMethod | None = None, measure_method: MeasureMethod | None = None, noise_model: NoiseModel | None = None, diff --git a/graphix/transpiler.py b/graphix/transpiler.py index df7bed7bc..019219eb4 100644 --- a/graphix/transpiler.py +++ b/graphix/transpiler.py @@ -6,16 +6,18 @@ from __future__ import annotations -import dataclasses +from dataclasses import dataclass from typing import TYPE_CHECKING, SupportsFloat -from typing_extensions import assert_never +# assert_never introduced in Python 3.11 +# override introduced in Python 3.12 +from typing_extensions import assert_never, override from graphix import command, instruction, parameter from graphix.branch_selector import BranchSelector, RandomBranchSelector from graphix.command import E, M, N, X, Z from graphix.fundamentals import ANGLE_PI, Axis -from graphix.instruction import Instruction, InstructionKind +from graphix.instruction import Instruction, InstructionKind, InstructionVisitor from graphix.measurements import Measurement, PauliMeasurement from graphix.ops import Ops from graphix.pattern import Pattern @@ -33,7 +35,7 @@ from graphix.sim.base_backend import Matrix -@dataclasses.dataclass +@dataclass class TranspileResult: """ The result of a transpilation. @@ -46,7 +48,7 @@ class TranspileResult: classical_outputs: tuple[int, ...] -@dataclasses.dataclass +@dataclass class SimulateResult: """ The result of a simulation. @@ -67,6 +69,15 @@ def _check_target(out: Sequence[int | None], index: int) -> int: return target +@dataclass +class _MapAngleVisitor(InstructionVisitor): + f: Callable[[ParameterizedAngle], ParameterizedAngle] + + @override + def visit_angle(self, angle: ParameterizedAngle) -> ParameterizedAngle: + return self.f(angle) + + class Circuit: """Gate-to-MBQC transpiler. @@ -357,6 +368,36 @@ def m(self, qubit: int, axis: Axis) -> None: def transpile(self) -> TranspileResult: """Transpile the circuit to a pattern. + Returns + ------- + result : :class:`TranspileResult` object + """ + from graphix_jcz_transpiler import transpile_jcz # noqa: PLC0415 + + transpiled_swaps = transpile_swaps(self) + result = transpile_jcz(transpiled_swaps.circuit) + # result = transpiled_swaps.circuit.old_transpile() + index_outputs: list[int | None] = [] + index = 0 + for qubit in transpiled_swaps.qubits: + if qubit is None: + index_outputs.append(None) + else: + index_outputs.append(index) + index += 1 + output_nodes = [] + for qubit in transpiled_swaps.qubits: + if qubit is not None: + index_output = index_outputs[qubit] + assert index_output is not None + output_nodes.append(result.pattern.output_nodes[index_output]) + result.pattern.reorder_output_nodes(output_nodes) + result.pattern = result.pattern.infer_pauli_measurements() + return result + + def old_transpile(self) -> TranspileResult: + """Transpile the circuit to a pattern. + Returns ------- result : :class:`TranspileResult` object @@ -995,18 +1036,17 @@ def evolve(op: Matrix, qargs: Iterable[int]) -> None: raise ValueError(f"Unknown instruction: {instr}") return SimulateResult(backend.state, tuple(classical_measures)) - def map_angle(self, f: Callable[[ParameterizedAngle], ParameterizedAngle]) -> Circuit: - """Apply `f` to all angles that occur in the circuit.""" + def visit(self, visitor: InstructionVisitor) -> Circuit: + """Apply `visitor` to all instructions in the circuit.""" result = Circuit(self.width) for instr in self.instruction: - match instr.kind: - case InstructionKind.RZZ | InstructionKind.RX | InstructionKind.RY | InstructionKind.RZ: - new_instr = dataclasses.replace(instr, angle=f(instr.angle)) - result.instruction.append(new_instr) - case _: - result.instruction.append(instr) + result.instruction.append(instr.visit(visitor)) return result + def map_angle(self, f: Callable[[ParameterizedAngle], ParameterizedAngle]) -> Circuit: + """Apply `f` to all angles that occur in the circuit.""" + return self.visit(_MapAngleVisitor(f)) + def is_parameterized(self) -> bool: """ Return `True` if there is at least one measurement angle that is not just an instance of `SupportsFloat`. @@ -1061,3 +1101,62 @@ def _transpile_rzz(instructions: Iterable[Instruction]) -> Iterator[Instruction] yield instruction.CNOT(control=instr.control, target=instr.target) else: yield instr + + +@dataclass(frozen=True) +class TranspileSwapsResult: + """The result returned by :func:`transpile_swaps`.""" + + circuit: Circuit + """Circuit without SWAP gates.""" + + qubits: tuple[int | None, ...] + """ + Tuple which has the same width as the circuit and which for + every qubit of the original circuit provides the index of the + corresponding qubit in the output of the returned + circuit. Measured qubits are mapped to ``None`` in this tuple. + """ + + +class _TranspileSwapVisitor(InstructionVisitor): + qubits: list[int | None] + + def __init__(self, width: int) -> None: + self.qubits = list(range(width)) + + @override + def visit_qubit(self, qubit: int) -> int: + return _check_target(self.qubits, qubit) + + +def transpile_swaps(circuit: Circuit) -> TranspileSwapsResult: + """Return a new circuit equivalent to the original one but without SWAP gates. + + Parameters + ---------- + circuit : Circuit + The original circuit + + Returns + ------- + TranspileSwapsResult + The field ``circuit`` contains an equivalent circuit without + SWAP gates. + The field ``qubits`` contains a tuple which has the same width + as the circuit and which for every qubit of the original + circuit provides the index of the corresponding qubit in the + output of the returned circuit. Measured qubits are mapped to + ``None`` in this tuple. + """ + new_circuit = Circuit(circuit.width) + visitor = _TranspileSwapVisitor(circuit.width) + for instr in circuit.instruction: + if instr.kind == InstructionKind.SWAP: + u, v = instr.targets + visitor.qubits[u], visitor.qubits[v] = visitor.qubits[v], visitor.qubits[u] + else: + new_circuit.add(instr.visit(visitor)) + if instr.kind == InstructionKind.M: + visitor.qubits[instr.target] = None + return TranspileSwapsResult(new_circuit, tuple(visitor.qubits)) diff --git a/pyproject.toml b/pyproject.toml index 674cf5acd..359dbe5b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Operating System :: OS Independent", "Topic :: Scientific/Engineering :: Physics", ] @@ -134,16 +135,6 @@ filterwarnings = ["ignore:Couldn't import `kahypar`"] [tool.mypy] # Keep in sync with pyright files = ["*.py", "examples", "graphix", "tests"] -exclude = [ - '^examples/deutsch_jozsa\.py$', - '^examples/ghz_with_tn\.py$', - '^examples/mbqc_vqe\.py$', - '^examples/qaoa\.py$', - '^examples/qft_with_tn\.py$', - '^examples/qnn\.py$', - '^examples/rotation\.py$', - '^examples/tn_simulation\.py$', -] follow_imports = "silent" follow_untyped_imports = true # required for qiskit, requires mypy >=1.14 strict = true @@ -152,21 +143,13 @@ mypy_path = "./stubs" [tool.pyright] # Keep in sync with mypy include = ["*.py", "examples", "graphix", "tests"] -exclude = [ - "examples/deutsch_jozsa.py", - "examples/ghz_with_tn.py", - "examples/mbqc_vqe.py", - "examples/qaoa.py", - "examples/qft_with_tn.py", - "examples/qnn.py", - "examples/rotation.py", - "examples/tn_simulation.py", -] reportUnknownArgumentType = "information" reportUnknownLambdaType = "information" reportUnknownMemberType = "information" reportUnknownParameterType = "information" reportUnknownVariableType = "information" +# reportInvalidTypeVarUse is a warning by default +reportInvalidTypeVarUse = "error" extraPaths = ["./stubs"] [tool.coverage.report] diff --git a/requirements.txt b/requirements.txt index 99c592fdd..a028709d6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ quimb scipy sympy typing_extensions +graphix_jcz_transpiler@git+https://github.com/thierry-martinez/graphix-jcz-transpiler.git@issues/2 diff --git a/stubs/cotengra/__init__.pyi b/stubs/cotengra/__init__.pyi new file mode 100644 index 000000000..f1795d2d2 --- /dev/null +++ b/stubs/cotengra/__init__.pyi @@ -0,0 +1,6 @@ +from cotengra.oe import PathOptimizer + +class HyperOptimizer(PathOptimizer): + def __init__( + self, minimize: str = "flops", reconf_opts: dict[None, None] | None = None, progbar: bool = False + ) -> None: ... diff --git a/stubs/cotengra/oe.pyi b/stubs/cotengra/oe.pyi new file mode 100644 index 000000000..a429ea966 --- /dev/null +++ b/stubs/cotengra/oe.pyi @@ -0,0 +1 @@ +class PathOptimizer: ... diff --git a/stubs/quimb/tensor/__init__.pyi b/stubs/quimb/tensor/__init__.pyi index 8366b5182..2fe95b21d 100644 --- a/stubs/quimb/tensor/__init__.pyi +++ b/stubs/quimb/tensor/__init__.pyi @@ -4,6 +4,7 @@ from typing import Literal import numpy as np import numpy.typing as npt from cotengra.oe import PathOptimizer +from matplotlib.axes import Axes from quimb import oset from typing_extensions import Self @@ -17,6 +18,8 @@ class Tensor: tags: Sequence[str] | None = None, left_inds: Sequence[str] | None = None, ) -> None: ... + def add_tag(self, tag: str) -> None: ... + def draw(self, legend: bool | str = "auto", title: str | None = None, ax: Axes | None = None) -> None: ... def reindex(self, retag_map: Mapping[str, str], inplace: bool = False) -> Tensor: ... def retag(self, retag_map: Mapping[str, str], inplace: bool = False) -> Tensor: ... def split( @@ -63,6 +66,18 @@ class TensorNetwork: optimize: str | PathOptimizer | None = None, ) -> float: ... def copy(self, virtual: bool = False, deep: bool = False) -> TensorNetwork: ... + def draw( + self, + color: list[str] | None = None, + *, + show_inds: bool | None = None, + show_tags: bool | None = None, + legend: bool | str = "auto", + iterations: int | Literal["auto"] = "auto", + k: float | None = None, + title: str | None = None, + ax: Axes | None = None, + ) -> None: ... def full_simplify( self, seq: str = "ADCR", @@ -83,3 +98,4 @@ class TensorNetwork: def tensors(self) -> tuple[Tensor, ...]: ... def rand_uuid(base: str = "") -> str: ... +def rand_tensor(shape: Sequence[int], inds: Sequence[str] | str) -> Tensor: ... diff --git a/tests/test_noisy_density_matrix.py b/tests/test_noisy_density_matrix.py index dc207931b..8e57b0f16 100644 --- a/tests/test_noisy_density_matrix.py +++ b/tests/test_noisy_density_matrix.py @@ -37,7 +37,9 @@ def hpat() -> Pattern: def rzpat(alpha: Angle) -> Pattern: circ = Circuit(1) circ.rz(0, alpha) - return circ.transpile().pattern + pattern = circ.transpile().pattern + pattern.standardize() + return pattern class TestNoisyDensityMatrixBackend: diff --git a/tests/test_pattern.py b/tests/test_pattern.py index 44375afd0..816dd01b5 100644 --- a/tests/test_pattern.py +++ b/tests/test_pattern.py @@ -167,7 +167,7 @@ def test_empty_output_nodes(self, backend_type: _BackendLiteral) -> None: pattern.add(M(0, Measurement.Y)) def simulate_and_measure() -> int: - sim: PatternSimulator[Statevec | DensityMatrix | MBQCTensorNet] = PatternSimulator(pattern, backend_type) + sim = PatternSimulator(pattern, backend_type) sim.run() state = sim.backend.state if isinstance(state, Statevec): diff --git a/tests/test_transpiler.py b/tests/test_transpiler.py index 23ef5ef74..37ca8f612 100644 --- a/tests/test_transpiler.py +++ b/tests/test_transpiler.py @@ -12,7 +12,7 @@ from graphix.instruction import InstructionKind from graphix.random_objects import rand_circuit, rand_gate, rand_state_vector from graphix.states import BasicStates -from graphix.transpiler import Circuit +from graphix.transpiler import Circuit, transpile_swaps from tests.test_branch_selector import CheckedBranchSelector if TYPE_CHECKING: @@ -140,6 +140,7 @@ def test_ccx(self, fx_bg: PCG64, jumps: int) -> None: nqubits = 4 depth = 6 circuit = rand_circuit(nqubits, depth, rng, use_ccx=True) + circuit = transpile_swaps(circuit).circuit pattern = circuit.transpile().pattern pattern.minimize_space() state = circuit.simulate_statevector(rng=rng).statevec @@ -257,3 +258,23 @@ def test_simple(self) -> None: state = circuit.simulate_statevector(input_state=input_state).statevec state_mbqc = pattern.simulate_pattern(input_state=input_state, rng=rng) assert state_mbqc.isclose(state) + + +@pytest.mark.parametrize("jumps", range(1, 11)) +def test_transple_swaps(fx_bg: PCG64, jumps: int) -> None: + rng = Generator(fx_bg.jumped(jumps)) + nqubits = 4 + depth = 6 + circuit = rand_circuit(nqubits, depth, rng, use_ccx=True) + assert any(instr.kind == InstructionKind.SWAP for instr in circuit.instruction) + transpiled_swaps = transpile_swaps(circuit) + circuit2 = transpiled_swaps.circuit + assert not any(instr.kind == InstructionKind.SWAP for instr in circuit2.instruction) + state = circuit.simulate_statevector(rng=rng).statevec + state2 = circuit2.simulate_statevector(rng=rng).statevec + qubits: list[int] = [] + for qubit in transpiled_swaps.qubits: + assert qubit is not None + qubits.append(qubit) + state2.psi = np.transpose(state2.psi, qubits) + assert state.isclose(state2)