Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/redsun/containers/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ class AppContainer:
"_virtual_container",
"_is_built",
"_built_devices",
"_devices_connected",
)

_device_components: ClassVar[dict[str, _DeviceComponent]] = {}
Expand Down Expand Up @@ -316,6 +317,7 @@ def __init__(self, *, session: str = "Redsun", frontend: str = "pyqt") -> None:
self._virtual_container: VirtualContainer | None = None
self._is_built: bool = False
self._built_devices: dict[str, Device] = {}
self._devices_connected: bool = False

# In the declarative subclass path (class MyApp(QtAppContainer, config=...))
# the metaclass loads the YAML only to resolve component kwargs and never
Expand Down Expand Up @@ -477,6 +479,7 @@ async def _connect_all() -> asyncio.Future[list[None]]:

future = asyncio.run_coroutine_threadsafe(_connect_all(), get_shared_loop())
future.result()
self._devices_connected = True

def shutdown(self) -> None:
"""Shutdown all presenters that implement ``HasShutdown``."""
Expand All @@ -494,9 +497,11 @@ def shutdown(self) -> None:
logger.info("Container shutdown complete")

def run(self) -> None:
"""Build if needed and start the application."""
"""Build and connect devices if needed, then start the application."""
if not self._is_built:
self.build()
if not self._devices_connected:
self.connect_devices()

frontend = self._config.get("frontend", "pyqt")
logger.info(f"Starting application with frontend: {frontend}")
Expand Down
77 changes: 77 additions & 0 deletions tests/container/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,3 +787,80 @@ async def test_oa_device_descriptor_contains_units(self) -> None:
desc = await m.x.describe()
assert "stage-x" in desc
assert desc["stage-x"]["units"] == "mm"


class TestConnectDevices:
"""Smoke tests for the connect_devices / run lifecycle."""

def test_connect_devices_requires_build(self) -> None:
"""connect_devices() raises RuntimeError when called before build()."""

class EmptyApp(AppContainer):
pass

app = EmptyApp()
with pytest.raises(RuntimeError, match="build()"):
app.connect_devices(mock=True)

def test_connect_devices_sets_connected_flag(self) -> None:
"""After connect_devices(mock=True), _devices_connected is True."""

class TestApp(AppContainer):
motor = declare_device(MockOAMotor, units="mm")

app = TestApp()
assert not app._devices_connected
app.build()
assert not app._devices_connected
app.connect_devices(mock=True)
assert app._devices_connected

def test_run_connects_devices_automatically(self) -> None:
"""run() calls connect_devices() so callers need not do it explicitly."""

class TestApp(AppContainer):
motor = declare_device(MockOAMotor, units="mm")

app = TestApp()
# Patch run() to stop after connect_devices so we don't need a frontend.
original_run = AppContainer.run

connected_before_frontend: list[bool] = []

def patched_run(self: AppContainer) -> None: # type: ignore[override]
# call the real run up to (but not past) frontend startup
if not self._is_built:
self.build()
if not self._devices_connected:
self.connect_devices(mock=True)
connected_before_frontend.append(self._devices_connected)

AppContainer.run = patched_run # type: ignore[method-assign]
try:
app.run()
finally:
AppContainer.run = original_run # type: ignore[method-assign]

assert connected_before_frontend == [True]

def test_run_skips_connect_when_already_connected(self) -> None:
"""Make sure that run() does not reconnect devices that were already connected."""
connect_calls: list[str] = []

class TrackingApp(AppContainer):
motor = declare_device(MockOAMotor, units="mm")

def connect_devices(self, mock: bool = False) -> None: # type: ignore[override]
connect_calls.append("called")
super().connect_devices(mock=mock)

app = TrackingApp()
app.build()
app.connect_devices(mock=True)
assert connect_calls == ["called"]

# Simulate run() when already connected — connect_devices must not fire again.
if not app._devices_connected:
app.connect_devices(mock=True)

assert connect_calls == ["called"] # still only one call
Loading