From e045675d556507e4944a5a5831cf77f951398982 Mon Sep 17 00:00:00 2001 From: Thierry Martinez Date: Mon, 2 Feb 2026 16:22:06 +0100 Subject: [PATCH] Fix for graphix#220: rename `get_`/`set_`-prefixed functions This commit makes the repository work with https://github.com/TeamGraphix/graphix/pull/418 --- graphix_symbolic/sympy_parameter.py | 48 +++++++++++++++++------------ pyproject.toml | 3 ++ requirements.txt | 2 +- tests/test_sympy_parameter.py | 7 ++++- 4 files changed, 38 insertions(+), 22 deletions(-) diff --git a/graphix_symbolic/sympy_parameter.py b/graphix_symbolic/sympy_parameter.py index 757c18d..3b32cf5 100644 --- a/graphix_symbolic/sympy_parameter.py +++ b/graphix_symbolic/sympy_parameter.py @@ -7,11 +7,20 @@ from __future__ import annotations import numbers -from typing import Mapping +from typing import TYPE_CHECKING import numpy as np -import sympy as sp -from graphix.parameter import ExpressionOrComplex, ExpressionOrFloat, ExpressionWithTrigonometry, Parameter +import sympy as sp # type: ignore[import-untyped] +from graphix.parameter import ( + ExpressionOrComplex, + ExpressionOrFloat, + ExpressionOrSupportsFloat, + ExpressionWithTrigonometry, + Parameter, +) + +if TYPE_CHECKING: + from collections.abc import Mapping class SympyExpression(ExpressionWithTrigonometry): @@ -99,13 +108,13 @@ def __mod__(self, other) -> float: """ return np.nan - def sin(self) -> ExpressionOrFloat: + def sin(self) -> ExpressionWithTrigonometry: return SympyExpression(sp.sin(self._expression)) - def cos(self) -> ExpressionOrFloat: + def cos(self) -> ExpressionWithTrigonometry: return SympyExpression(sp.cos(self._expression)) - def exp(self) -> ExpressionOrFloat: + def exp(self) -> ExpressionWithTrigonometry: return SympyExpression(sp.exp(self._expression)) def conjugate(self) -> ExpressionOrFloat: @@ -124,25 +133,18 @@ def __repr__(self) -> str: def __str__(self) -> str: return str(self._expression) - @staticmethod - def __check_sympy_parameter(variable: Parameter) -> None: - if not isinstance(variable, SympyParameter): - raise ValueError( - f"Sympy expressions can only be substituted with sympy parameters, not {variable.__class__}." - ) - - def subs(self, variable: Parameter, value: ExpressionOrFloat) -> ExpressionOrComplex: - self.__check_sympy_parameter(variable) - result = sp.N(self._expression.subs(variable._expression, value)) + def subs(self, variable: Parameter, value: ExpressionOrSupportsFloat) -> ExpressionOrComplex: + parameter = _check_sympy_parameter(variable) + result = sp.N(self._expression.subs(parameter._expression, value)) if isinstance(result, numbers.Number) or not result.free_symbols: return complex(result) else: return SympyExpression(result) - def xreplace(self, assignment: Mapping[Parameter, ExpressionOrFloat]) -> ExpressionOrComplex: - for variable in assignment: - self.__check_sympy_parameter(variable) - sympy_assignment = {variable._expression: value for variable, value in assignment.items()} + def xreplace(self, assignment: Mapping[Parameter, ExpressionOrSupportsFloat]) -> ExpressionOrComplex: + sympy_assignment = { + _check_sympy_parameter(variable)._expression: value for variable, value in assignment.items() + } result = sp.N(self._expression.xreplace(sympy_assignment)) if isinstance(result, numbers.Number) or not result.free_symbols: return complex(result) @@ -150,6 +152,12 @@ def xreplace(self, assignment: Mapping[Parameter, ExpressionOrFloat]) -> Express return SympyExpression(result) +def _check_sympy_parameter(variable: Parameter) -> SympyParameter: + if not isinstance(variable, SympyParameter): + raise ValueError(f"Sympy expressions can only be substituted with sympy parameters, not {variable.__class__}.") + return variable + + class SympyParameter(Parameter, SympyExpression): """Placeholder for measurement angles, which allows the pattern optimizations without specifying measurement angles for measurement commands. diff --git a/pyproject.toml b/pyproject.toml index 5f3a6a5..4ffc549 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,3 +78,6 @@ docstring-code-format = true [tool.ruff.lint.pydocstyle] convention = "numpy" + +[tool.mypy] +files = ["graphix_symbolic", "tests"] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index d12aec9..839a4a0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -graphix @ git+https://github.com/thierry-martinez/graphix@branch_selector +graphix @ git+https://github.com/thierry-martinez/graphix@fix/220_avoid_get_set sympy>=1.9 diff --git a/tests/test_sympy_parameter.py b/tests/test_sympy_parameter.py index 949c7b4..dd2204b 100644 --- a/tests/test_sympy_parameter.py +++ b/tests/test_sympy_parameter.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + import numpy as np import pytest from graphix import Circuit @@ -6,6 +8,9 @@ from graphix_symbolic import SympyParameter +if TYPE_CHECKING: + from graphix.parameter import Parameter + def test_parameter_circuit_simulation(fx_rng: Generator) -> None: alpha = SympyParameter("alpha") @@ -23,7 +28,7 @@ def test_parameter_parallel_substitution(fx_rng: Generator) -> None: circuit = Circuit(2) circuit.rz(0, alpha) circuit.rz(1, beta) - mapping = {alpha: 0.5, beta: 0.4} + mapping: dict[Parameter, float] = {alpha: 0.5, beta: 0.4} result_subs_then_simulate = circuit.xreplace(mapping).simulate_statevector().statevec result_simulate_then_subs = circuit.simulate_statevector().statevec.xreplace(mapping) assert np.allclose(result_subs_then_simulate.psi, result_simulate_then_subs.psi)