diff --git a/CHANGELOG.md b/CHANGELOG.md index 092bf1a67..12fe5c58e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - #507: Static method `PauliString.from_measured_node` is subsumed by the function `extraction_ps_from_corrected_node`. +- #512: Method `Circuit.simulate_statevector` accepts a `backend: DenseStateBackend[_DenseStateT] | Literal["statevector", "densitymatrix"]` parameter. + ## [0.3.5] - 2026-03-26 ### Added diff --git a/graphix/transpiler.py b/graphix/transpiler.py index c009b62b0..7859d3f0f 100644 --- a/graphix/transpiler.py +++ b/graphix/transpiler.py @@ -7,7 +7,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, SupportsFloat +from typing import TYPE_CHECKING, Generic, SupportsFloat, TypeVar, overload # assert_never introduced in Python 3.11 # override introduced in Python 3.12 @@ -21,10 +21,13 @@ from graphix.measurements import Measurement, PauliMeasurement from graphix.ops import Ops from graphix.pattern import Pattern -from graphix.sim.statevec import Statevec, StatevectorBackend +from graphix.sim.base_backend import DenseStateBackend +from graphix.sim.density_matrix import DensityMatrixBackend +from graphix.sim.statevec import StatevectorBackend if TYPE_CHECKING: from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence + from typing import Literal from numpy.random import Generator @@ -33,10 +36,18 @@ from graphix.instruction import InstructionType, InstructionTypeWithoutRZZ from graphix.parameter import ExpressionOrFloat, Parameter from graphix.sim import Data - from graphix.sim.base_backend import Matrix + from graphix.sim.base_backend import DenseState, Matrix + from graphix.sim.density_matrix import DensityMatrix + from graphix.sim.statevec import Statevec + _BuiltinDenseStateBackend = DensityMatrixBackend | StatevectorBackend + _DenseStateBackendLiteral = Literal["statevector", "densitymatrix"] -@dataclass +_DenseStateT_co = TypeVar("_DenseStateT_co", bound="DenseState", covariant=True) +_DenseStateT = TypeVar("_DenseStateT", bound="DenseState") + + +@dataclass(frozen=True) class TranspileResult: """ The result of a transpilation. @@ -49,16 +60,18 @@ class TranspileResult: classical_outputs: tuple[int, ...] -@dataclass -class SimulateResult: +@dataclass(frozen=True) +class SimulateResult(Generic[_DenseStateT]): """ - The result of a simulation. + Result of a circuit simulation. - statevec : :class:`graphix.sim.statevec.Statevec` object - classical_measures : tuple[int,...], classical measures + statevec : _DenseStateT + State representation of the simulation output. + classical_measures : tuple[int,...] + Results of classical measurements. """ - statevec: Statevec + statevec: _DenseStateT # mypy rejects covariant types as dataclass parameters as of Python 3.13 classical_measures: tuple[int, ...] @@ -942,21 +955,60 @@ def _ccx_command( ) return ancilla[17], ancilla[15], ancilla[13], seq + @overload + def simulate_statevector( + self, + backend: StatevectorBackend | Literal["statevector"] = ..., + input_state: Data | None = None, + branch_selector: BranchSelector | None = None, + rng: Generator | None = None, + *, + stacklevel: int = 1, + ) -> SimulateResult[Statevec]: ... + + @overload + def simulate_statevector( + self, + backend: DensityMatrixBackend | Literal["densitymatrix"], + input_state: Data | None = None, + branch_selector: BranchSelector | None = None, + rng: Generator | None = None, + *, + stacklevel: int = 1, + ) -> SimulateResult[DensityMatrix]: ... + + @overload def simulate_statevector( self, + backend: DenseStateBackend[_DenseStateT], input_state: Data | None = None, branch_selector: BranchSelector | None = None, rng: Generator | None = None, *, stacklevel: int = 1, - ) -> SimulateResult: - """Run statevector simulation of the gate sequence. + ) -> SimulateResult[_DenseStateT]: ... + + def simulate_statevector( + self, + backend: DenseStateBackend[_DenseStateT] | _DenseStateBackendLiteral = "statevector", + input_state: Data | None = None, + branch_selector: BranchSelector | None = None, + rng: Generator | None = None, + *, + stacklevel: int = 1, + ) -> SimulateResult[_DenseStateT] | SimulateResult[_DenseStateT | Statevec | DensityMatrix]: + # `SimulateResult` is not covariant in `_DenseStateT` so `SimulateResult[_DenseStateT]` is not a subtype of `SimulateResult[_DenseStateT | Statevec | DensityMatrix]` + r"""Simulate the gate sequence with a backend and input state of choice. + + By default, this method uses the statevector backend and initializes the register to :math:`|+\rangle^{\otimes n}`. Parameters ---------- input_state : Data + backend: :class:`graphix.sim.base_backend.DenseStateBackend[_DenseStateT]`, 'statevector', or 'densitymatrix' + Simulator backend to use. Optional, defaults to "statevector". branch_selector: :class:`graphix.branch_selector.BranchSelector` - branch selector for measures (default: :class:`RandomBranchSelector`). + branch selector for measures (default: :class:`RandomBranchSelector`). It cannot be specified if ``backend`` is already instantiated. rng: Generator, optional Random-number generator for measurements. This generator is used only in case of random branch selection @@ -967,14 +1019,12 @@ def simulate_statevector( result : :class:`SimulateResult` output state of the statevector simulation and results of classical measures. """ - if branch_selector is None: - branch_selector = RandomBranchSelector() + _backend = _initialize_backend(backend, branch_selector) - backend = StatevectorBackend(branch_selector=branch_selector) if input_state is None: - backend.add_nodes(range(self.width)) + _backend.add_nodes(range(self.width)) else: - backend.add_nodes(range(self.width), input_state) + _backend.add_nodes(range(self.width), input_state) classical_measures = [] @@ -982,22 +1032,20 @@ def simulate_statevector( instr = self.instruction[i] def evolve_single(op: Matrix, target: int) -> None: - backend.state.evolve_single(op, backend.node_index.index(target)) + _backend.state.evolve_single(op, _backend.node_index.index(target)) def evolve(op: Matrix, qargs: Iterable[int]) -> None: - backend.state.evolve(op, [backend.node_index.index(qarg) for qarg in qargs]) + _backend.state.evolve(op, [_backend.node_index.index(qarg) for qarg in qargs]) match instr.kind: case instruction.InstructionKind.CNOT: - backend.state.cnot( - (backend.node_index.index(instr.control), backend.node_index.index(instr.target)) - ) + evolve(Ops.CNOT, [instr.control, instr.target]) case instruction.InstructionKind.SWAP: u, v = instr.targets - backend.state.swap((backend.node_index.index(u), backend.node_index.index(v))) + _backend.state.swap((_backend.node_index.index(u), _backend.node_index.index(v))) case instruction.InstructionKind.CZ: u, v = instr.targets - backend.state.entangle((backend.node_index.index(u), backend.node_index.index(v))) + _backend.state.entangle((_backend.node_index.index(u), _backend.node_index.index(v))) case instruction.InstructionKind.I: pass case instruction.InstructionKind.S: @@ -1021,13 +1069,13 @@ def evolve(op: Matrix, qargs: Iterable[int]) -> None: case instruction.InstructionKind.CCX: evolve(Ops.CCX, [instr.controls[0], instr.controls[1], instr.target]) case instruction.InstructionKind.M: - result = backend.measure( + result = _backend.measure( instr.target, PauliMeasurement(instr.axis), rng=rng, stacklevel=stacklevel + 1 ) classical_measures.append(result) case _: raise ValueError(f"Unknown instruction: {instr}") - return SimulateResult(backend.state, tuple(classical_measures)) + return SimulateResult(_backend.state, tuple(classical_measures)) def visit(self, visitor: InstructionVisitor) -> Circuit: """Apply `visitor` to all instructions in the circuit.""" @@ -1153,3 +1201,59 @@ def transpile_swaps(circuit: Circuit) -> TranspileSwapsResult: if instr.kind == InstructionKind.M: visitor.qubits[instr.target] = None return TranspileSwapsResult(new_circuit, tuple(visitor.qubits)) + + +@overload +def _initialize_backend( + backend: StatevectorBackend | Literal["statevector"], + branch_selector: BranchSelector | None, +) -> StatevectorBackend: ... + + +@overload +def _initialize_backend( + backend: DensityMatrixBackend | Literal["densitymatrix"], + branch_selector: BranchSelector | None, +) -> DensityMatrixBackend: ... + + +@overload +def _initialize_backend( + backend: DenseStateBackend[_DenseStateT_co], + branch_selector: BranchSelector | None, +) -> DenseStateBackend[_DenseStateT_co]: ... + + +def _initialize_backend( + backend: DenseStateBackend[_DenseStateT_co] | _DenseStateBackendLiteral, + branch_selector: BranchSelector | None, +) -> _BuiltinDenseStateBackend | DenseStateBackend[_DenseStateT_co]: + """Initialize backend for circuit simulation. + + Parameters + ---------- + backend: :class:`graphix.sim.base_backend.DenseStateBackend[_DenseStateT_co]`, 'statevector', or 'densitymatrix' + Simulation backend + branch_selector: :class:`BranchSelector` + Branch selector used for measurements. Can only be specified if ``backend`` is not an already instantiated :class:`Backend` object. If ``None``, it defaults to :class:`RandomBranchSelector`. + + Returns + ------- + :class:`DenseStateBackend` + matching the appropriate backend + """ + if isinstance(backend, DenseStateBackend): + if branch_selector is not None: + raise ValueError("`branch_selector` cannot be specified if `backend` is already instantiated.") + return backend + + if branch_selector is None: + branch_selector = RandomBranchSelector() + + match backend: + case "statevector": + return StatevectorBackend(branch_selector=branch_selector) + case "densitymatrix": + return DensityMatrixBackend(branch_selector=branch_selector) + case _: + raise ValueError(f"Unknown backend {backend}.") diff --git a/tests/test_transpiler.py b/tests/test_transpiler.py index d73f9133d..f848a5ed2 100644 --- a/tests/test_transpiler.py +++ b/tests/test_transpiler.py @@ -11,18 +11,22 @@ from graphix.fundamentals import ANGLE_PI, Axis, Sign from graphix.instruction import I, InstructionKind from graphix.random_objects import rand_circuit, rand_gate, rand_state_vector +from graphix.sim.density_matrix import DensityMatrix +from graphix.sim.statevec import Statevec from graphix.states import BasicStates from graphix.transpiler import Circuit, transpile_swaps from tests.test_branch_selector import CheckedBranchSelector if TYPE_CHECKING: from collections.abc import Callable - from typing import TypeAlias + from typing import Literal, TypeAlias from graphix.instruction import InstructionType from graphix.measurements import Outcome InstructionTestCase: TypeAlias = Callable[[Generator], InstructionType] + _DenseStateBackendLiteral = Literal["statevector", "densitymatrix"] + INSTRUCTION_TEST_CASES: list[InstructionTestCase] = [ lambda _rng: instruction.CCX(0, (1, 2)), @@ -156,20 +160,30 @@ def test_transpiled(self, fx_rng: Generator) -> None: state_mbqc = pattern.simulate_pattern(rng=fx_rng) assert state_mbqc.isclose(state) + @pytest.mark.parametrize("backend", ["statevector", "densitymatrix"]) @pytest.mark.parametrize("jumps", range(1, 11)) @pytest.mark.parametrize("axis", [Axis.X, Axis.Y, Axis.Z]) @pytest.mark.parametrize("outcome", [0, 1]) - def test_measure(self, fx_bg: PCG64, jumps: int, axis: Axis, outcome: Outcome) -> None: + def test_measure( + self, fx_bg: PCG64, jumps: int, axis: Axis, outcome: Outcome, backend: _DenseStateBackendLiteral + ) -> None: rng = Generator(fx_bg.jumped(jumps)) circuit = Circuit(2) circuit.cnot(0, 1) circuit.m(0, axis) input_state = rand_state_vector(2, rng=rng) branch_selector = ConstBranchSelector(outcome) - state = circuit.simulate_statevector(rng=rng, input_state=input_state, branch_selector=branch_selector).statevec + state = circuit.simulate_statevector( + rng=rng, input_state=input_state, branch_selector=branch_selector, backend=backend + ).statevec pattern = circuit.transpile().pattern - state_mbqc = pattern.simulate_pattern(rng=rng, input_state=input_state, branch_selector=branch_selector) - assert state_mbqc.isclose(state) + state_mbqc = pattern.simulate_pattern( + rng=rng, input_state=input_state, branch_selector=branch_selector, backend=backend + ) + if isinstance(state_mbqc, Statevec) and isinstance(state, Statevec): + assert state_mbqc.isclose(state) + elif isinstance(state_mbqc, DensityMatrix) and isinstance(state, DensityMatrix): + assert np.allclose(state_mbqc.rho, state.rho) @pytest.mark.parametrize("input_axis", [Axis.X, Axis.Y, Axis.Z]) @pytest.mark.parametrize("input_sign", [Sign.PLUS, Sign.MINUS]) @@ -261,6 +275,18 @@ def test_simple(self) -> None: state_mbqc = pattern.simulate_pattern(input_state=input_state, rng=rng) assert state_mbqc.isclose(state) + @pytest.mark.parametrize("jumps", range(1, 3)) + def test_dm_backend(self, fx_bg: PCG64, jumps: int) -> None: + nqubits = 2 + rng = Generator(fx_bg.jumped(jumps)) + circuit = rand_circuit(nqubits, 3, rng) + pattern = circuit.transpile().pattern + pattern.minimize_space() + input_state = rand_state_vector(nqubits, rng=rng) + state = circuit.simulate_statevector(input_state=input_state, backend="densitymatrix").statevec + state_mbqc = pattern.simulate_pattern(input_state=input_state, backend="densitymatrix", rng=rng) + assert np.allclose(state_mbqc.rho, state.rho) + @pytest.mark.parametrize("jumps", range(1, 11)) def test_transpile_swaps(fx_bg: PCG64, jumps: int) -> None: