diff --git a/statemachine/state.py b/statemachine/state.py index 1d77bc9b..c4ba2c04 100644 --- a/statemachine/state.py +++ b/statemachine/state.py @@ -71,11 +71,18 @@ def __new__( # type: ignore [misc] inherited_kwargs.update(getattr(base, "_factory_kwargs", {})) inherited_kwargs.update(kwargs) + # Lazy import to avoid circular dependency (states.py imports state.py) + from .states import States + states = [] history = [] callbacks = {} for key, value in attrs.items(): - if isinstance(value, HistoryState): + if isinstance(value, States): + for state_id, state in value.items(): + state._set_id(state_id) + states.append(state) + elif isinstance(value, HistoryState): value._set_id(key) history.append(value) elif isinstance(value, State): diff --git a/tests/test_statechart_compound.py b/tests/test_statechart_compound.py index 334b4f28..4570c08d 100644 --- a/tests/test_statechart_compound.py +++ b/tests/test_statechart_compound.py @@ -7,7 +7,11 @@ Theme: Fellowship journey through Middle-earth. """ +from enum import Enum +from enum import auto + import pytest +from statemachine.states import States from statemachine import State from statemachine import StateChart @@ -213,3 +217,38 @@ class shire(State.Compound, name="The Shire"): sm = NamedCompound() assert sm.shire.name == "The Shire" + + async def test_from_enum_inside_compound(self, sm_runner): + """States.from_enum() works inside compound states (#606).""" + + class OuterStates(Enum): + FOO = auto() + BAR = auto() + + class InnerStates(Enum): + FIZZ = auto() + BUZZ = auto() + + class SC(StateChart): + baz = States.from_enum(OuterStates, initial=OuterStates.FOO, final=OuterStates.BAR) + + class inner(State.Compound): + qux = States.from_enum( + InnerStates, initial=InnerStates.FIZZ, final=InnerStates.BUZZ + ) + fizz_to_buzz = qux.FIZZ.to(qux.BUZZ) + + baz_foo_to_inner = baz.FOO.to(inner) + inner_to_baz_bar = inner.to(baz.BAR) + + sm = await sm_runner.start(SC) + assert {OuterStates.FOO} == set(sm.configuration_values) + + await sm_runner.send(sm, "baz_foo_to_inner") + assert {"inner", InnerStates.FIZZ} == set(sm.configuration_values) + + await sm_runner.send(sm, "fizz_to_buzz") + assert {"inner", InnerStates.BUZZ} == set(sm.configuration_values) + + await sm_runner.send(sm, "inner_to_baz_bar") + assert {OuterStates.BAR} == set(sm.configuration_values) diff --git a/tests/test_statechart_parallel.py b/tests/test_statechart_parallel.py index 619d1b36..297589e3 100644 --- a/tests/test_statechart_parallel.py +++ b/tests/test_statechart_parallel.py @@ -7,8 +7,14 @@ Theme: War of the Ring — multiple simultaneous fronts. """ +from enum import Enum +from enum import auto + import pytest +from statemachine.states import States +from statemachine import State +from statemachine import StateChart from tests.machines.parallel.session import Session from tests.machines.parallel.session_with_done_state import SessionWithDoneState from tests.machines.parallel.two_towers import TwoTowers @@ -139,6 +145,53 @@ async def test_top_level_parallel_done_state_fires_before_termination(self, sm_r assert {"finished"} == set(sm.configuration_values) assert sm.is_terminated is True + async def test_from_enum_inside_parallel(self, sm_runner): + """States.from_enum() works inside parallel states (#606).""" + + class RegionA(Enum): + IDLE = auto() + ACTIVE = auto() + + class RegionB(Enum): + OFF = auto() + ON = auto() + + class SC(StateChart): + start = State(initial=True) + done = State(final=True) + + class work(State.Parallel): + class region_a(State.Compound): + a = States.from_enum(RegionA, initial=RegionA.IDLE, final=RegionA.ACTIVE) + go_a = a.IDLE.to(a.ACTIVE) + + class region_b(State.Compound): + b = States.from_enum(RegionB, initial=RegionB.OFF, final=RegionB.ON) + go_b = b.OFF.to(b.ON) + + begin = start.to(work) + finish = work.to(done) + + sm = await sm_runner.start(SC) + assert {"start"} == set(sm.configuration_values) + + await sm_runner.send(sm, "begin") + vals = set(sm.configuration_values) + assert "work" in vals + assert RegionA.IDLE in vals + assert RegionB.OFF in vals + + await sm_runner.send(sm, "go_a") + vals = set(sm.configuration_values) + assert RegionA.ACTIVE in vals + assert RegionB.OFF in vals # region_b unchanged + + await sm_runner.send(sm, "go_b") + # Both regions final -> done.state.work fires + assert {RegionA.ACTIVE, RegionB.ON} <= set(sm.configuration_values) or {"done"} == set( + sm.configuration_values + ) + async def test_top_level_parallel_not_terminated_when_one_region_pending(self, sm_runner): """Machine keeps running when only one region reaches final.""" sm = await sm_runner.start(Session)