From e00b6b0f25112bbdfb57e0357789c0bad84e6d89 Mon Sep 17 00:00:00 2001 From: matulni Date: Wed, 20 May 2026 14:19:23 +0200 Subject: [PATCH 01/10] Add parameter to Circuit.simulate_statevector --- graphix/transpiler.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/graphix/transpiler.py b/graphix/transpiler.py index c009b62b0..3f447a7b6 100644 --- a/graphix/transpiler.py +++ b/graphix/transpiler.py @@ -21,7 +21,7 @@ from graphix.measurements import Measurement, PauliMeasurement from graphix.ops import Ops from graphix.pattern import Pattern -from graphix.sim.statevec import Statevec, StatevectorBackend +from graphix.sim.statevec import StatevectorBackend if TYPE_CHECKING: from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence @@ -33,7 +33,8 @@ from graphix.instruction import InstructionType, InstructionTypeWithoutRZZ from graphix.parameter import ExpressionOrFloat, Parameter from graphix.sim import Data - from graphix.sim.base_backend import Matrix + from graphix.sim.base_backend import DenseStateBackend, Matrix + from graphix.sim.statevec import Statevec @dataclass @@ -945,6 +946,7 @@ def _ccx_command( def simulate_statevector( self, input_state: Data | None = None, + backend: DenseStateBackend[Statevec] | None = None, branch_selector: BranchSelector | None = None, rng: Generator | None = None, *, @@ -955,8 +957,10 @@ def simulate_statevector( Parameters ---------- input_state : Data + backend : :class:`graphix.sim.base_backend.DenseStateBackend[Statevec]` | None, optional + The simulator backend to use. If ``None`` (default), "statevector" backend is used. branch_selector: :class:`graphix.branch_selector.BranchSelector` - branch selector for measures (default: :class:`RandomBranchSelector`). + branch selector for measures (default: :class:`RandomBranchSelector`). It cannot be specified if ``backend`` is already instantiated. rng: Generator, optional Random-number generator for measurements. This generator is used only in case of random branch selection @@ -967,10 +971,13 @@ def simulate_statevector( result : :class:`SimulateResult` output state of the statevector simulation and results of classical measures. """ - if branch_selector is None: - branch_selector = RandomBranchSelector() - - backend = StatevectorBackend(branch_selector=branch_selector) + if backend is not None: + if branch_selector is not None: + raise ValueError("`branch_selector` cannot be specified if `backend` is already instantiated.") + else: + if branch_selector is None: + branch_selector = RandomBranchSelector() + backend = StatevectorBackend(branch_selector=branch_selector) if input_state is None: backend.add_nodes(range(self.width)) else: @@ -992,6 +999,7 @@ def evolve(op: Matrix, qargs: Iterable[int]) -> None: backend.state.cnot( (backend.node_index.index(instr.control), backend.node_index.index(instr.target)) ) + # evolve(Ops.CNOT, [instr.control, instr.target]) case instruction.InstructionKind.SWAP: u, v = instr.targets backend.state.swap((backend.node_index.index(u), backend.node_index.index(v))) From a4541ba65d906f018d05e9b35ad67d98a68a2c9a Mon Sep 17 00:00:00 2001 From: matulni Date: Wed, 20 May 2026 15:23:56 +0200 Subject: [PATCH 02/10] Add parameter to Circuit.simulate_statevector --- graphix/transpiler.py | 133 +++++++++++++++++++++++++++++++++--------- 1 file changed, 105 insertions(+), 28 deletions(-) diff --git a/graphix/transpiler.py b/graphix/transpiler.py index 3f447a7b6..b00baa033 100644 --- a/graphix/transpiler.py +++ b/graphix/transpiler.py @@ -7,7 +7,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, SupportsFloat +from typing import TYPE_CHECKING, Generic, Literal, SupportsFloat, TypeVar, cast, overload # assert_never introduced in Python 3.11 # override introduced in Python 3.12 @@ -21,7 +21,9 @@ from graphix.measurements import Measurement, PauliMeasurement from graphix.ops import Ops from graphix.pattern import Pattern -from graphix.sim.statevec import StatevectorBackend +from graphix.sim.base_backend import DenseStateBackend +from graphix.sim.density_matrix import DensityMatrix, DensityMatrixBackend +from graphix.sim.statevec import Statevec, StatevectorBackend if TYPE_CHECKING: from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence @@ -33,11 +35,16 @@ from graphix.instruction import InstructionType, InstructionTypeWithoutRZZ from graphix.parameter import ExpressionOrFloat, Parameter from graphix.sim import Data - from graphix.sim.base_backend import DenseStateBackend, Matrix - from graphix.sim.statevec import Statevec + from graphix.sim.base_backend import DenseState, Matrix -@dataclass +_BuiltinDenseStateBackend = DensityMatrixBackend | StatevectorBackend +_DenseStateBackendLiteral = Literal["statevector", "densitymatrix"] + +_DenseStateT_co = TypeVar("_DenseStateT_co", bound="DenseState", covariant=True) + + +@dataclass(frozen=True) class TranspileResult: """ The result of a transpilation. @@ -51,7 +58,7 @@ class TranspileResult: @dataclass -class SimulateResult: +class SimulateResult(Generic[_DenseStateT_co]): """ The result of a simulation. @@ -59,7 +66,7 @@ class SimulateResult: classical_measures : tuple[int,...], classical measures """ - statevec: Statevec + statevec: _DenseStateT_co classical_measures: tuple[int, ...] @@ -943,15 +950,48 @@ def _ccx_command( ) return ancilla[17], ancilla[15], ancilla[13], seq + @overload + def simulate_statevector( + self, + input_state: Data | None = None, + backend: StatevectorBackend | Literal["statevector"] = ..., + branch_selector: BranchSelector | None = None, + rng: Generator | None = None, + *, + stacklevel: int = 1, + ) -> SimulateResult[Statevec]: ... + + @overload + def simulate_statevector( + self, + input_state: Data | None = None, + backend: DensityMatrixBackend | Literal["densitymatrix"] = ..., + branch_selector: BranchSelector | None = None, + rng: Generator | None = None, + *, + stacklevel: int = 1, + ) -> SimulateResult[DensityMatrix]: ... + + @overload def simulate_statevector( self, input_state: Data | None = None, - backend: DenseStateBackend[Statevec] | None = None, + backend: DenseStateBackend[_DenseStateT_co] = ..., branch_selector: BranchSelector | None = None, rng: Generator | None = None, *, stacklevel: int = 1, - ) -> SimulateResult: + ) -> SimulateResult[_DenseStateT_co]: ... + + def simulate_statevector( + self, + input_state: Data | None = None, + backend: DenseStateBackend[_DenseStateT_co] | _DenseStateBackendLiteral = "statevector", + branch_selector: BranchSelector | None = None, + rng: Generator | None = None, + *, + stacklevel: int = 1, + ) -> SimulateResult[_DenseStateT_co]: """Run statevector simulation of the gate sequence. Parameters @@ -971,17 +1011,12 @@ def simulate_statevector( result : :class:`SimulateResult` output state of the statevector simulation and results of classical measures. """ - if backend is not None: - if branch_selector is not None: - raise ValueError("`branch_selector` cannot be specified if `backend` is already instantiated.") - else: - if branch_selector is None: - branch_selector = RandomBranchSelector() - backend = StatevectorBackend(branch_selector=branch_selector) + backend_ = _initialize_backend(backend, branch_selector) + if input_state is None: - backend.add_nodes(range(self.width)) + backend_.add_nodes(range(self.width)) else: - backend.add_nodes(range(self.width), input_state) + backend_.add_nodes(range(self.width), input_state) classical_measures = [] @@ -989,23 +1024,23 @@ def simulate_statevector( instr = self.instruction[i] def evolve_single(op: Matrix, target: int) -> None: - backend.state.evolve_single(op, backend.node_index.index(target)) + backend_.state.evolve_single(op, backend_.node_index.index(target)) def evolve(op: Matrix, qargs: Iterable[int]) -> None: - backend.state.evolve(op, [backend.node_index.index(qarg) for qarg in qargs]) + backend_.state.evolve(op, [backend_.node_index.index(qarg) for qarg in qargs]) match instr.kind: case instruction.InstructionKind.CNOT: - backend.state.cnot( - (backend.node_index.index(instr.control), backend.node_index.index(instr.target)) - ) - # evolve(Ops.CNOT, [instr.control, instr.target]) + evolve(Ops.CNOT, [instr.control, instr.target]) + # backend.state.cnot( + # (backend.node_index.index(instr.control), backend.node_index.index(instr.target)) + # ) case instruction.InstructionKind.SWAP: u, v = instr.targets - backend.state.swap((backend.node_index.index(u), backend.node_index.index(v))) + backend_.state.swap((backend_.node_index.index(u), backend_.node_index.index(v))) case instruction.InstructionKind.CZ: u, v = instr.targets - backend.state.entangle((backend.node_index.index(u), backend.node_index.index(v))) + backend_.state.entangle((backend_.node_index.index(u), backend_.node_index.index(v))) case instruction.InstructionKind.I: pass case instruction.InstructionKind.S: @@ -1029,13 +1064,13 @@ def evolve(op: Matrix, qargs: Iterable[int]) -> None: case instruction.InstructionKind.CCX: evolve(Ops.CCX, [instr.controls[0], instr.controls[1], instr.target]) case instruction.InstructionKind.M: - result = backend.measure( + result = backend_.measure( instr.target, PauliMeasurement(instr.axis), rng=rng, stacklevel=stacklevel + 1 ) classical_measures.append(result) case _: raise ValueError(f"Unknown instruction: {instr}") - return SimulateResult(backend.state, tuple(classical_measures)) + return SimulateResult(cast("_DenseStateT_co", backend_.state), tuple(classical_measures)) def visit(self, visitor: InstructionVisitor) -> Circuit: """Apply `visitor` to all instructions in the circuit.""" @@ -1161,3 +1196,45 @@ def transpile_swaps(circuit: Circuit) -> TranspileSwapsResult: if instr.kind == InstructionKind.M: visitor.qubits[instr.target] = None return TranspileSwapsResult(new_circuit, tuple(visitor.qubits)) + + +@overload +def _initialize_backend( + backend: StatevectorBackend | Literal["statevector"], + branch_selector: BranchSelector | None, +) -> StatevectorBackend: ... + + +@overload +def _initialize_backend( + backend: DensityMatrixBackend | Literal["densitymatrix"], + branch_selector: BranchSelector | None, +) -> DensityMatrixBackend: ... + + +@overload +def _initialize_backend( + backend: DenseStateBackend[_DenseStateT_co], + branch_selector: BranchSelector | None, +) -> DenseStateBackend[_DenseStateT_co]: ... + + +def _initialize_backend( + backend: DenseStateBackend[_DenseStateT_co] | _DenseStateBackendLiteral, + branch_selector: BranchSelector | None, +) -> _BuiltinDenseStateBackend | DenseStateBackend[_DenseStateT_co]: + if isinstance(backend, DenseStateBackend): + if branch_selector is not None: + raise ValueError("`branch_selector` cannot be specified if `backend` is already instantiated.") + return backend + + if branch_selector is None: + branch_selector = RandomBranchSelector() + + match backend: + case "statevector": + return StatevectorBackend(branch_selector=branch_selector) + case "densitymatrix": + return DensityMatrixBackend(branch_selector=branch_selector) + case _: + raise ValueError(f"Unknown backend {backend}.") From 250a2a413525e5d992596b9d69630cbedf0965ac Mon Sep 17 00:00:00 2001 From: matulni Date: Wed, 20 May 2026 15:40:45 +0200 Subject: [PATCH 03/10] Fix pyright error --- graphix/transpiler.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/graphix/transpiler.py b/graphix/transpiler.py index b00baa033..d883fcf14 100644 --- a/graphix/transpiler.py +++ b/graphix/transpiler.py @@ -57,7 +57,7 @@ class TranspileResult: classical_outputs: tuple[int, ...] -@dataclass +@dataclass(frozen=True) class SimulateResult(Generic[_DenseStateT_co]): """ The result of a simulation. @@ -953,8 +953,8 @@ def _ccx_command( @overload def simulate_statevector( self, - input_state: Data | None = None, backend: StatevectorBackend | Literal["statevector"] = ..., + input_state: Data | None = None, branch_selector: BranchSelector | None = None, rng: Generator | None = None, *, @@ -964,8 +964,8 @@ def simulate_statevector( @overload def simulate_statevector( self, + backend: DensityMatrixBackend | Literal["densitymatrix"], input_state: Data | None = None, - backend: DensityMatrixBackend | Literal["densitymatrix"] = ..., branch_selector: BranchSelector | None = None, rng: Generator | None = None, *, @@ -975,8 +975,8 @@ def simulate_statevector( @overload def simulate_statevector( self, + backend: DenseStateBackend[_DenseStateT_co], input_state: Data | None = None, - backend: DenseStateBackend[_DenseStateT_co] = ..., branch_selector: BranchSelector | None = None, rng: Generator | None = None, *, @@ -985,8 +985,8 @@ def simulate_statevector( def simulate_statevector( self, - input_state: Data | None = None, backend: DenseStateBackend[_DenseStateT_co] | _DenseStateBackendLiteral = "statevector", + input_state: Data | None = None, branch_selector: BranchSelector | None = None, rng: Generator | None = None, *, From 1fba9aaeb35db2384cd487c84c2fccc040e57f82 Mon Sep 17 00:00:00 2001 From: matulni Date: Wed, 20 May 2026 16:17:28 +0200 Subject: [PATCH 04/10] Add tests --- tests/test_transpiler.py | 36 +++++++++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/tests/test_transpiler.py b/tests/test_transpiler.py index d73f9133d..f848a5ed2 100644 --- a/tests/test_transpiler.py +++ b/tests/test_transpiler.py @@ -11,18 +11,22 @@ from graphix.fundamentals import ANGLE_PI, Axis, Sign from graphix.instruction import I, InstructionKind from graphix.random_objects import rand_circuit, rand_gate, rand_state_vector +from graphix.sim.density_matrix import DensityMatrix +from graphix.sim.statevec import Statevec from graphix.states import BasicStates from graphix.transpiler import Circuit, transpile_swaps from tests.test_branch_selector import CheckedBranchSelector if TYPE_CHECKING: from collections.abc import Callable - from typing import TypeAlias + from typing import Literal, TypeAlias from graphix.instruction import InstructionType from graphix.measurements import Outcome InstructionTestCase: TypeAlias = Callable[[Generator], InstructionType] + _DenseStateBackendLiteral = Literal["statevector", "densitymatrix"] + INSTRUCTION_TEST_CASES: list[InstructionTestCase] = [ lambda _rng: instruction.CCX(0, (1, 2)), @@ -156,20 +160,30 @@ def test_transpiled(self, fx_rng: Generator) -> None: state_mbqc = pattern.simulate_pattern(rng=fx_rng) assert state_mbqc.isclose(state) + @pytest.mark.parametrize("backend", ["statevector", "densitymatrix"]) @pytest.mark.parametrize("jumps", range(1, 11)) @pytest.mark.parametrize("axis", [Axis.X, Axis.Y, Axis.Z]) @pytest.mark.parametrize("outcome", [0, 1]) - def test_measure(self, fx_bg: PCG64, jumps: int, axis: Axis, outcome: Outcome) -> None: + def test_measure( + self, fx_bg: PCG64, jumps: int, axis: Axis, outcome: Outcome, backend: _DenseStateBackendLiteral + ) -> None: rng = Generator(fx_bg.jumped(jumps)) circuit = Circuit(2) circuit.cnot(0, 1) circuit.m(0, axis) input_state = rand_state_vector(2, rng=rng) branch_selector = ConstBranchSelector(outcome) - state = circuit.simulate_statevector(rng=rng, input_state=input_state, branch_selector=branch_selector).statevec + state = circuit.simulate_statevector( + rng=rng, input_state=input_state, branch_selector=branch_selector, backend=backend + ).statevec pattern = circuit.transpile().pattern - state_mbqc = pattern.simulate_pattern(rng=rng, input_state=input_state, branch_selector=branch_selector) - assert state_mbqc.isclose(state) + state_mbqc = pattern.simulate_pattern( + rng=rng, input_state=input_state, branch_selector=branch_selector, backend=backend + ) + if isinstance(state_mbqc, Statevec) and isinstance(state, Statevec): + assert state_mbqc.isclose(state) + elif isinstance(state_mbqc, DensityMatrix) and isinstance(state, DensityMatrix): + assert np.allclose(state_mbqc.rho, state.rho) @pytest.mark.parametrize("input_axis", [Axis.X, Axis.Y, Axis.Z]) @pytest.mark.parametrize("input_sign", [Sign.PLUS, Sign.MINUS]) @@ -261,6 +275,18 @@ def test_simple(self) -> None: state_mbqc = pattern.simulate_pattern(input_state=input_state, rng=rng) assert state_mbqc.isclose(state) + @pytest.mark.parametrize("jumps", range(1, 3)) + def test_dm_backend(self, fx_bg: PCG64, jumps: int) -> None: + nqubits = 2 + rng = Generator(fx_bg.jumped(jumps)) + circuit = rand_circuit(nqubits, 3, rng) + pattern = circuit.transpile().pattern + pattern.minimize_space() + input_state = rand_state_vector(nqubits, rng=rng) + state = circuit.simulate_statevector(input_state=input_state, backend="densitymatrix").statevec + state_mbqc = pattern.simulate_pattern(input_state=input_state, backend="densitymatrix", rng=rng) + assert np.allclose(state_mbqc.rho, state.rho) + @pytest.mark.parametrize("jumps", range(1, 11)) def test_transpile_swaps(fx_bg: PCG64, jumps: int) -> None: From 4b390f5823175b0525b2d82b08f3c2c3657d6fb5 Mon Sep 17 00:00:00 2001 From: matulni Date: Wed, 20 May 2026 16:30:15 +0200 Subject: [PATCH 05/10] Up docs --- graphix/transpiler.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/graphix/transpiler.py b/graphix/transpiler.py index d883fcf14..8865e7dd5 100644 --- a/graphix/transpiler.py +++ b/graphix/transpiler.py @@ -997,8 +997,8 @@ def simulate_statevector( Parameters ---------- input_state : Data - backend : :class:`graphix.sim.base_backend.DenseStateBackend[Statevec]` | None, optional - The simulator backend to use. If ``None`` (default), "statevector" backend is used. + backend: :class:`graphix.sim.base_backend.DenseStateBackend[_DenseStateT_co]`, 'statevector', or 'densitymatrix' + Simulator backend to use. Optional, defaults to "statevector". branch_selector: :class:`graphix.branch_selector.BranchSelector` branch selector for measures (default: :class:`RandomBranchSelector`). It cannot be specified if ``backend`` is already instantiated. rng: Generator, optional @@ -1223,6 +1223,20 @@ def _initialize_backend( backend: DenseStateBackend[_DenseStateT_co] | _DenseStateBackendLiteral, branch_selector: BranchSelector | None, ) -> _BuiltinDenseStateBackend | DenseStateBackend[_DenseStateT_co]: + """Initialize backend for circuit simulation. + + Parameters + ---------- + backend: :class:`graphix.sim.base_backend.DenseStateBackend[_DenseStateT_co]`, 'statevector', or 'densitymatrix' + Simulation backend + branch_selector: :class:`BranchSelector` + Branch selector used for measurements. Can only be specified if ``backend`` is not an already instantiated :class:`Backend` object. If ``None``, it defaults to :class:`RandomBranchSelector`. + + Returns + ------- + :class:`DenseStateBackend` + matching the appropriate backend + """ if isinstance(backend, DenseStateBackend): if branch_selector is not None: raise ValueError("`branch_selector` cannot be specified if `backend` is already instantiated.") From e69ec49a73c4843f356768bd0ff2bcae39ec75e5 Mon Sep 17 00:00:00 2001 From: matulni Date: Wed, 20 May 2026 16:46:35 +0200 Subject: [PATCH 06/10] Move imports to typechecking block --- graphix/transpiler.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/graphix/transpiler.py b/graphix/transpiler.py index 8865e7dd5..b896c0045 100644 --- a/graphix/transpiler.py +++ b/graphix/transpiler.py @@ -7,7 +7,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Generic, Literal, SupportsFloat, TypeVar, cast, overload +from typing import TYPE_CHECKING, Generic, SupportsFloat, TypeVar, cast, overload # assert_never introduced in Python 3.11 # override introduced in Python 3.12 @@ -22,11 +22,12 @@ from graphix.ops import Ops from graphix.pattern import Pattern from graphix.sim.base_backend import DenseStateBackend -from graphix.sim.density_matrix import DensityMatrix, DensityMatrixBackend -from graphix.sim.statevec import Statevec, StatevectorBackend +from graphix.sim.density_matrix import DensityMatrixBackend +from graphix.sim.statevec import StatevectorBackend if TYPE_CHECKING: from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence + from typing import Literal from numpy.random import Generator @@ -36,10 +37,11 @@ from graphix.parameter import ExpressionOrFloat, Parameter from graphix.sim import Data from graphix.sim.base_backend import DenseState, Matrix + from graphix.sim.density_matrix import DensityMatrix + from graphix.sim.statevec import Statevec - -_BuiltinDenseStateBackend = DensityMatrixBackend | StatevectorBackend -_DenseStateBackendLiteral = Literal["statevector", "densitymatrix"] + _BuiltinDenseStateBackend = DensityMatrixBackend | StatevectorBackend + _DenseStateBackendLiteral = Literal["statevector", "densitymatrix"] _DenseStateT_co = TypeVar("_DenseStateT_co", bound="DenseState", covariant=True) From cfcfc968073616604de887158a1565c9bd7778fc Mon Sep 17 00:00:00 2001 From: matulni Date: Wed, 20 May 2026 16:49:24 +0200 Subject: [PATCH 07/10] Remove comment --- graphix/transpiler.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/graphix/transpiler.py b/graphix/transpiler.py index b896c0045..b403375c6 100644 --- a/graphix/transpiler.py +++ b/graphix/transpiler.py @@ -1013,12 +1013,12 @@ def simulate_statevector( result : :class:`SimulateResult` output state of the statevector simulation and results of classical measures. """ - backend_ = _initialize_backend(backend, branch_selector) + _backend = _initialize_backend(backend, branch_selector) if input_state is None: - backend_.add_nodes(range(self.width)) + _backend.add_nodes(range(self.width)) else: - backend_.add_nodes(range(self.width), input_state) + _backend.add_nodes(range(self.width), input_state) classical_measures = [] @@ -1026,23 +1026,20 @@ def simulate_statevector( instr = self.instruction[i] def evolve_single(op: Matrix, target: int) -> None: - backend_.state.evolve_single(op, backend_.node_index.index(target)) + _backend.state.evolve_single(op, _backend.node_index.index(target)) def evolve(op: Matrix, qargs: Iterable[int]) -> None: - backend_.state.evolve(op, [backend_.node_index.index(qarg) for qarg in qargs]) + _backend.state.evolve(op, [_backend.node_index.index(qarg) for qarg in qargs]) match instr.kind: case instruction.InstructionKind.CNOT: evolve(Ops.CNOT, [instr.control, instr.target]) - # backend.state.cnot( - # (backend.node_index.index(instr.control), backend.node_index.index(instr.target)) - # ) case instruction.InstructionKind.SWAP: u, v = instr.targets - backend_.state.swap((backend_.node_index.index(u), backend_.node_index.index(v))) + _backend.state.swap((_backend.node_index.index(u), _backend.node_index.index(v))) case instruction.InstructionKind.CZ: u, v = instr.targets - backend_.state.entangle((backend_.node_index.index(u), backend_.node_index.index(v))) + _backend.state.entangle((_backend.node_index.index(u), _backend.node_index.index(v))) case instruction.InstructionKind.I: pass case instruction.InstructionKind.S: @@ -1066,13 +1063,13 @@ def evolve(op: Matrix, qargs: Iterable[int]) -> None: case instruction.InstructionKind.CCX: evolve(Ops.CCX, [instr.controls[0], instr.controls[1], instr.target]) case instruction.InstructionKind.M: - result = backend_.measure( + result = _backend.measure( instr.target, PauliMeasurement(instr.axis), rng=rng, stacklevel=stacklevel + 1 ) classical_measures.append(result) case _: raise ValueError(f"Unknown instruction: {instr}") - return SimulateResult(cast("_DenseStateT_co", backend_.state), tuple(classical_measures)) + return SimulateResult(cast("_DenseStateT_co", _backend.state), tuple(classical_measures)) def visit(self, visitor: InstructionVisitor) -> Circuit: """Apply `visitor` to all instructions in the circuit.""" From 173783074499d0ad278a886aff28bd68eabb5b71 Mon Sep 17 00:00:00 2001 From: matulni Date: Wed, 27 May 2026 09:49:26 +0200 Subject: [PATCH 08/10] Add typing suggestions by Thierry --- graphix/transpiler.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/graphix/transpiler.py b/graphix/transpiler.py index b403375c6..7859d3f0f 100644 --- a/graphix/transpiler.py +++ b/graphix/transpiler.py @@ -7,7 +7,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Generic, SupportsFloat, TypeVar, cast, overload +from typing import TYPE_CHECKING, Generic, SupportsFloat, TypeVar, overload # assert_never introduced in Python 3.11 # override introduced in Python 3.12 @@ -44,6 +44,7 @@ _DenseStateBackendLiteral = Literal["statevector", "densitymatrix"] _DenseStateT_co = TypeVar("_DenseStateT_co", bound="DenseState", covariant=True) +_DenseStateT = TypeVar("_DenseStateT", bound="DenseState") @dataclass(frozen=True) @@ -60,15 +61,17 @@ class TranspileResult: @dataclass(frozen=True) -class SimulateResult(Generic[_DenseStateT_co]): +class SimulateResult(Generic[_DenseStateT]): """ - The result of a simulation. + Result of a circuit simulation. - statevec : :class:`graphix.sim.statevec.Statevec` object - classical_measures : tuple[int,...], classical measures + statevec : _DenseStateT + State representation of the simulation output. + classical_measures : tuple[int,...] + Results of classical measurements. """ - statevec: _DenseStateT_co + statevec: _DenseStateT # mypy rejects covariant types as dataclass parameters as of Python 3.13 classical_measures: tuple[int, ...] @@ -977,29 +980,32 @@ def simulate_statevector( @overload def simulate_statevector( self, - backend: DenseStateBackend[_DenseStateT_co], + backend: DenseStateBackend[_DenseStateT], input_state: Data | None = None, branch_selector: BranchSelector | None = None, rng: Generator | None = None, *, stacklevel: int = 1, - ) -> SimulateResult[_DenseStateT_co]: ... + ) -> SimulateResult[_DenseStateT]: ... def simulate_statevector( self, - backend: DenseStateBackend[_DenseStateT_co] | _DenseStateBackendLiteral = "statevector", + backend: DenseStateBackend[_DenseStateT] | _DenseStateBackendLiteral = "statevector", input_state: Data | None = None, branch_selector: BranchSelector | None = None, rng: Generator | None = None, *, stacklevel: int = 1, - ) -> SimulateResult[_DenseStateT_co]: - """Run statevector simulation of the gate sequence. + ) -> SimulateResult[_DenseStateT] | SimulateResult[_DenseStateT | Statevec | DensityMatrix]: + # `SimulateResult` is not covariant in `_DenseStateT` so `SimulateResult[_DenseStateT]` is not a subtype of `SimulateResult[_DenseStateT | Statevec | DensityMatrix]` + r"""Simulate the gate sequence with a backend and input state of choice. + + By default, this method uses the statevector backend and initializes the register to :math:`|+\rangle^{\otimes n}`. Parameters ---------- input_state : Data - backend: :class:`graphix.sim.base_backend.DenseStateBackend[_DenseStateT_co]`, 'statevector', or 'densitymatrix' + backend: :class:`graphix.sim.base_backend.DenseStateBackend[_DenseStateT]`, 'statevector', or 'densitymatrix' Simulator backend to use. Optional, defaults to "statevector". branch_selector: :class:`graphix.branch_selector.BranchSelector` branch selector for measures (default: :class:`RandomBranchSelector`). It cannot be specified if ``backend`` is already instantiated. @@ -1069,7 +1075,7 @@ def evolve(op: Matrix, qargs: Iterable[int]) -> None: classical_measures.append(result) case _: raise ValueError(f"Unknown instruction: {instr}") - return SimulateResult(cast("_DenseStateT_co", _backend.state), tuple(classical_measures)) + return SimulateResult(_backend.state, tuple(classical_measures)) def visit(self, visitor: InstructionVisitor) -> Circuit: """Apply `visitor` to all instructions in the circuit.""" From 14e1299f4b3ede78c19dfd879081832cef97b00f Mon Sep 17 00:00:00 2001 From: matulni Date: Wed, 27 May 2026 09:54:27 +0200 Subject: [PATCH 09/10] Up CHANGELOG --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 092bf1a67..8e2946ae5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - #507: Static method `PauliString.from_measured_node` is subsumed by the function `extraction_ps_from_corrected_node`. +- #512: Method `Circuit.simulate_statevector` accepts a `backend: DenseStateBackend[_DenseStateT]|"statevector" | "densitymatrix"` parameter. + ## [0.3.5] - 2026-03-26 ### Added From 4833f5720844c3f847ee16b64666bb58185fe450 Mon Sep 17 00:00:00 2001 From: matulni Date: Thu, 28 May 2026 17:46:50 +0200 Subject: [PATCH 10/10] Update CHANGELOG.md Co-authored-by: thierry-martinez --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e2946ae5..12fe5c58e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,7 +51,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - #507: Static method `PauliString.from_measured_node` is subsumed by the function `extraction_ps_from_corrected_node`. -- #512: Method `Circuit.simulate_statevector` accepts a `backend: DenseStateBackend[_DenseStateT]|"statevector" | "densitymatrix"` parameter. +- #512: Method `Circuit.simulate_statevector` accepts a `backend: DenseStateBackend[_DenseStateT] | Literal["statevector", "densitymatrix"]` parameter. ## [0.3.5] - 2026-03-26