Skip to content
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- #507: Static method `PauliString.from_measured_node` is subsumed by the function `extraction_ps_from_corrected_node`.

- #512: Method `Circuit.simulate_statevector` accepts a `backend: DenseStateBackend[_DenseStateT] | Literal["statevector", "densitymatrix"]` parameter.

## [0.3.5] - 2026-03-26

### Added
Expand Down
158 changes: 131 additions & 27 deletions graphix/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, SupportsFloat
from typing import TYPE_CHECKING, Generic, SupportsFloat, TypeVar, overload

# assert_never introduced in Python 3.11
# override introduced in Python 3.12
Expand All @@ -21,10 +21,13 @@
from graphix.measurements import Measurement, PauliMeasurement
from graphix.ops import Ops
from graphix.pattern import Pattern
from graphix.sim.statevec import Statevec, StatevectorBackend
from graphix.sim.base_backend import DenseStateBackend
from graphix.sim.density_matrix import DensityMatrixBackend
from graphix.sim.statevec import StatevectorBackend

if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from typing import Literal

from numpy.random import Generator

Expand All @@ -33,10 +36,18 @@
from graphix.instruction import InstructionType, InstructionTypeWithoutRZZ
from graphix.parameter import ExpressionOrFloat, Parameter
from graphix.sim import Data
from graphix.sim.base_backend import Matrix
from graphix.sim.base_backend import DenseState, Matrix
from graphix.sim.density_matrix import DensityMatrix
from graphix.sim.statevec import Statevec

_BuiltinDenseStateBackend = DensityMatrixBackend | StatevectorBackend
_DenseStateBackendLiteral = Literal["statevector", "densitymatrix"]

@dataclass
_DenseStateT_co = TypeVar("_DenseStateT_co", bound="DenseState", covariant=True)
_DenseStateT = TypeVar("_DenseStateT", bound="DenseState")


@dataclass(frozen=True)
class TranspileResult:
"""
The result of a transpilation.
Expand All @@ -49,16 +60,18 @@ class TranspileResult:
classical_outputs: tuple[int, ...]


@dataclass
class SimulateResult:
@dataclass(frozen=True)
class SimulateResult(Generic[_DenseStateT]):
"""
The result of a simulation.
Result of a circuit simulation.

statevec : :class:`graphix.sim.statevec.Statevec` object
classical_measures : tuple[int,...], classical measures
statevec : _DenseStateT
State representation of the simulation output.
classical_measures : tuple[int,...]
Results of classical measurements.
"""

statevec: Statevec
statevec: _DenseStateT # mypy rejects covariant types as dataclass parameters as of Python 3.13
classical_measures: tuple[int, ...]


Expand Down Expand Up @@ -942,21 +955,60 @@ def _ccx_command(
)
return ancilla[17], ancilla[15], ancilla[13], seq

@overload
def simulate_statevector(
self,
backend: StatevectorBackend | Literal["statevector"] = ...,
input_state: Data | None = None,
branch_selector: BranchSelector | None = None,
rng: Generator | None = None,
*,
stacklevel: int = 1,
) -> SimulateResult[Statevec]: ...

@overload
def simulate_statevector(
self,
backend: DensityMatrixBackend | Literal["densitymatrix"],
input_state: Data | None = None,
branch_selector: BranchSelector | None = None,
rng: Generator | None = None,
*,
stacklevel: int = 1,
) -> SimulateResult[DensityMatrix]: ...

@overload
def simulate_statevector(
self,
backend: DenseStateBackend[_DenseStateT],
input_state: Data | None = None,
branch_selector: BranchSelector | None = None,
rng: Generator | None = None,
*,
stacklevel: int = 1,
) -> SimulateResult:
"""Run statevector simulation of the gate sequence.
) -> SimulateResult[_DenseStateT]: ...

def simulate_statevector(
self,
backend: DenseStateBackend[_DenseStateT] | _DenseStateBackendLiteral = "statevector",
input_state: Data | None = None,
branch_selector: BranchSelector | None = None,
rng: Generator | None = None,
*,
stacklevel: int = 1,
) -> SimulateResult[_DenseStateT] | SimulateResult[_DenseStateT | Statevec | DensityMatrix]:
# `SimulateResult` is not covariant in `_DenseStateT` so `SimulateResult[_DenseStateT]` is not a subtype of `SimulateResult[_DenseStateT | Statevec | DensityMatrix]`
r"""Simulate the gate sequence with a backend and input state of choice.

By default, this method uses the statevector backend and initializes the register to :math:`|+\rangle^{\otimes n}`.

Parameters
----------
input_state : Data
backend: :class:`graphix.sim.base_backend.DenseStateBackend[_DenseStateT]`, 'statevector', or 'densitymatrix'
Simulator backend to use. Optional, defaults to "statevector".
branch_selector: :class:`graphix.branch_selector.BranchSelector`
branch selector for measures (default: :class:`RandomBranchSelector`).
branch selector for measures (default: :class:`RandomBranchSelector`). It cannot be specified if ``backend`` is already instantiated.
rng: Generator, optional
Random-number generator for measurements.
This generator is used only in case of random branch selection
Expand All @@ -967,37 +1019,33 @@ def simulate_statevector(
result : :class:`SimulateResult`
output state of the statevector simulation and results of classical measures.
"""
if branch_selector is None:
branch_selector = RandomBranchSelector()
_backend = _initialize_backend(backend, branch_selector)

backend = StatevectorBackend(branch_selector=branch_selector)
if input_state is None:
backend.add_nodes(range(self.width))
_backend.add_nodes(range(self.width))
else:
backend.add_nodes(range(self.width), input_state)
_backend.add_nodes(range(self.width), input_state)

classical_measures = []

for i in range(len(self.instruction)):
instr = self.instruction[i]

def evolve_single(op: Matrix, target: int) -> None:
backend.state.evolve_single(op, backend.node_index.index(target))
_backend.state.evolve_single(op, _backend.node_index.index(target))

def evolve(op: Matrix, qargs: Iterable[int]) -> None:
backend.state.evolve(op, [backend.node_index.index(qarg) for qarg in qargs])
_backend.state.evolve(op, [_backend.node_index.index(qarg) for qarg in qargs])

match instr.kind:
case instruction.InstructionKind.CNOT:
backend.state.cnot(
(backend.node_index.index(instr.control), backend.node_index.index(instr.target))
)
evolve(Ops.CNOT, [instr.control, instr.target])
case instruction.InstructionKind.SWAP:
u, v = instr.targets
backend.state.swap((backend.node_index.index(u), backend.node_index.index(v)))
_backend.state.swap((_backend.node_index.index(u), _backend.node_index.index(v)))
case instruction.InstructionKind.CZ:
u, v = instr.targets
backend.state.entangle((backend.node_index.index(u), backend.node_index.index(v)))
_backend.state.entangle((_backend.node_index.index(u), _backend.node_index.index(v)))
case instruction.InstructionKind.I:
pass
case instruction.InstructionKind.S:
Expand All @@ -1021,13 +1069,13 @@ def evolve(op: Matrix, qargs: Iterable[int]) -> None:
case instruction.InstructionKind.CCX:
evolve(Ops.CCX, [instr.controls[0], instr.controls[1], instr.target])
case instruction.InstructionKind.M:
result = backend.measure(
result = _backend.measure(
instr.target, PauliMeasurement(instr.axis), rng=rng, stacklevel=stacklevel + 1
)
classical_measures.append(result)
case _:
raise ValueError(f"Unknown instruction: {instr}")
return SimulateResult(backend.state, tuple(classical_measures))
return SimulateResult(_backend.state, tuple(classical_measures))

def visit(self, visitor: InstructionVisitor) -> Circuit:
"""Apply `visitor` to all instructions in the circuit."""
Expand Down Expand Up @@ -1153,3 +1201,59 @@ def transpile_swaps(circuit: Circuit) -> TranspileSwapsResult:
if instr.kind == InstructionKind.M:
visitor.qubits[instr.target] = None
return TranspileSwapsResult(new_circuit, tuple(visitor.qubits))


@overload
def _initialize_backend(
backend: StatevectorBackend | Literal["statevector"],
branch_selector: BranchSelector | None,
) -> StatevectorBackend: ...


@overload
def _initialize_backend(
backend: DensityMatrixBackend | Literal["densitymatrix"],
branch_selector: BranchSelector | None,
) -> DensityMatrixBackend: ...


@overload
def _initialize_backend(
backend: DenseStateBackend[_DenseStateT_co],
branch_selector: BranchSelector | None,
) -> DenseStateBackend[_DenseStateT_co]: ...


def _initialize_backend(
backend: DenseStateBackend[_DenseStateT_co] | _DenseStateBackendLiteral,
branch_selector: BranchSelector | None,
) -> _BuiltinDenseStateBackend | DenseStateBackend[_DenseStateT_co]:
"""Initialize backend for circuit simulation.

Parameters
----------
backend: :class:`graphix.sim.base_backend.DenseStateBackend[_DenseStateT_co]`, 'statevector', or 'densitymatrix'
Simulation backend
branch_selector: :class:`BranchSelector`
Branch selector used for measurements. Can only be specified if ``backend`` is not an already instantiated :class:`Backend` object. If ``None``, it defaults to :class:`RandomBranchSelector`.

Returns
-------
:class:`DenseStateBackend`
matching the appropriate backend
"""
if isinstance(backend, DenseStateBackend):
if branch_selector is not None:
raise ValueError("`branch_selector` cannot be specified if `backend` is already instantiated.")
return backend

if branch_selector is None:
branch_selector = RandomBranchSelector()

match backend:
case "statevector":
return StatevectorBackend(branch_selector=branch_selector)
case "densitymatrix":
return DensityMatrixBackend(branch_selector=branch_selector)
case _:
raise ValueError(f"Unknown backend {backend}.")
36 changes: 31 additions & 5 deletions tests/test_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,22 @@
from graphix.fundamentals import ANGLE_PI, Axis, Sign
from graphix.instruction import I, InstructionKind
from graphix.random_objects import rand_circuit, rand_gate, rand_state_vector
from graphix.sim.density_matrix import DensityMatrix
from graphix.sim.statevec import Statevec
from graphix.states import BasicStates
from graphix.transpiler import Circuit, transpile_swaps
from tests.test_branch_selector import CheckedBranchSelector

if TYPE_CHECKING:
from collections.abc import Callable
from typing import TypeAlias
from typing import Literal, TypeAlias

from graphix.instruction import InstructionType
from graphix.measurements import Outcome

InstructionTestCase: TypeAlias = Callable[[Generator], InstructionType]
_DenseStateBackendLiteral = Literal["statevector", "densitymatrix"]


INSTRUCTION_TEST_CASES: list[InstructionTestCase] = [
lambda _rng: instruction.CCX(0, (1, 2)),
Expand Down Expand Up @@ -156,20 +160,30 @@ def test_transpiled(self, fx_rng: Generator) -> None:
state_mbqc = pattern.simulate_pattern(rng=fx_rng)
assert state_mbqc.isclose(state)

@pytest.mark.parametrize("backend", ["statevector", "densitymatrix"])
@pytest.mark.parametrize("jumps", range(1, 11))
@pytest.mark.parametrize("axis", [Axis.X, Axis.Y, Axis.Z])
@pytest.mark.parametrize("outcome", [0, 1])
def test_measure(self, fx_bg: PCG64, jumps: int, axis: Axis, outcome: Outcome) -> None:
def test_measure(
self, fx_bg: PCG64, jumps: int, axis: Axis, outcome: Outcome, backend: _DenseStateBackendLiteral
) -> None:
rng = Generator(fx_bg.jumped(jumps))
circuit = Circuit(2)
circuit.cnot(0, 1)
circuit.m(0, axis)
input_state = rand_state_vector(2, rng=rng)
branch_selector = ConstBranchSelector(outcome)
state = circuit.simulate_statevector(rng=rng, input_state=input_state, branch_selector=branch_selector).statevec
state = circuit.simulate_statevector(
rng=rng, input_state=input_state, branch_selector=branch_selector, backend=backend
).statevec
pattern = circuit.transpile().pattern
state_mbqc = pattern.simulate_pattern(rng=rng, input_state=input_state, branch_selector=branch_selector)
assert state_mbqc.isclose(state)
state_mbqc = pattern.simulate_pattern(
rng=rng, input_state=input_state, branch_selector=branch_selector, backend=backend
)
if isinstance(state_mbqc, Statevec) and isinstance(state, Statevec):
assert state_mbqc.isclose(state)
elif isinstance(state_mbqc, DensityMatrix) and isinstance(state, DensityMatrix):
assert np.allclose(state_mbqc.rho, state.rho)

@pytest.mark.parametrize("input_axis", [Axis.X, Axis.Y, Axis.Z])
@pytest.mark.parametrize("input_sign", [Sign.PLUS, Sign.MINUS])
Expand Down Expand Up @@ -261,6 +275,18 @@ def test_simple(self) -> None:
state_mbqc = pattern.simulate_pattern(input_state=input_state, rng=rng)
assert state_mbqc.isclose(state)

@pytest.mark.parametrize("jumps", range(1, 3))
def test_dm_backend(self, fx_bg: PCG64, jumps: int) -> None:
nqubits = 2
rng = Generator(fx_bg.jumped(jumps))
circuit = rand_circuit(nqubits, 3, rng)
pattern = circuit.transpile().pattern
pattern.minimize_space()
input_state = rand_state_vector(nqubits, rng=rng)
state = circuit.simulate_statevector(input_state=input_state, backend="densitymatrix").statevec
state_mbqc = pattern.simulate_pattern(input_state=input_state, backend="densitymatrix", rng=rng)
assert np.allclose(state_mbqc.rho, state.rho)


@pytest.mark.parametrize("jumps", range(1, 11))
def test_transpile_swaps(fx_bg: PCG64, jumps: int) -> None:
Expand Down
Loading