diff --git a/docs/how-to/automation-daemon.md b/docs/how-to/automation-daemon.md index 113872b..ec8a3f0 100644 --- a/docs/how-to/automation-daemon.md +++ b/docs/how-to/automation-daemon.md @@ -56,19 +56,24 @@ _snapshot: SystemSnapshot | None = None async def on_space_update(space: Space) -> None: global _snapshot if _snapshot is not None: - _snapshot.spaces[space.id] = space + space = _snapshot.apply_space(space) temp = ( - f"{space.state.current_temp_c:.1f}°C" - if space.state.current_temp_c is not None + f"{space.state.ambient_temperature_c:.1f}°C" + if space.state.ambient_temperature_c is not None else "unknown" ) - LOG.info("[space] %s — mode=%s temp=%s", space.name, space.controls.mode.value, temp) + LOG.info( + "[space] %s — mode=%s temp=%s", + space.name, + space.controls.hvac_mode.value, + temp, + ) if ( - space.state.current_temp_c is not None - and space.state.current_temp_c > 27.0 - and space.controls.mode.value in ("auto", "cool") + space.state.ambient_temperature_c is not None + and space.state.ambient_temperature_c > 27.0 + and space.controls.hvac_mode.value in ("auto", "cool") ): LOG.warning("[space] %s is above 27°C — check cooling", space.name) @@ -76,8 +81,8 @@ async def on_space_update(space: Space) -> None: async def on_idu_update(idu: IndoorUnit) -> None: global _snapshot if _snapshot is not None: - _snapshot.indoor_units[idu.id] = idu - LOG.debug("[idu] %s — fan=%s online=%s", idu.id, idu.controls.fan_speed.value, idu.state.is_online) + idu = _snapshot.apply_indoor_unit(idu) + LOG.debug("[idu] %s — fan=%s online=%s", idu.id, idu.controls.fan_speed.value, idu.is_online) async def run() -> None: @@ -91,7 +96,7 @@ async def run() -> None: _snapshot = await client.get_snapshot() LOG.info( "Snapshot loaded: system=%s rooms=%d idus=%d", - _snapshot.system_id, + client.system_name, len(_snapshot.rooms), len(_snapshot.indoor_units), ) @@ -100,8 +105,7 @@ async def run() -> None: stream = client.stream(topics, max_reconnects=-1, reconnect_delay_s=2.0) stream.on_space_update(on_space_update) stream.on_indoor_unit_update(on_idu_update) - stream.on_connected(lambda: LOG.info("Stream connected")) - stream.on_disconnected(lambda: LOG.warning("Stream disconnected; will reconnect automatically")) + stream.on_error(lambda exc: LOG.error("Stream stopped: %s", exc)) async with stream: LOG.info("Daemon running. Send SIGINT or SIGTERM to stop.") diff --git a/docs/how-to/configure-comfort-settings.md b/docs/how-to/configure-comfort-settings.md index 046a07b..a3e8f65 100644 --- a/docs/how-to/configure-comfort-settings.md +++ b/docs/how-to/configure-comfort-settings.md @@ -11,14 +11,14 @@ To retrieve all comfort presets: ```python settings = await client.list_comfort_settings() for s in settings: - print(f"{s.name}: mode={s.hvac_mode}, heat={s.heat_setpoint_c}°C, cool={s.cool_setpoint_c}°C, fan={s.fan_speed}") + print(f"{s.name}: mode={s.hvac_mode}, heat={s.heating_setpoint_c}°C, cool={s.cooling_setpoint_c}°C, fan={s.fan_speed}") ``` Alternatively, comfort settings are embedded in `SystemSnapshot`: ```python snapshot = await client.get_snapshot() -for cs in snapshot.comfort_settings.values(): +for cs in snapshot.comfort_settings: print(f"{cs.name}: {cs.hvac_mode}") ``` @@ -40,7 +40,7 @@ updated = await client.update_comfort_setting( cool_setpoint_c=25.0, fan_speed=FanSpeed.AUTO, ) -print(f"Updated '{updated.name}': heat={updated.heat_setpoint_c}°C cool={updated.cool_setpoint_c}°C") +print(f"Updated '{updated.name}': heat={updated.heating_setpoint_c}°C cool={updated.cooling_setpoint_c}°C") ``` Omit any parameter to keep its current value. You can also update by comfort setting ID string: @@ -73,11 +73,11 @@ async def main() -> None: snapshot = await client.get_snapshot() preset = next( - (cs for cs in snapshot.comfort_settings.values() if cs.name == PRESET_NAME), + (cs for cs in snapshot.comfort_settings if cs.name == PRESET_NAME), None, ) if preset is None: - names = [cs.name for cs in snapshot.comfort_settings.values()] + names = [cs.name for cs in snapshot.comfort_settings] print(f"Preset '{PRESET_NAME}' not found. Available: {names}") return @@ -85,8 +85,8 @@ async def main() -> None: updated = await client.set_space( space, mode=preset.hvac_mode, - heat_setpoint_c=preset.heat_setpoint_c, - cool_setpoint_c=preset.cool_setpoint_c, + heat_setpoint_c=preset.heating_setpoint_c, + cool_setpoint_c=preset.cooling_setpoint_c, ) print(f" {updated.name}: mode={updated.controls.hvac_mode}") diff --git a/docs/how-to/configure-schedules.md b/docs/how-to/configure-schedules.md index 9a1b954..e3fc812 100644 --- a/docs/how-to/configure-schedules.md +++ b/docs/how-to/configure-schedules.md @@ -13,19 +13,31 @@ To create a day program with timed comfort-setting transitions: ```python from quilt_hp.models.schedule import ScheduleEvent -# Get a comfort setting ID from the snapshot +# Get comfort settings for one room from the snapshot snapshot = await client.get_snapshot() space = snapshot.space_by_name("Bedroom") -active_cs = next( - cs for cs in snapshot.comfort_settings.values() if cs.name == "Active" -) -sleep_cs = next( - cs for cs in snapshot.comfort_settings.values() if cs.name == "Sleep" -) +assert space is not None +space_settings = snapshot.comfort_settings_for_space(space) +active_cs = next(cs for cs in space_settings if cs.name == "Active") +sleep_cs = next(cs for cs in space_settings if cs.name == "Sleep") events = [ - ScheduleEvent(time_of_day_s=7 * 3600, comfort_setting_id=active_cs.id), # 07:00 → Active - ScheduleEvent(time_of_day_s=22 * 3600, comfort_setting_id=sleep_cs.id), # 22:00 → Sleep + ScheduleEvent( + start_s=7 * 3600, + comfort_setting_id=active_cs.id, + hvac_mode=active_cs.hvac_mode, + heating_setpoint_c=active_cs.heating_setpoint_c, + cooling_setpoint_c=active_cs.cooling_setpoint_c, + precondition=False, + ), + ScheduleEvent( + start_s=22 * 3600, + comfort_setting_id=sleep_cs.id, + hvac_mode=sleep_cs.hvac_mode, + heating_setpoint_c=sleep_cs.heating_setpoint_c, + cooling_setpoint_c=sleep_cs.cooling_setpoint_c, + precondition=False, + ), ] day = await client.create_schedule_day( @@ -36,7 +48,7 @@ day = await client.create_schedule_day( print(f"Created schedule day: {day.id} ({len(day.events)} events)") ``` -`time_of_day_s` is the number of seconds from midnight (e.g., `7 * 3600` = 07:00). +`start_s` is the number of seconds from midnight (e.g., `7 * 3600` = 07:00). --- @@ -47,17 +59,17 @@ To create a schedule week and assign day programs to each weekday: ```python from quilt_hp.models.schedule import ScheduleWeekDay -# day_of_week: 0 = Monday, 6 = Sunday +# weekday: 1 = Monday, 7 = Sunday week = await client.create_schedule_week( space_id=space.id, days=[ - ScheduleWeekDay(day_of_week=0, schedule_day_id=weekday_program.id), # Mon - ScheduleWeekDay(day_of_week=1, schedule_day_id=weekday_program.id), # Tue - ScheduleWeekDay(day_of_week=2, schedule_day_id=weekday_program.id), # Wed - ScheduleWeekDay(day_of_week=3, schedule_day_id=weekday_program.id), # Thu - ScheduleWeekDay(day_of_week=4, schedule_day_id=weekday_program.id), # Fri - ScheduleWeekDay(day_of_week=5, schedule_day_id=weekend_program.id), # Sat - ScheduleWeekDay(day_of_week=6, schedule_day_id=weekend_program.id), # Sun + ScheduleWeekDay(weekday=1, day_id=weekday_program.id), # Mon + ScheduleWeekDay(weekday=2, day_id=weekday_program.id), # Tue + ScheduleWeekDay(weekday=3, day_id=weekday_program.id), # Wed + ScheduleWeekDay(weekday=4, day_id=weekday_program.id), # Thu + ScheduleWeekDay(weekday=5, day_id=weekday_program.id), # Fri + ScheduleWeekDay(weekday=6, day_id=weekend_program.id), # Sat + ScheduleWeekDay(weekday=7, day_id=weekend_program.id), # Sun ], ) print(f"Created schedule week: {week.id}") @@ -74,7 +86,7 @@ updated_week = await client.update_schedule_week( schedule_week_id=week.id, space_id=space.id, days=[ - ScheduleWeekDay(day_of_week=0, schedule_day_id=new_monday_program.id), + ScheduleWeekDay(weekday=1, day_id=new_monday_program.id), # ... include all 7 days; omitted days are cleared ], ) @@ -114,4 +126,4 @@ To resume: await client.set_schedule_execution(paused=False) ``` -This is a global switch. It affects all schedule weeks across all spaces in the system. The current pause state is available as `snapshot.schedule_paused`. +This is a global switch. It affects all schedule weeks across all spaces in the system. The current pause state is available as `snapshot.primary_location.schedule_paused` when a location is present. diff --git a/docs/how-to/control-spaces.md b/docs/how-to/control-spaces.md index 3f20710..703f2b6 100644 --- a/docs/how-to/control-spaces.md +++ b/docs/how-to/control-spaces.md @@ -88,7 +88,7 @@ To set the fan speed on an indoor unit: from quilt_hp.models.enums import FanSpeed snapshot = await client.get_snapshot() -idu = snapshot.indoor_units[next(iter(snapshot.indoor_units))] # first IDU +idu = snapshot.indoor_units[0] # first IDU updated = await client.set_indoor_unit(idu, fan_speed=FanSpeed.MEDIUM) print(f"Fan speed: {updated.controls.fan_speed}") diff --git a/docs/how-to/stream-updates.md b/docs/how-to/stream-updates.md index d6d6614..4330821 100644 --- a/docs/how-to/stream-updates.md +++ b/docs/how-to/stream-updates.md @@ -23,7 +23,13 @@ def on_idu(idu: IndoorUnit) -> None: async with client.stream(snapshot.stream_topics()) as stream: stream.on_space_update(on_space) - stream.on_indoor_unit_update(on_idu) + stream.on_indoor_unit_update(lambda idu: print(snapshot.apply_indoor_unit(idu).id)) + stream.on_outdoor_unit_update(snapshot.apply_outdoor_unit) + stream.on_controller_update(snapshot.apply_controller) + stream.on_qsm_update(snapshot.apply_qsm) + stream.on_remote_sensor_update(snapshot.apply_remote_sensor) + stream.on_controller_remote_sensor_update(snapshot.apply_controller_remote_sensor) + stream.on_software_update_info(lambda info: print(f"Update info: {info.id}")) stream.on_error(lambda e: print(f"Fatal error: {e}")) await asyncio.sleep(3600) # run for 1 hour ``` @@ -54,24 +60,59 @@ For indoor units: ```python def on_idu(idu: IndoorUnit) -> None: merged = snapshot.apply_indoor_unit(idu) - print(f"{merged.id}: online={merged.state.is_online}") + print(f"{merged.id}: online={merged.is_online}") ``` For background on why sparse diffs require merging, see [Snapshot and stream data model](../explanation/snapshot-and-stream.md). --- -## Run the stream as a background task +## Callback registration methods + +`NotifierStream` accepts both synchronous and async callbacks. Register whichever entity types you care about: + +| Method | Callback argument | Typical use | +| --- | --- | --- | +| `on_space_update()` | `Space` | Merge room diffs with `snapshot.apply_space()` | +| `on_indoor_unit_update()` | `IndoorUnit` | Merge IDU diffs with `snapshot.apply_indoor_unit()` | +| `on_outdoor_unit_update()` | `OutdoorUnit` | Merge ODU diffs with `snapshot.apply_outdoor_unit()` | +| `on_controller_update()` | `Controller` | Merge Dial diffs with `snapshot.apply_controller()` | +| `on_qsm_update()` | `QuiltSmartModule` | Merge QSM diffs with `snapshot.apply_qsm()` | +| `on_remote_sensor_update()` | `RemoteSensor` | Merge standalone sensor diffs with `snapshot.apply_remote_sensor()` | +| `on_controller_remote_sensor_update()` | `ControllerRemoteSensor` | Merge Dial sensor diffs with `snapshot.apply_controller_remote_sensor()` | +| `on_software_update_info()` | `SoftwareUpdateInfo` | Observe firmware/software update records | +| `on_error()` | `Exception` | Handle fatal stream failure after reconnects are exhausted | + +--- + +## Lifecycle methods + +Use these methods to control the stream explicitly: + +| Method / property | What it does | +| --- | --- | +| `await stream.start()` | Starts the listener in the background | +| `await stream.run_forever()` | Runs inline until cancelled or a fatal error stops it | +| `await stream.stop()` | Cancels the background task and closes the stream | +| `await stream.subscribe(topics)` | Adds topic subscriptions after startup | +| `await stream.unsubscribe(topics)` | Removes topic subscriptions | +| `stream.error` | Last fatal exception, or `None` while healthy | + +### Run the stream as a background task To run the stream while doing other work concurrently: ```python -async with client.stream(snapshot.stream_topics()) as stream: - stream.on_space_update(on_space) - # Stream runs in the background — do other work here +stream = client.stream(snapshot.stream_topics()) +stream.on_space_update(on_space) +await stream.start() +try: result = await do_something_else() await asyncio.sleep(3600) -# Stream is stopped when the async with block exits +finally: + await stream.stop() + if stream.error is not None: + print(f"Stream stopped with error: {stream.error}") ``` Use this pattern in integrations (Home Assistant, automation daemons) where the stream is just one part of a larger async application. @@ -115,7 +156,7 @@ async with client.stream(snapshot.stream_topics()) as stream: ## Handle stream errors and reconnect -The stream reconnects automatically with exponential back-off (1 s, 2 s, 4 s, … up to a 60 s cap). Use these options to configure the reconnect budget: +The stream reconnects automatically with exponential back-off (1 s, 2 s, 4 s, … up to a 60 s cap). Use `on_error()` or the `error` property to observe only fatal failures after the reconnect budget is exhausted. Configure the reconnect budget like this: ```python # Unlimited reconnects (default: -1) @@ -132,14 +173,16 @@ stream = client.stream( ) ``` -To observe connection lifecycle events: +To observe fatal stream failures: ```python -stream.on_connected(lambda: print("Stream connected")) -stream.on_disconnected(lambda: print("Stream disconnected; will reconnect")) stream.on_error(lambda e: print(f"Fatal error (budget exhausted): {e}")) + +await stream.run_forever() +if stream.error is not None: + print(f"Last fatal error: {stream.error}") ``` -`on_error` is called only when the reconnect budget is exhausted. Until then, disconnects and errors trigger automatic reconnection without invoking `on_error`. +`on_error()` is called only when the reconnect budget is exhausted. Until then, disconnects and transient errors trigger automatic reconnection without surfacing a fatal error to your callback. For the full reconnect state machine, see [The streaming protocol](../explanation/streaming-protocol.md). diff --git a/docs/how-to/tui-app.md b/docs/how-to/tui-app.md index 3a16995..a0f4b17 100644 --- a/docs/how-to/tui-app.md +++ b/docs/how-to/tui-app.md @@ -64,12 +64,10 @@ class IDUUpdate(Message): self.idu = idu -class StreamConnected(Message): - pass - - -class StreamDisconnected(Message): - pass +class StreamError(Message): + def __init__(self, error: Exception) -> None: + super().__init__() + self.error = error class QuiltApp(App): @@ -99,6 +97,7 @@ class QuiltApp(App): self._snapshot = await self._client.get_snapshot() self._refresh_table() + self.query_one("#status-bar", Label).update("● Streaming") self.run_worker(self._run_stream(), exclusive=True) @@ -110,30 +109,26 @@ class QuiltApp(App): stream = self._client.stream(topics, max_reconnects=-1) stream.on_space_update(self._on_space) stream.on_indoor_unit_update(self._on_idu) - stream.on_connected(lambda: self.post_message(StreamConnected())) - stream.on_disconnected(lambda: self.post_message(StreamDisconnected())) + stream.on_error(lambda exc: self.post_message(StreamError(exc))) async with stream: await asyncio.Event().wait() def _on_space(self, space: Space) -> None: if self._snapshot is not None: - self._snapshot.spaces[space.id] = space + space = self._snapshot.apply_space(space) self.post_message(SpaceUpdate(space)) def _on_idu(self, idu: IndoorUnit) -> None: if self._snapshot is not None: - self._snapshot.indoor_units[idu.id] = idu + idu = self._snapshot.apply_indoor_unit(idu) self.post_message(IDUUpdate(idu)) def on_space_update(self, msg: SpaceUpdate) -> None: self._update_row(msg.space) - def on_stream_connected(self, msg: StreamConnected) -> None: - self.query_one("#status-bar", Label).update("● Connected") - - def on_stream_disconnected(self, msg: StreamDisconnected) -> None: - self.query_one("#status-bar", Label).update("○ Disconnected — reconnecting…") + def on_stream_error(self, msg: StreamError) -> None: + self.query_one("#status-bar", Label).update(f"✕ Stream stopped: {msg.error}") def _refresh_table(self) -> None: if self._snapshot is None: @@ -144,15 +139,15 @@ class QuiltApp(App): def _update_row(self, space: Space) -> None: table = self.query_one("#spaces-table", DataTable) temp = ( - f"{space.state.current_temp_c:.1f}°C" - if space.state.current_temp_c is not None + f"{space.state.ambient_temperature_c:.1f}°C" + if space.state.ambient_temperature_c is not None else "—" ) setpoints = ( - f"{space.controls.heat_setpoint_c:.0f} / " - f"{space.controls.cool_setpoint_c:.0f}°C" + f"{space.controls.heating_setpoint_c:.0f} / " + f"{space.controls.cooling_setpoint_c:.0f}°C" ) - row = (space.name, space.controls.mode.value, temp, setpoints) + row = (space.name, space.controls.hvac_mode.value, temp, setpoints) key = f"space-{space.id}" if key in table.rows: @@ -203,8 +198,10 @@ async def action_set_mode(self) -> None: row_key = table.cursor_row_key if row_key is None: return + if self._snapshot is None: + return space_id = row_key.removeprefix("space-") - space = self._snapshot.spaces.get(space_id) + space = next((s for s in self._snapshot.spaces if s.id == space_id), None) if space is None: return diff --git a/docs/reference/client.md b/docs/reference/client.md index d486aad..7598c8c 100644 --- a/docs/reference/client.md +++ b/docs/reference/client.md @@ -112,7 +112,7 @@ async def __aenter__(self) -> QuiltClient: ... async def __aexit__(self, *_: object) -> None: ... ``` -`__aexit__` closes the gRPC channel. Always use `QuiltClient` as an async context manager. +`__aexit__` calls `close()` to close the gRPC channel. Prefer the async context manager, or call `await close()` yourself when managing lifecycle manually. --- @@ -143,6 +143,16 @@ async def refresh_token(self, context: TokenRefreshContext | None = None) -> Non Silently refreshes the auth token using the refresh token. Does not attempt OTP. Called automatically by the transport interceptor on `UNAUTHENTICATED`; rarely needed directly. +### `get_current_token` + +```python +def get_current_token(self) -> str +``` + +Returns the current JWT access token held by the client. + +**Raises:** `QuiltAuthError` if the client is not authenticated yet. + --- ## System discovery @@ -155,7 +165,7 @@ async def list_systems(self) -> list[SystemInfo] Lists all Quilt systems the authenticated user has access to. -**Returns:** List of `SystemInfo` objects with `id`, `name`, `timezone`, `location_id`. +**Returns:** List of `SystemInfo` objects with `id`, `name`, and `timezone`. **Raises:** `QuiltError` if the gRPC call fails. @@ -205,6 +215,14 @@ def invalidate_snapshot(self) -> None Discards the cached snapshot. The next `get_snapshot()` call fetches fresh data from the server. +### `close` + +```python +async def close(self) -> None +``` + +Closes the underlying gRPC channel and clears the client's open channel reference. Safe to call multiple times. + --- ## Space control diff --git a/docs/reference/models.md b/docs/reference/models.md index 4d70df9..e441774 100644 --- a/docs/reference/models.md +++ b/docs/reference/models.md @@ -67,10 +67,16 @@ service = UserService(channel) ```python from quilt_hp.services.streaming import NotifierStream -stream = NotifierStream( - channel=channel, - topics=topics, - token_provider=client, +# metadata_provider returns gRPC call metadata (e.g. auth headers). +# Obtain a token from your QuiltClient or token store. +def get_metadata() -> list[tuple[str, str]]: + return [("authorization", f"Bearer {token}")] + +stream = NotifierStream.create( + channel, + topics, + metadata_provider=get_metadata, + authenticate=client.refresh_token, max_reconnects=-1, reconnect_delay_s=1.0, ) @@ -78,14 +84,29 @@ stream = NotifierStream( See [Streaming protocol behavior](../explanation/streaming-protocol.md) for the full state machine, event types, and reconnect behavior. -Event registration methods: +Callback registration methods: ```python -stream.on_space_update(callback) # Callable[[Space], Awaitable | None] -stream.on_indoor_unit_update(callback) # Callable[[IndoorUnit], Awaitable | None] -stream.on_comfort_setting_update(callback) -stream.on_connected(callback) # no args -stream.on_disconnected(callback) # no args +stream.on_space_update(callback) +stream.on_indoor_unit_update(callback) +stream.on_outdoor_unit_update(callback) +stream.on_controller_update(callback) +stream.on_qsm_update(callback) +stream.on_remote_sensor_update(callback) +stream.on_controller_remote_sensor_update(callback) +stream.on_software_update_info(callback) +stream.on_error(callback) +``` + +Lifecycle methods: + +```python +await stream.start() +await stream.run_forever() +await stream.subscribe(["hds/space/"]) +await stream.unsubscribe(["hds/space/"]) +await stream.stop() +stream.error ``` --- @@ -99,31 +120,45 @@ All models are `dataclass` instances populated from proto fields by `from_proto( ```python @dataclass class SystemSnapshot: - system_id: str - location_id: str - spaces: dict[str, Space] - indoor_units: dict[str, IndoorUnit] - outdoor_units: dict[str, OutdoorUnit] - controllers: dict[str, Controller] - sensors: dict[str, RemoteSensor] - comfort_settings: dict[str, ComfortSetting] - schedule_days: dict[str, ScheduleDay] - schedule_weeks: dict[str, ScheduleWeek] - schedule_paused: bool - fetched_at: float # time.time() when snapshot was constructed + spaces: list[Space] + indoor_units: list[IndoorUnit] + outdoor_units: list[OutdoorUnit] + controllers: list[Controller] + quilt_smart_modules: list[QuiltSmartModule] + comfort_settings: list[ComfortSetting] + schedule_weeks: list[ScheduleWeek] + schedule_days: list[ScheduleDay] + remote_sensors: list[RemoteSensor] + controller_remote_sensors: list[ControllerRemoteSensor] + software_update_infos: list[SoftwareUpdateInfo] + locations: list[Location] + timezone: str | None ``` -`SystemSnapshot` is the root object returned by `get_snapshot()`. All child objects are indexed by their string ID for O(1) lookup. +`SystemSnapshot` is the root object returned by `get_snapshot()`. Child collections are stored as lists, not dicts. Look up objects by iterating, with helpers like `space_by_name()`, or by merging stream diffs in place with the `apply_*()` methods. Useful helper properties and methods: ```python -snapshot.rooms # → list[Space] leaf spaces only (has parent_space_id) -snapshot.floors # → list[Space] parent spaces only -snapshot.stream_topics() # → list[str] all topics for use with stream() +snapshot.rooms # → list[Space] leaf spaces only +snapshot.primary_location # → Location | None +snapshot.space_by_name("Bedroom") # → Space | None +snapshot.comfort_settings_for_space(space) +snapshot.away_comfort_setting(space) +snapshot.stream_topics() # → list[str] ``` -The `apply_*` methods (`apply_space_update`, `apply_indoor_unit_update`, etc.) are called by `NotifierStream` to merge sparse proto3 diffs into the snapshot in-place. You rarely call these directly. +The merge helpers update the matching list entry or append a new object when needed: + +```python +snapshot.apply_space(space) +snapshot.apply_indoor_unit(idu) +snapshot.apply_outdoor_unit(odu) +snapshot.apply_controller(controller) +snapshot.apply_qsm(qsm) +snapshot.apply_remote_sensor(sensor) +snapshot.apply_controller_remote_sensor(sensor) +``` --- @@ -133,11 +168,11 @@ The `apply_*` methods (`apply_space_update`, `apply_indoor_unit_update`, etc.) a @dataclass class Space: id: str + system_id: str name: str parent_space_id: str | None - location_id: str - controls: SpaceControls settings: SpaceSettings + controls: SpaceControls state: SpaceState ``` @@ -148,37 +183,45 @@ A single room or floor zone. `parent_space_id is None` for floor-level spaces; l ```python @dataclass class SpaceControls: - mode: HVACMode - heat_setpoint_c: float - cool_setpoint_c: float - comfort_setting_id: str | None + hvac_mode: HVACMode + temperature_setpoint_c: float + cooling_setpoint_c: float + heating_setpoint_c: float + comfort_setting_id: str + comfort_setting_override: ComfortSettingOverride + boost_mode: BoostMode ``` -The writable HVAC setpoint state. `comfort_setting_id` is `None` when the space is in manual control mode. Setting `mode=STANDBY` clears `comfort_setting_id`. +The writable HVAC control state. `comfort_setting_id` uses an empty-string sentinel when the space is in manual control mode. Setting `hvac_mode=STANDBY` clears the linked comfort setting. #### `SpaceSettings` ```python @dataclass class SpaceSettings: - unoccupied_timeout_s: float + name: str + timezone: str + occupancy_mode: OccupancyMode occupied_timeout_s: float - schedules_paused: bool + unoccupied_timeout_s: float + safety_heating: SafetyHeatingMode + hvac_controller_type: HvacControllerType ``` -Automation configuration for the space. +Automation and safety configuration for the space. #### `SpaceState` ```python @dataclass class SpaceState: - current_temp_c: float | None - occupancy: OccupancyState # OCCUPIED, UNOCCUPIED, UNKNOWN - last_occupied_at: datetime | None + ambient_temperature_c: float | None + hvac_state: HVACState + setpoint_c: float | None + comfort_setting_id: str ``` -Read-only live state derived from sensor telemetry. +Read-only live state derived from sensor telemetry and current control state. --- @@ -319,12 +362,16 @@ class RemoteSensorState: @dataclass class ComfortSetting: id: str - location_id: str + system_id: str + space_id: str name: str + type: ComfortSettingType hvac_mode: HVACMode - heat_setpoint_c: float - cool_setpoint_c: float + heating_setpoint_c: float + cooling_setpoint_c: float fan_speed: FanSpeed + louver_mode: LouverMode + louver_fixed_position: float ``` A named HVAC preset. Spaces reference comfort settings by `controls.comfort_setting_id`. @@ -347,8 +394,12 @@ class ScheduleDay: ```python @dataclass class ScheduleEvent: - time_of_day_s: int # seconds from midnight + start_s: int # seconds from midnight comfort_setting_id: str + hvac_mode: HVACMode + heating_setpoint_c: float + cooling_setpoint_c: float + precondition: bool ``` --- @@ -368,8 +419,8 @@ class ScheduleWeek: ```python @dataclass class ScheduleWeekDay: - day_of_week: int # 0 = Monday, 6 = Sunday - schedule_day_id: str | None + weekday: int # 1 = Monday, 7 = Sunday + day_id: str ``` --- @@ -382,7 +433,6 @@ class SystemInfo: id: str name: str timezone: str - location_id: str ``` Returned by `list_systems()`. @@ -396,6 +446,7 @@ Returned by `list_systems()`. class Location: id: str name: str + system_id: str timezone: str schedule_paused: bool ``` @@ -404,16 +455,93 @@ Location metadata embedded in `SystemSnapshot`. --- -## Enum types +### `ControllerRemoteSensor` + +```python +@dataclass +class ControllerRemoteSensor: + id: str + controller_id: str + mac: str | None + ambient_temperature_c: float | None + humidity_percent: float | None + battery_level_percent: float | None + signal_level_dbm: int | None + control_mode: RemoteSensorControlMode +``` -All enums live in `quilt_hp.models.enums`. They are `StrEnum` values (except `LouverMode`, `LedAnimation`, and `LightPreset` which may be `IntEnum`). +Temperature, humidity, battery, and signal data exposed by a controller when its remote-sensor mode is enabled. -| Enum | Values | -|------|--------| -| `HVACMode` | `STANDBY`, `COOL`, `HEAT`, `AUTO`, `FAN` | -| `FanSpeed` | `AUTO`, `QUIET`, `LOW`, `MEDIUM`, `HIGH`, `BLAST` | -| `LouverMode` | `CLOSED`, `SWEEP`, `FIXED`, `AUTO` | -| `OccupancyState` | `OCCUPIED`, `UNOCCUPIED`, `UNKNOWN` | -| `DeclaredUserType` | `HOMEOWNER`, `PARTNER` | +--- + +### `EnergyBucket` + +```python +@dataclass +class EnergyBucket: + start_time: datetime + energy_kwh: float + status: MetricBucketStatus +``` + +One hourly energy measurement bucket. Use `has_missing_energy_value` or `energy_kwh_or_none` to handle NaN sentinel values safely (a `None` or non-float `energy_kwh` is also treated as missing). + +--- + +### `SpaceEnergyMetrics` + +```python +@dataclass +class SpaceEnergyMetrics: + space_id: str + buckets: list[EnergyBucket] +``` + +Hourly energy history for one space. Convenience properties include `total_kwh` and `missing_bucket_count`. + +--- + +### `SoftwareUpdateInfo` + +```python +@dataclass +class SoftwareUpdateInfo: + id: str + state: int + status: int + current_version: str + target_version: str + current_progress: float + total_progress: float + progress_unit: int +``` + +Firmware/software update record associated with an indoor unit, outdoor unit, controller, or QSM. + +--- + +## Enum types -`FanSpeed.to_wire()` maps to `(fan_speed_mode, fan_speed_percent)` pairs consumed by the HDS proto. This mapping is handled inside `HomeDatastoreService`; client code works with `FanSpeed` values only. +All enums live in `quilt_hp.models.enums` and subclass `IntEnum`, mirroring Quilt's wire values. + +| Enum | Purpose | Representative values | +|------|---------|-----------------------| +| `HVACMode` | Requested operating mode | `STANDBY`, `COOL`, `HEAT`, `AUTO`, `FAN`, `FALLBACK_AUTO`, `FALLBACK_OFF` | +| `HVACState` | Actual running state | `STANDBY`, `COOL`, `HEAT`, `DRIFT`, `FAN` | +| `FanSpeed` | Indoor-unit fan speed preset | `AUTO`, `QUIET`, `LOW`, `MEDIUM`, `HIGH`, `BLAST` | +| `LouverMode` | Indoor-unit louver behavior | `CLOSED`, `SWEEP`, `FIXED`, `AUTO` | +| `LouverAngle` | Fixed louver angle preset | `ANGLE1`–`ANGLE5` | +| `LightPreset` | Built-in LED color presets | `DAYLIGHT`, `WARM`, `SUNSET`, `SKY` | +| `LedAnimation` | Indoor-unit LED animation mode | `NONE`, `SPARKLE_FADE`, `TWINKLE_FADE`, `DANCE`, `CHASE` | +| `ComfortSettingType` | Named preset kind | `ACTIVE`, `SLEEP`, `AWAY`, `STANDBY`, `CUSTOM` | +| `ComfortSettingOverride` | Why the active preset differs from schedule | `NONE`, `UNTIL_NEXT_SCHEDULE`, `INDEFINITE`, `UNOCCUPIED`, `OCCUPIED` | +| `BoostMode` | Space turbo override | `OFF`, `ON` | +| `OccupancyMode` | Space auto-away/return setting | `DISABLED`, `ENABLED` | +| `OccupancyState` | Presence/occupancy detection result | `UNDETECTED`, `DETECTED` | +| `SafetyHeatingMode` | Freeze-protection setting | `DISABLED`, `ENABLED` | +| `ConditionState` | Diagnostic condition status | `INACTIVE`, `ACTIVE` | +| `HvacControllerType` | Controller algorithm variant | `PASS_THROUGH_TEMPERATURE`, `INTEGRAL_TEMPERATURE_V1`, `INTEGRAL_TEMPERATURE_V2` | +| `FallbackControlCommand` | Offline fallback command sent to an IDU | `COMPLETE`, `EXIT` | +| `RemoteSensorControlMode` | Whether a remote sensor participates in control | `DISABLED`, `ENABLED` | + +`FanSpeed.to_wire()` and `FanSpeed.from_wire()` handle the Quilt protocol's `(fan_speed_mode, fan_speed_percent)` encoding. `LouverAngle.to_wire()` and `LouverAngle.from_wire()` do the same for fixed louver positions. diff --git a/src/quilt_hp/__init__.py b/src/quilt_hp/__init__.py index 0e7c50b..c1d4892 100644 --- a/src/quilt_hp/__init__.py +++ b/src/quilt_hp/__init__.py @@ -7,6 +7,7 @@ QuiltConnectionError, QuiltError, QuiltNotFoundError, + QuiltStreamError, ) __version__ = "0.2.2" @@ -18,5 +19,6 @@ "QuiltConnectionError", "QuiltError", "QuiltNotFoundError", + "QuiltStreamError", "__version__", ] diff --git a/src/quilt_hp/auth.py b/src/quilt_hp/auth.py index d44a570..76f6317 100644 --- a/src/quilt_hp/auth.py +++ b/src/quilt_hp/auth.py @@ -8,6 +8,7 @@ import asyncio import inspect +import logging import time from collections.abc import Awaitable, Callable from functools import partial @@ -33,6 +34,8 @@ type OtpCallback = Callable[[str], str | Awaitable[str]] type CognitoAuthResult = dict[str, str | int] +logger = logging.getLogger(__name__) + class _CognitoClient(Protocol): def initiate_auth(self, **kwargs: object) -> dict[str, object]: ... @@ -55,8 +58,11 @@ def _require_str(result: CognitoAuthResult, key: str) -> str: def _expires_in_s(result: CognitoAuthResult) -> int: - value = result.get("ExpiresIn", 3600) - return value if isinstance(value, int) else 3600 + value = result.get("ExpiresIn") + if isinstance(value, int): + return value + logger.warning("Authentication response missing valid ExpiresIn; using default") + return 3600 def _make_cognito_client() -> _CognitoClient: @@ -191,10 +197,12 @@ async def authenticate( # 1. Valid cached IdToken if cached is not None and not cached.is_expired: + logger.debug("Using cached token") return cached.id_token # 2. Refresh token if cached is not None and cached.refresh_token: + logger.debug("Starting token refresh") context = refresh_context or TokenRefreshContext( reason=TokenRefreshReason.EXPIRED_CACHED_TOKEN, source="authenticate", @@ -203,7 +211,7 @@ async def authenticate( await refresh_hooks.on_refresh_start(context) try: result = await _do_refresh(cached.refresh_token) - except Exception as exc: + except (QuiltAuthError, ClientError) as exc: if refresh_hooks is not None: await refresh_hooks.on_refresh_failure(context, exc) action = ( @@ -213,6 +221,7 @@ async def authenticate( ) if action == RefreshFailureAction.RAISE or otp_callback is None: raise + logger.warning("Refresh failed; falling back to OTP") else: tokens = CachedTokens( id_token=_require_str(result, "IdToken"), @@ -223,6 +232,7 @@ async def authenticate( await _save_tokens(token_store, email, tokens) if refresh_hooks is not None: await refresh_hooks.on_refresh_success(context, tokens) + logger.info("Token refresh succeeded") return tokens.id_token # 3. Full OTP login @@ -241,4 +251,5 @@ async def authenticate( ) if token_store: await _save_tokens(token_store, email, tokens) + logger.info("OTP login succeeded") return tokens.id_token diff --git a/src/quilt_hp/cli/main.py b/src/quilt_hp/cli/main.py index c3ed00b..5295cc6 100644 --- a/src/quilt_hp/cli/main.py +++ b/src/quilt_hp/cli/main.py @@ -5,8 +5,10 @@ import asyncio import json import sys -from collections.abc import Coroutine, Sequence +from collections.abc import AsyncIterator, Callable, Coroutine, Sequence +from contextlib import asynccontextmanager from enum import StrEnum +from functools import wraps from typing import Any, Protocol, cast try: @@ -20,6 +22,7 @@ from quilt_hp.cli.settings import SettingsStore from quilt_hp.cli.store import FileStore from quilt_hp.client import QuiltClient +from quilt_hp.exceptions import QuiltAuthError, QuiltError from quilt_hp.models.enums import HVACMode from quilt_hp.models.system import SystemSnapshot @@ -59,8 +62,44 @@ def _app_callback( _ = version +def _handle_errors[T]( + func: Callable[..., Coroutine[Any, Any, T]], +) -> Callable[..., Coroutine[Any, Any, T]]: + @wraps(func) + async def _wrapped(*args: Any, **kwargs: Any) -> T: + try: + return await func(*args, **kwargs) + except QuiltAuthError as exc: + console.print(f"[red]Authentication failed: {exc}[/red]") + raise typer.Exit(1) from None + except QuiltError as exc: + console.print(f"[red]Error: {exc}[/red]") + raise typer.Exit(1) from None + + return _wrapped + + def _run[T](coro: Coroutine[Any, Any, T]) -> T: - return asyncio.run(coro) + @_handle_errors + async def _wrapped() -> T: + return await coro + + return asyncio.run(_wrapped()) + + +@asynccontextmanager +async def _logged_in_client(email: str, home: str | None) -> AsyncIterator[QuiltClient]: + async with QuiltClient(email, home=home, token_store=_store) as client: + await client.login() + yield client + + +@asynccontextmanager +async def _client_snapshot( + email: str, home: str | None +) -> AsyncIterator[tuple[QuiltClient, SystemSnapshot]]: + async with _logged_in_client(email, home) as client: + yield client, await client.get_snapshot() def _resolve(email: str | None, home: str | None) -> tuple[str, str | None]: @@ -364,8 +403,8 @@ async def _login() -> None: await client.login() console.print(f"[green]✓ Already logged in as {email}[/green]") return - except Exception: - pass + except QuiltAuthError: + pass # expected — cached tokens absent/expired, proceed to OTP # Cached tokens absent/expired — trigger OTP flow and prompt. async def _prompt_for_otp(challenge_email: str) -> str: @@ -395,9 +434,7 @@ def info( email, home = _resolve(email, home) async def _info() -> None: - async with QuiltClient(email, home=home, token_store=_store) as client: - await client.login() - snap = await client.get_snapshot() + async with _client_snapshot(email, home) as (_, snap): _emit_output(output, _snapshot_payload(snap)) _run(_info()) @@ -418,9 +455,8 @@ def devices( email, home = _resolve(email, home) async def _devices() -> None: - async with QuiltClient(email, home=home, token_store=_store) as client: - await client.login() - payload = _snapshot_payload(await client.get_snapshot()) + async with _client_snapshot(email, home) as (_, snapshot): + payload = _snapshot_payload(snapshot) device_payload = { "spaces": [{"id": s["id"], "name": s["name"]} for s in payload["spaces"]], "indoor_units": [ @@ -477,9 +513,8 @@ def values( email, home = _resolve(email, home) async def _values() -> None: - async with QuiltClient(email, home=home, token_store=_store) as client: - await client.login() - payload = _snapshot_payload(await client.get_snapshot()) + async with _client_snapshot(email, home) as (_, snapshot): + payload = _snapshot_payload(snapshot) value_payload = { "spaces": [ { @@ -599,8 +634,7 @@ def presets( email, home = _resolve(email, home) async def _presets() -> None: - async with QuiltClient(email, home=home, token_store=_store) as client: - await client.login() + async with _logged_in_client(email, home) as client: settings = await client.list_comfort_settings() if not settings: console.print("No comfort settings found.") @@ -628,10 +662,7 @@ def schedules( email, home = _resolve(email, home) async def _schedules() -> None: - async with QuiltClient(email, home=home, token_store=_store) as client: - await client.login() - snapshot = await client.get_snapshot() - + async with _client_snapshot(email, home) as (_, snapshot): cs_by_id = {cs.id: cs for cs in snapshot.comfort_settings} day_by_id = {d.id: d for d in snapshot.schedule_days} @@ -675,9 +706,7 @@ async def _energy() -> None: import zoneinfo from datetime import datetime, timedelta - async with QuiltClient(email, home=home, token_store=_store) as client: - await client.login() - snapshot = await client.get_snapshot() + async with _client_snapshot(email, home) as (client, snapshot): name_by_id = {s.id: s.name for s in snapshot.spaces} now = datetime.now(tz=zoneinfo.ZoneInfo(snapshot.timezone or "UTC")) @@ -721,10 +750,7 @@ def set_space( email, home = _resolve(email, home) async def _set() -> None: - async with QuiltClient(email, home=home, token_store=_store) as client: - await client.login() - snap = await client.get_snapshot() - + async with _client_snapshot(email, home) as (client, snap): space = next( (s for s in snap.rooms if s.name.lower() == space_name.lower()), None, @@ -733,7 +759,15 @@ async def _set() -> None: console.print(f"[red]Room {space_name!r} not found.[/red]") raise typer.Exit(1) - hvac_mode = HVACMode[mode.upper()] if mode else None + if mode: + try: + hvac_mode: HVACMode | None = HVACMode[mode.upper()] + except KeyError: + valid = ", ".join(m.name.lower() for m in HVACMode if m.value) + console.print(f"[red]Invalid mode {mode!r}. Valid: {valid}[/red]") + raise typer.Exit(1) from None + else: + hvac_mode = None await client.set_space( space.id, diff --git a/src/quilt_hp/cli/settings.py b/src/quilt_hp/cli/settings.py index 26910d9..cfb5658 100644 --- a/src/quilt_hp/cli/settings.py +++ b/src/quilt_hp/cli/settings.py @@ -98,10 +98,11 @@ def _coerce(self, payload: dict[str, object]) -> Settings: email = payload.get("email") home = payload.get("home") dark = payload.get("dark") + uf = payload.get("use_fahrenheit", False) return Settings( email=email if isinstance(email, str) else None, home=home if isinstance(home, str) else None, - use_fahrenheit=bool(payload.get("use_fahrenheit", False)), + use_fahrenheit=uf if isinstance(uf, bool) else False, dark=dark if isinstance(dark, bool) else None, ) diff --git a/src/quilt_hp/cli/store.py b/src/quilt_hp/cli/store.py index e553957..622f564 100644 --- a/src/quilt_hp/cli/store.py +++ b/src/quilt_hp/cli/store.py @@ -9,14 +9,23 @@ import asyncio import json +import logging import os from dataclasses import asdict from pathlib import Path +from uuid import uuid4 from quilt_hp._paths import app_config_dir from quilt_hp.exceptions import QuiltAuthError from quilt_hp.tokens import CachedTokens +logger = logging.getLogger(__name__) + + +def _warn_if_permission_error(action: str, path: Path, exc: OSError) -> None: + if isinstance(exc, PermissionError): + logger.warning("Permission denied while %s token file %s", action, path) + class FileStore: """Filesystem-backed token persistence.""" @@ -26,18 +35,40 @@ class FileStore: def _token_path(self) -> Path: return app_config_dir() / "tokens.json" + def _atomic_write(self, payload: dict[str, object]) -> None: + path = self._token_path() + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_name(f"{path.name}.{os.getpid()}.{uuid4().hex}.tmp") + try: + # Open with O_CREAT|O_WRONLY|O_TRUNC and mode 0o600 so the file + # is never world-readable, even transiently before chmod. + fd = os.open(tmp, os.O_CREAT | os.O_WRONLY | os.O_TRUNC, 0o600) + with os.fdopen(fd, "w") as f: + f.write(json.dumps(payload, indent=2)) + os.replace(tmp, path) + os.chmod(path, 0o600) + except OSError: + try: + tmp.unlink(missing_ok=True) + except OSError: + pass + raise + async def load(self, email: str) -> CachedTokens | None: """TokenStore.load — return cached tokens for *email* or None.""" return await asyncio.to_thread(self._load_sync, email) def _load_sync(self, email: str) -> CachedTokens | None: + path = self._token_path() + logger.debug("Loading token file %s", path) try: - data = json.loads(self._token_path().read_text()) + data = json.loads(path.read_text()) except FileNotFoundError: return None except json.JSONDecodeError as exc: raise QuiltAuthError("Token store contains invalid JSON.") from exc except OSError as exc: + _warn_if_permission_error("reading", path, exc) raise QuiltAuthError("Failed to read token store.") from exc try: @@ -58,6 +89,7 @@ async def save(self, email: str, tokens: CachedTokens) -> None: def _save_sync(self, email: str, tokens: CachedTokens) -> None: path = self._token_path() + logger.debug("Saving token file %s", path) try: data = json.loads(path.read_text()) except FileNotFoundError: @@ -65,17 +97,19 @@ def _save_sync(self, email: str, tokens: CachedTokens) -> None: except json.JSONDecodeError as exc: raise QuiltAuthError("Token store contains invalid JSON.") from exc except OSError as exc: + _warn_if_permission_error("reading", path, exc) raise QuiltAuthError("Failed to read token store.") from exc data[email] = asdict(tokens) try: - path.write_text(json.dumps(data, indent=2)) - os.chmod(path, 0o600) + self._atomic_write(data) except OSError as exc: + _warn_if_permission_error("writing", path, exc) raise QuiltAuthError("Failed to persist token store.") from exc def clear_tokens(self, email: str) -> None: """Remove cached tokens for *email*.""" path = self._token_path() + logger.debug("Loading token file %s", path) try: data = json.loads(path.read_text()) except FileNotFoundError: @@ -83,23 +117,28 @@ def clear_tokens(self, email: str) -> None: except json.JSONDecodeError as exc: raise QuiltAuthError("Token store contains invalid JSON.") from exc except OSError as exc: + _warn_if_permission_error("reading", path, exc) raise QuiltAuthError("Failed to read token store.") from exc data.pop(email, None) + logger.debug("Saving token file %s", path) try: - path.write_text(json.dumps(data, indent=2)) - os.chmod(path, 0o600) + self._atomic_write(data) except OSError as exc: + _warn_if_permission_error("writing", path, exc) raise QuiltAuthError("Failed to persist token store.") from exc def list_emails(self) -> list[str]: """All email addresses that have cached tokens.""" + path = self._token_path() + logger.debug("Loading token file %s", path) try: - data = json.loads(self._token_path().read_text()) + data = json.loads(path.read_text()) except FileNotFoundError: return [] except json.JSONDecodeError as exc: raise QuiltAuthError("Token store contains invalid JSON.") from exc except OSError as exc: + _warn_if_permission_error("reading", path, exc) raise QuiltAuthError("Failed to read token store.") from exc return [k for k in data if isinstance(k, str)] diff --git a/src/quilt_hp/client.py b/src/quilt_hp/client.py index 638ca28..f8693dd 100644 --- a/src/quilt_hp/client.py +++ b/src/quilt_hp/client.py @@ -13,8 +13,10 @@ from __future__ import annotations +import logging import time -from typing import TYPE_CHECKING, Self +from collections.abc import Callable +from typing import TYPE_CHECKING, Protocol, Self, TypeVar from quilt_hp.auth import OtpCallback, authenticate from quilt_hp.const import Environment @@ -32,6 +34,8 @@ ) from quilt_hp.transport import auth_metadata, create_channel +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from datetime import datetime @@ -46,6 +50,13 @@ from quilt_hp.models.system import SystemInfo, SystemSnapshot +class _HasId(Protocol): + id: str + + +TResolved = TypeVar("TResolved", bound=_HasId) + + class QuiltClient: """Async client for the Quilt HVAC cloud API. @@ -102,6 +113,7 @@ def get_current_token(self) -> str: def _ensure_channel(self) -> grpc.aio.Channel: if self._channel is None: + logger.debug("Creating client channel for %s", self._environment.value) self._channel = create_channel( self, self._environment, @@ -112,6 +124,45 @@ def _ensure_channel(self) -> grpc.aio.Channel: self._user_svc = UserService(self._channel) return self._channel + def _require_channel(self) -> grpc.aio.Channel: + if self._channel is None: + raise QuiltError("Client not connected. Call login() first.") + return self._channel + + def _require_hds(self) -> HomeDatastoreService: + if self._hds is None: + raise QuiltError("Client not connected. Call login() first.") + return self._hds + + def _require_sysinfo(self) -> SystemInformationService: + if self._sysinfo is None: + raise QuiltError("Client not connected. Call login() first.") + return self._sysinfo + + def _require_user_service(self) -> UserService: + if self._user_svc is None: + raise QuiltError("Client not connected. Call login() first.") + return self._user_svc + + async def _resolve_system_id(self, system_id: str | None = None) -> str: + return system_id or await self.get_system_id() + + async def _resolve_snapshot_item( + self, + item: TResolved | str, + *, + items: Callable[[SystemSnapshot], list[TResolved]], + kind: str, + ) -> TResolved: + if not isinstance(item, str): + return item + + snapshot = await self.get_snapshot() + for candidate in items(snapshot): + if candidate.id == item: + return candidate + raise QuiltError(f"{kind} {item!r} not found") + # --- Auth --- async def login(self, otp_callback: OtpCallback | None = None) -> None: @@ -131,7 +182,13 @@ async def login(self, otp_callback: OtpCallback | None = None) -> None: refresh_hooks=self._token_refresh_hooks, refresh_policy=self._token_refresh_policy, ) + # Clear cached state so stale data from a prior session is never returned. + self._system_id = None + self._system_name = None + self._snapshot_cache = None + self._snapshot_cached_at = 0.0 self._ensure_channel() + logger.info("Login succeeded") async def refresh_token(self, context: TokenRefreshContext | None = None) -> None: """Refresh the auth token without OTP when refresh token is valid.""" @@ -156,16 +213,16 @@ def system_name(self) -> str | None: async def list_systems(self) -> list[SystemInfo]: """List all systems the user has access to.""" - self._ensure_channel() - assert self._sysinfo is not None - return await self._sysinfo.list_systems() + return await self._require_sysinfo().list_systems() async def get_system_id(self, home: str | None = None) -> str: """Get primary system ID, cached after first call unless home changes.""" target_home = home or self._home + logger.debug("Resolving system for home filter %r", target_home) if self._system_id is not None: # Bypass the cache only when a different home is requested. if not home or home == self._home: + logger.debug("Using cached system id %s", self._system_id) return self._system_id systems = await self.list_systems() @@ -184,6 +241,7 @@ async def get_system_id(self, home: str | None = None) -> str: self._system_id = systems[0].id self._system_name = systems[0].name + logger.info("Selected system %s (%s)", self._system_name, self._system_id) return self._system_id async def get_snapshot(self, system_id: str | None = None) -> SystemSnapshot: @@ -194,17 +252,18 @@ async def get_snapshot(self, system_id: str | None = None) -> SystemSnapshot: Pass ``system_id`` to query a specific system (bypasses and does not populate the cache for the default system). """ - self._ensure_channel() - assert self._hds is not None - sid = system_id or await self.get_system_id() + hds = self._require_hds() + sid = await self._resolve_system_id(system_id) # Only use cache for the default (unspecified) system_id if system_id is None and self._snapshot_ttl_s > 0: age = time.monotonic() - self._snapshot_cached_at if self._snapshot_cache is not None and age < self._snapshot_ttl_s: + logger.debug("Snapshot cache hit for system %s", sid) return self._snapshot_cache + logger.debug("Snapshot cache miss for system %s", sid) - snapshot = await self._hds.get_system(sid) + snapshot = await hds.get_system(sid) if system_id is None and self._snapshot_ttl_s > 0: self._snapshot_cache = snapshot @@ -214,6 +273,7 @@ async def get_snapshot(self, system_id: str | None = None) -> SystemSnapshot: def invalidate_snapshot(self) -> None: """Discard the cached snapshot so the next call fetches fresh data.""" + logger.warning("Invalidating snapshot cache") self._snapshot_cache = None self._snapshot_cached_at = 0.0 @@ -238,15 +298,11 @@ async def set_space( space: A ``Space`` object (no snapshot lookup needed) **or** a space ID string (snapshot is fetched to resolve the object). """ - self._ensure_channel() - assert self._hds is not None - if isinstance(space, str): - snapshot = await self.get_snapshot() - resolved = next((s for s in snapshot.spaces if s.id == space), None) - if resolved is None: - raise QuiltError(f"Space {space!r} not found") - space = resolved - return await self._hds.update_space( + hds = self._require_hds() + space = await self._resolve_snapshot_item( + space, items=lambda snapshot: snapshot.spaces, kind="Space" + ) + return await hds.update_space( space, mode=mode, heat_setpoint_c=heat_setpoint_c, @@ -267,15 +323,11 @@ async def set_space_settings( unoccupied_timeout_s: Seconds of no-presence before auto-away. occupied_timeout_s: Seconds of presence before auto-return. """ - self._ensure_channel() - assert self._hds is not None - if isinstance(space, str): - snapshot = await self.get_snapshot() - resolved = next((s for s in snapshot.spaces if s.id == space), None) - if resolved is None: - raise QuiltError(f"Space {space!r} not found") - space = resolved - return await self._hds.update_space_settings( + hds = self._require_hds() + space = await self._resolve_snapshot_item( + space, items=lambda snapshot: snapshot.spaces, kind="Space" + ) + return await hds.update_space_settings( space, unoccupied_timeout_s=unoccupied_timeout_s, occupied_timeout_s=occupied_timeout_s, @@ -305,15 +357,13 @@ async def set_indoor_unit( idu: An ``IndoorUnit`` object (no snapshot lookup needed) **or** an IDU ID string (snapshot is fetched to resolve the object). """ - self._ensure_channel() - assert self._hds is not None - if isinstance(idu, str): - snapshot = await self.get_snapshot() - resolved = next((u for u in snapshot.indoor_units if u.id == idu), None) - if resolved is None: - raise QuiltError(f"Indoor unit {idu!r} not found") - idu = resolved - return await self._hds.update_indoor_unit( + hds = self._require_hds() + idu = await self._resolve_snapshot_item( + idu, + items=lambda snapshot: snapshot.indoor_units, + kind="Indoor unit", + ) + return await hds.update_indoor_unit( idu, fan_speed=fan_speed, louver_mode=louver_mode, @@ -346,15 +396,13 @@ async def set_indoor_unit_settings( All parameters are optional; omitted fields keep their current value. Set a fence value to 0.0 to clear it (returns to max-range detection). """ - self._ensure_channel() - assert self._hds is not None - if isinstance(idu, str): - snapshot = await self.get_snapshot() - resolved = next((u for u in snapshot.indoor_units if u.id == idu), None) - if resolved is None: - raise QuiltError(f"Indoor unit {idu!r} not found") - idu = resolved - return await self._hds.update_indoor_unit_settings( + hds = self._require_hds() + idu = await self._resolve_snapshot_item( + idu, + items=lambda snapshot: snapshot.indoor_units, + kind="Indoor unit", + ) + return await hds.update_indoor_unit_settings( idu, fence_left_m=fence_left_m, fence_right_m=fence_right_m, @@ -384,15 +432,13 @@ async def update_comfort_setting( setting: A ``ComfortSetting`` object (no snapshot lookup needed) **or** a setting ID string (snapshot resolves the object). """ - self._ensure_channel() - assert self._hds is not None - if isinstance(setting, str): - snapshot = await self.get_snapshot() - resolved = next((s for s in snapshot.comfort_settings if s.id == setting), None) - if resolved is None: - raise QuiltError(f"Comfort setting {setting!r} not found") - setting = resolved - return await self._hds.update_comfort_setting( + hds = self._require_hds() + setting = await self._resolve_snapshot_item( + setting, + items=lambda snapshot: snapshot.comfort_settings, + kind="Comfort setting", + ) + return await hds.update_comfort_setting( setting, name=name, hvac_mode=hvac_mode, @@ -410,10 +456,9 @@ async def create_schedule_day( events: list[ScheduleEvent], ) -> ScheduleDay: """Create a new schedule day program from domain schedule events.""" - self._ensure_channel() - assert self._hds is not None - system_id = await self.get_system_id() - return await self._hds.create_schedule_day( + hds = self._require_hds() + system_id = await self._resolve_system_id() + return await hds.create_schedule_day( system_id=system_id, space_id=space_id, name=name, @@ -426,10 +471,9 @@ async def create_schedule_week( days: list[ScheduleWeekDay] | None = None, ) -> ScheduleWeek: """Create a new schedule week from domain weekday mappings.""" - self._ensure_channel() - assert self._hds is not None - system_id = await self.get_system_id() - return await self._hds.create_schedule_week( + hds = self._require_hds() + system_id = await self._resolve_system_id() + return await hds.create_schedule_week( system_id=system_id, space_id=space_id, days=days, @@ -442,10 +486,9 @@ async def update_schedule_week( days: list[ScheduleWeekDay], ) -> ScheduleWeek: """Update an existing schedule week with domain weekday mappings.""" - self._ensure_channel() - assert self._hds is not None - system_id = await self.get_system_id() - return await self._hds.update_schedule_week( + hds = self._require_hds() + system_id = await self._resolve_system_id() + return await hds.update_schedule_week( schedule_week_id=schedule_week_id, system_id=system_id, space_id=space_id, @@ -454,9 +497,7 @@ async def update_schedule_week( async def delete_schedule_day(self, schedule_day_id: str) -> None: """Delete a schedule day program.""" - self._ensure_channel() - assert self._hds is not None - await self._hds.delete_schedule_day(schedule_day_id) + await self._require_hds().delete_schedule_day(schedule_day_id) async def update_schedule_day( self, @@ -466,10 +507,9 @@ async def update_schedule_day( events: list[ScheduleEvent] | None = None, ) -> ScheduleDay: """Update an existing schedule day using domain schedule events.""" - self._ensure_channel() - assert self._hds is not None - system_id = await self.get_system_id() - return await self._hds.update_schedule_day( + hds = self._require_hds() + system_id = await self._resolve_system_id() + return await hds.update_schedule_day( schedule_day_id=schedule_day_id, system_id=system_id, space_id=space_id, @@ -479,9 +519,7 @@ async def update_schedule_day( async def delete_schedule_week(self, schedule_week_id: str) -> None: """Delete a schedule week.""" - self._ensure_channel() - assert self._hds is not None - await self._hds.delete_schedule_week(schedule_week_id) + await self._require_hds().delete_schedule_week(schedule_week_id) async def set_schedule_execution(self, paused: bool) -> None: """Globally pause or resume all schedules for the primary location. @@ -489,13 +527,12 @@ async def set_schedule_execution(self, paused: bool) -> None: Args: paused: True to pause all schedules, False to resume. """ - self._ensure_channel() - assert self._hds is not None + hds = self._require_hds() snapshot = await self.get_snapshot() loc = snapshot.primary_location if loc is None: raise QuiltError("No location found for this system.") - await self._hds.update_location_schedule_execution( + await hds.update_location_schedule_execution( location_id=loc.id, system_id=loc.system_id, paused=paused, @@ -510,10 +547,8 @@ async def get_energy( system_id: str | None = None, ) -> list[SpaceEnergyMetrics]: """Fetch energy metrics for a time range.""" - self._ensure_channel() - assert self._sysinfo is not None - sid = system_id or await self.get_system_id() - return await self._sysinfo.get_energy_metrics(sid, start, end) + sid = await self._resolve_system_id(system_id) + return await self._require_sysinfo().get_energy_metrics(sid, start, end) # --- Streaming --- @@ -523,6 +558,7 @@ def stream( *, max_reconnects: int = -1, reconnect_delay_s: float = 1.0, + debounce_s: float = 0.0, ) -> NotifierStream: """Create a NotifierStream for real-time updates. @@ -533,6 +569,9 @@ def stream( means unlimited (the default). reconnect_delay_s: Initial back-off in seconds before reconnecting. Doubles on each attempt, capped at 60 s. + debounce_s: Quiet period in seconds for coalescing updates by + entity before dispatching the latest event. ``0.0`` disables + debouncing. Returns a ``NotifierStream`` that can be used as: @@ -549,7 +588,7 @@ def stream( s.on_space_update(my_callback) await s.run_forever() """ - channel = self._ensure_channel() + channel = self._require_channel() return NotifierStream.create( channel, topics, @@ -557,15 +596,14 @@ def stream( authenticate=self.refresh_token, max_reconnects=max_reconnects, reconnect_delay_s=reconnect_delay_s, + debounce_s=debounce_s, ) # --- User --- async def get_current_user(self) -> User: """Get the currently authenticated user.""" - self._ensure_channel() - assert self._user_svc is not None - return await self._user_svc.get_current_user() + return await self._require_user_service().get_current_user() async def update_current_user( self, @@ -575,9 +613,7 @@ async def update_current_user( phone_number: str | None = None, ) -> User: """Update current user's first/last name and optional phone number.""" - self._ensure_channel() - assert self._user_svc is not None - return await self._user_svc.update_current_user( + return await self._require_user_service().update_current_user( first_name=first_name, last_name=last_name, phone_number=phone_number, @@ -585,9 +621,7 @@ async def update_current_user( async def get_user_attributes(self) -> UserAttributes: """Get current user's additional attributes.""" - self._ensure_channel() - assert self._user_svc is not None - return await self._user_svc.get_user_attributes() + return await self._require_user_service().get_user_attributes() async def patch_user_attributes( self, @@ -595,9 +629,7 @@ async def patch_user_attributes( declared_user_type: DeclaredUserType, ) -> UserAttributes: """Patch current user's additional attributes.""" - self._ensure_channel() - assert self._user_svc is not None - return await self._user_svc.patch_user_attributes( + return await self._require_user_service().patch_user_attributes( declared_user_type=declared_user_type, ) @@ -608,6 +640,9 @@ async def close(self) -> None: if self._channel is not None: await self._channel.close() self._channel = None + self._hds = None + self._sysinfo = None + self._user_svc = None async def __aenter__(self) -> Self: return self diff --git a/src/quilt_hp/models/__init__.py b/src/quilt_hp/models/__init__.py index 40e0af4..855c1b0 100644 --- a/src/quilt_hp/models/__init__.py +++ b/src/quilt_hp/models/__init__.py @@ -17,6 +17,7 @@ LightPreset, LouverAngle, LouverMode, + MetricBucketStatus, OccupancyMode, OccupancyState, RemoteSensorControlMode, @@ -65,6 +66,7 @@ "Location", "LouverAngle", "LouverMode", + "MetricBucketStatus", "OccupancyMode", "OccupancyState", "OutdoorUnit", diff --git a/src/quilt_hp/models/_helpers.py b/src/quilt_hp/models/_helpers.py new file mode 100644 index 0000000..ca7e313 --- /dev/null +++ b/src/quilt_hp/models/_helpers.py @@ -0,0 +1,31 @@ +from __future__ import annotations + + +def lookup_hardware(hw_map: dict[str, object], hardware_id: str | None) -> object | None: + """Resolve hardware objects across common ID formats.""" + if not hardware_id: + return None + raw = hardware_id.strip() + if not raw: + return None + keys = ( + raw, + raw.rsplit("/", 1)[-1], + raw.rsplit(":", 1)[-1], + raw.casefold(), + raw.rsplit("/", 1)[-1].casefold(), + raw.rsplit(":", 1)[-1].casefold(), + ) + for key in keys: + hw = hw_map.get(key) + if hw is not None: + return hw + return None + + +def parse_wifi_state(proto: object) -> tuple[str | None, str | None, int | None]: + """Extract WiFi fields while preserving explicit zero signal values.""" + ssid = getattr(proto, "ssid", "") or None + ip = getattr(proto, "ipv4_address", None) or None + signal = getattr(proto, "signal_level_dbm", None) + return ssid, ip, signal if signal is not None else None diff --git a/src/quilt_hp/models/controller.py b/src/quilt_hp/models/controller.py index bf25d16..41e1141 100644 --- a/src/quilt_hp/models/controller.py +++ b/src/quilt_hp/models/controller.py @@ -7,33 +7,13 @@ from typing import Any, cast from quilt_hp.const import PROTO_TIMESTAMP_UNSET_SECONDS +from quilt_hp.models._helpers import lookup_hardware, parse_wifi_state from quilt_hp.models.enums import RemoteSensorControlMode from quilt_hp.models.qsm import WifiInfo _ONLINE_THRESHOLD_S = 5 * 60 # 5 minutes, matching KMP IS_ONLINE_THRESHOLD_MINUTES -def _lookup_hw(hw_map: dict[str, object], hardware_id: str | None) -> object | None: - if not hardware_id: - return None - raw = hardware_id.strip() - if not raw: - return None - keys = ( - raw, - raw.rsplit("/", 1)[-1], - raw.rsplit(":", 1)[-1], - raw.casefold(), - raw.rsplit("/", 1)[-1].casefold(), - raw.rsplit(":", 1)[-1].casefold(), - ) - for key in keys: - hw = hw_map.get(key) - if hw is not None: - return hw - return None - - @dataclass(slots=True) class Controller: """A Quilt controller (Dial thermostat).""" @@ -119,11 +99,13 @@ def _wifi(wstate: object) -> WifiInfo | None: info = WifiInfo.from_proto(wstate) return info if info.connected else None + wifi_ssid, wifi_ip, wifi_signal_dbm = parse_wifi_state(w) + serial: str | None = None model_sku: str | None = None fw_ver: str | None = None if hw_map: - hw = _lookup_hw(hw_map, p.relationships.hardware_id) + hw = lookup_hardware(hw_map, p.relationships.hardware_id) if hw is not None: a = cast("Any", hw).attributes serial = a.serial_number or None @@ -139,9 +121,9 @@ def _wifi(wstate: object) -> WifiInfo | None: pcb_temperature_a_c=p.state.temperature_f3, pcb_temperature_b_c=p.state.temperature_f4, calibrated_ambient_c=p.state.temperature_f5, - wifi_ssid=w.ssid or None, - wifi_ip=w.ipv4_address or None, - wifi_signal_dbm=w.signal_level_dbm or None, + wifi_ssid=wifi_ssid, + wifi_ip=wifi_ip, + wifi_signal_dbm=wifi_signal_dbm, wifi_freq_mhz=w.frequency_mhz or None, wifi_last_seen=wifi_last_seen, ap_wifi=_wifi(p.ap_wifi_state), diff --git a/src/quilt_hp/models/energy.py b/src/quilt_hp/models/energy.py index d54a445..c9fd59d 100644 --- a/src/quilt_hp/models/energy.py +++ b/src/quilt_hp/models/energy.py @@ -6,6 +6,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING +from quilt_hp.models.enums import MetricBucketStatus + if TYPE_CHECKING: from datetime import datetime @@ -16,16 +18,21 @@ class EnergyBucket: start_time: datetime energy_kwh: float - status: int # 0=UNSPECIFIED, 1=COMPLETE, 2=INCOMPLETE + status: MetricBucketStatus @property def has_missing_energy_value(self) -> bool: - """True when energy_kwh is NaN (wire sentinel for missing/error data).""" - return math.isnan(self.energy_kwh) + """True when energy_kwh is missing: either not a float or NaN sentinel.""" + return not isinstance(self.energy_kwh, float) or math.isnan(self.energy_kwh) + + @property + def is_valid(self) -> bool: + """True when this bucket carries a usable numeric energy value.""" + return self.energy_kwh_or_none is not None @property def energy_kwh_or_none(self) -> float | None: - """Energy value, or None when this bucket carries a NaN sentinel.""" + """Energy value, or None when this bucket is missing or NaN.""" return None if self.has_missing_energy_value else self.energy_kwh diff --git a/src/quilt_hp/models/enums.py b/src/quilt_hp/models/enums.py index c2549ed..d6c7f6a 100644 --- a/src/quilt_hp/models/enums.py +++ b/src/quilt_hp/models/enums.py @@ -250,6 +250,17 @@ def __str__(self) -> str: return self.name +class MetricBucketStatus(IntEnum): + """Energy-metric bucket completeness state.""" + + UNSPECIFIED = 0 + COMPLETE = 1 + INCOMPLETE = 2 + + def __str__(self) -> str: + return self.name + + class BoostMode(IntEnum): """Boost (turbo) mode override for a space.""" diff --git a/src/quilt_hp/models/indoor_unit.py b/src/quilt_hp/models/indoor_unit.py index 06af33e..5748d55 100644 --- a/src/quilt_hp/models/indoor_unit.py +++ b/src/quilt_hp/models/indoor_unit.py @@ -392,7 +392,9 @@ def _idu_from_proto(proto: object) -> IndoorUnit: fan_speed=FanSpeed.from_wire(c.fan_speed_mode, c.fan_speed_percent), fan_speed_mode_raw=c.fan_speed_mode, fan_speed_percent_raw=c.fan_speed_percent, - louver_mode=LouverMode(c.louver_mode) if c.louver_mode else LouverMode.UNSPECIFIED, + louver_mode=( + LouverMode(c.louver_mode) if c.louver_mode is not None else LouverMode.UNSPECIFIED + ), louver_fixed_position=c.louver_fixed_position, led_color_code=c.led_color_code, led_brightness=c.led_color_brightness_percent, diff --git a/src/quilt_hp/models/outdoor_unit.py b/src/quilt_hp/models/outdoor_unit.py index 8d24802..0a8b392 100644 --- a/src/quilt_hp/models/outdoor_unit.py +++ b/src/quilt_hp/models/outdoor_unit.py @@ -4,26 +4,19 @@ from dataclasses import dataclass +from quilt_hp.models._helpers import lookup_hardware -def _lookup_hw(hw_map: dict[str, object], hardware_id: str | None) -> object | None: - if not hardware_id: - return None - raw = hardware_id.strip() - if not raw: - return None - keys = ( - raw, - raw.rsplit("/", 1)[-1], - raw.rsplit(":", 1)[-1], - raw.casefold(), - raw.rsplit("/", 1)[-1].casefold(), - raw.rsplit(":", 1)[-1].casefold(), - ) - for key in keys: - hw = hw_map.get(key) - if hw is not None: - return hw - return None + +def _has_performance_data(proto: object) -> bool: + if not hasattr(proto, "performance_data"): + return False + has_field = getattr(proto, "HasField", None) + if callable(has_field): + try: + return bool(has_field("performance_data")) + except ValueError: + pass + return True @dataclass(slots=True) @@ -58,24 +51,21 @@ class OutdoorUnit: def from_proto(cls, proto: object, hw_map: dict[str, object] | None = None) -> OutdoorUnit: """Construct from a protobuf OutdoorUnit message.""" hw_id = proto.relationships.hardware_id # type: ignore[attr-defined] - hw = _lookup_hw(hw_map, hw_id) if hw_map else None + hw = lookup_hardware(hw_map, hw_id) if hw_map else None pd = None - if hasattr(proto, "performance_data"): - p = proto.performance_data - # Server does not reliably set updated_ts on - # OutdoorUnitPerformanceData. Gate on any non-zero value instead. - if p.ambient_temperature_c or p.compressor_frequency_hz or p.energy_measurement_j: - pd = OutdoorUnitPerformanceData( - measurement_interval_s=p.measurement_interval_s, - energy_measurement_j=p.energy_measurement_j, - compressor_frequency_hz=p.compressor_frequency_hz, - ambient_temperature_c=p.ambient_temperature_c, - coil_temperature_c=p.coil_temperature_c, - exhaust_temperature_c=p.exhaust_temperature_c, - high_pressure_kpa=p.high_pressure_kpa, - low_pressure_kpa=p.low_pressure_kpa, - ) + if _has_performance_data(proto): + p = proto.performance_data # type: ignore[attr-defined] + pd = OutdoorUnitPerformanceData( + measurement_interval_s=p.measurement_interval_s, + energy_measurement_j=p.energy_measurement_j, + compressor_frequency_hz=p.compressor_frequency_hz, + ambient_temperature_c=p.ambient_temperature_c, + coil_temperature_c=p.coil_temperature_c, + exhaust_temperature_c=p.exhaust_temperature_c, + high_pressure_kpa=p.high_pressure_kpa, + low_pressure_kpa=p.low_pressure_kpa, + ) return cls( id=proto.header.object_id, # type: ignore[attr-defined] diff --git a/src/quilt_hp/models/qsm.py b/src/quilt_hp/models/qsm.py index ecfccdf..762e5d4 100644 --- a/src/quilt_hp/models/qsm.py +++ b/src/quilt_hp/models/qsm.py @@ -11,6 +11,8 @@ from dataclasses import dataclass +from quilt_hp.models._helpers import parse_wifi_state + @dataclass(slots=True) class WifiInfo: @@ -26,10 +28,8 @@ def connected(self) -> bool: @classmethod def from_proto(cls, proto: object) -> WifiInfo: - ssid = getattr(proto, "ssid", None) or None - ip = getattr(proto, "ipv4_address", None) or None - sig = getattr(proto, "signal_level_dbm", None) or None - return cls(ssid=ssid, ip=ip, signal_dbm=sig) + ssid, ip, signal_dbm = parse_wifi_state(proto) + return cls(ssid=ssid, ip=ip, signal_dbm=signal_dbm) @dataclass(slots=True) diff --git a/src/quilt_hp/models/schedule.py b/src/quilt_hp/models/schedule.py index 15778e2..7e1a6a1 100644 --- a/src/quilt_hp/models/schedule.py +++ b/src/quilt_hp/models/schedule.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from quilt_hp.const import EMPTY_COMFORT_SETTING_ID_SENTINEL, UNKNOWN_SCHEDULE_SORT_ORDER_SENTINEL +from quilt_hp.models.enums import HVACMode _WEEKDAY_NAMES = { 0: "?", @@ -24,7 +25,7 @@ class ScheduleEvent: start_s: int # seconds from midnight comfort_setting_id: str - hvac_mode: int + hvac_mode: HVACMode heating_setpoint_c: float cooling_setpoint_c: float precondition: bool @@ -61,7 +62,7 @@ def from_proto(cls, proto: object) -> ScheduleDay: ScheduleEvent( start_s=ev.start_s, comfort_setting_id=ev.comfort_setting_id, - hvac_mode=ev.hvac_mode, + hvac_mode=HVACMode(ev.hvac_mode), heating_setpoint_c=ev.heating_temperature_setpoint_c, cooling_setpoint_c=ev.cooling_temperature_setpoint_c, precondition=ev.precondition, diff --git a/src/quilt_hp/models/sensor.py b/src/quilt_hp/models/sensor.py index bb05a3c..b98802d 100644 --- a/src/quilt_hp/models/sensor.py +++ b/src/quilt_hp/models/sensor.py @@ -21,11 +21,15 @@ def _parse_state( s: object, ) -> tuple[float | None, float | None, float | None, int | None]: """Return ambient temp, humidity, battery, and signal from proto state.""" + ambient_temperature_c = getattr(s, "ambient_temperature_c", None) + humidity_percent = getattr(s, "humidity_percent", None) + battery_level_percent = getattr(s, "battery_level_percent", None) + signal_level_dbm = getattr(s, "signal_level_dbm", None) return ( - s.ambient_temperature_c or None, # type: ignore[attr-defined] - s.humidity_percent or None, # type: ignore[attr-defined] - s.battery_level_percent or None, # type: ignore[attr-defined] - s.signal_level_dbm or None, # type: ignore[attr-defined] + ambient_temperature_c if ambient_temperature_c is not None else None, + humidity_percent if humidity_percent is not None else None, + battery_level_percent if battery_level_percent is not None else None, + signal_level_dbm if signal_level_dbm is not None else None, ) diff --git a/src/quilt_hp/models/space.py b/src/quilt_hp/models/space.py index 7f46853..ae77041 100644 --- a/src/quilt_hp/models/space.py +++ b/src/quilt_hp/models/space.py @@ -79,14 +79,18 @@ def fmt(val_c: float) -> str: mode = self.hvac_mode if mode in (HVACMode.STANDBY, HVACMode.UNSPECIFIED, HVACMode.FAN): return "--" - if mode == HVACMode.COOL and self.cooling_setpoint_c: + if mode == HVACMode.COOL: return fmt(self.cooling_setpoint_c) - if mode == HVACMode.HEAT and self.heating_setpoint_c: + if mode == HVACMode.HEAT: return fmt(self.heating_setpoint_c) - if mode == HVACMode.AUTO and self.cooling_setpoint_c and self.heating_setpoint_c: + if mode == HVACMode.AUTO: return f"{fmt(self.heating_setpoint_c)}–{fmt(self.cooling_setpoint_c)}" - best = self.temperature_setpoint_c or self.cooling_setpoint_c or self.heating_setpoint_c - return fmt(best) if best else "--" + best = self.temperature_setpoint_c + if best is None: + best = self.cooling_setpoint_c + if best is None: + best = self.heating_setpoint_c + return fmt(best) if best is not None else "--" @property def has_standby_sentinel_setpoints(self) -> bool: diff --git a/src/quilt_hp/services/__init__.py b/src/quilt_hp/services/__init__.py index 8d74951..d28958c 100644 --- a/src/quilt_hp/services/__init__.py +++ b/src/quilt_hp/services/__init__.py @@ -1 +1,113 @@ """Service layer — thin async wrappers around gRPC stubs.""" + +from __future__ import annotations + +import asyncio +import logging +from collections.abc import Awaitable, Callable +from typing import Any, cast + +import grpc +import grpc.aio + +from quilt_hp.exceptions import QuiltConnectionError, QuiltError + +logger = logging.getLogger(__name__) + +_TRANSIENT_GRPC_CODES = { + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.DEADLINE_EXCEEDED, +} + + +class _GrpcCallContext: + def __init__( + self, + operation: str, + *, + max_retries: int = 0, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + ) -> None: + self._operation = operation + self._max_retries = max_retries + self._retry_delay = retry_delay + self._retry_backoff = retry_backoff + + async def __aenter__(self) -> Callable[..., Awaitable[Any]]: + return self.run + + async def __aexit__(self, exc_type: object, exc: BaseException | None, tb: object) -> bool: + del exc_type, tb + if exc is None: + return False + translated = self._translate_exception(exc) + if translated is exc: + raise exc + raise translated from exc + + async def run(self, func: Callable[..., Awaitable[Any]], /, *args: Any, **kwargs: Any) -> Any: + attempt = 0 + delay = self._retry_delay + while True: + try: + return await func(*args, **kwargs) + except Exception as exc: + wrapped = self._translate_exception(exc) + if not self._should_retry(exc, attempt): + if wrapped is exc: + raise + raise wrapped from exc + attempt += 1 + logger.warning( + "%s failed with %s; retrying in %.1fs (%d/%d)", + self._operation, + cast("grpc.aio.AioRpcError", exc).code(), + delay, + attempt, + self._max_retries, + ) + await asyncio.sleep(delay) + delay *= self._retry_backoff + + def _should_retry(self, exc: BaseException, attempt: int) -> bool: + return ( + isinstance(exc, grpc.aio.AioRpcError) + and exc.code() in _TRANSIENT_GRPC_CODES + and attempt < self._max_retries + ) + + def _translate_exception(self, exc: BaseException) -> QuiltError: + if isinstance(exc, QuiltError): + return exc + if isinstance(exc, grpc.aio.AioRpcError): + if exc.code() in _TRANSIENT_GRPC_CODES: + return QuiltConnectionError(f"{self._operation} failed: {exc.details()}") + return QuiltError(f"{self._operation} failed: {exc.details()}") + logger.debug("Unexpected error in %s: %s", self._operation, exc) + return QuiltError(f"{self._operation} failed: {exc}") + + +def grpc_call( + operation: str, + *, + max_retries: int = 0, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, +) -> _GrpcCallContext: + """Translate gRPC errors and optionally retry transient unary calls. + + Usage:: + + async with grpc_call("UpdateSpace"): + result = await stub.UpdateSpace(request) + + async with grpc_call("ListSystems", max_retries=2) as call: + result = await call(stub.ListSystems, request) + """ + return _GrpcCallContext( + operation, + max_retries=max_retries, + retry_delay=retry_delay, + retry_backoff=retry_backoff, + ) diff --git a/src/quilt_hp/services/hds.py b/src/quilt_hp/services/hds.py index f637b4a..d99cb64 100644 --- a/src/quilt_hp/services/hds.py +++ b/src/quilt_hp/services/hds.py @@ -5,6 +5,7 @@ from __future__ import annotations +import logging import time from collections.abc import Callable, Sequence from typing import Protocol, cast @@ -22,6 +23,8 @@ from quilt_hp.models.space import Space from quilt_hp.models.system import SystemSnapshot +logger = logging.getLogger(__name__) + class _HomeDatastoreServiceStub(Protocol): async def GetHomeDatastoreSystem( @@ -81,6 +84,7 @@ def __init__(self, channel: grpc.aio.Channel) -> None: async def get_system(self, system_id: str) -> SystemSnapshot: """Fetch a full system snapshot.""" + logger.debug("RPC GetHomeDatastoreSystem system_id=%s", system_id) try: snap = await self._stub.GetHomeDatastoreSystem( hds.GetHomeDatastoreSystemRequest(system_id=system_id) @@ -118,13 +122,13 @@ async def update_space( heat = heat_setpoint_c if heat_setpoint_c is not None else c.heating_setpoint_c cool = cool_setpoint_c if cool_setpoint_c is not None else c.cooling_setpoint_c - # Mode-relevant setpoint routing (from KX.java) - temp_setpoint = heat if mode_enum == HVACMode.HEAT else cool - # AUTO mode: enforce cool - heat >= 2.5°C if mode_enum == HVACMode.AUTO and cool - heat < 2.5: cool = heat + 2.5 + # Mode-relevant setpoint routing (from KX.java) + temp_setpoint = heat if mode_enum == HVACMode.HEAT else cool + # Setting to STANDBY explicitly means "turn off" — clear the comfort # setting so occupancy cannot reactivate the room (i.e. not AWAY mode). if mode_enum == HVACMode.STANDBY: @@ -149,6 +153,9 @@ async def update_space( comfort_setting_id_string=cs_id, ), ) + logger.debug( + "RPC UpdateSpace space_id=%s system_id=%s", snapshot_space.id, snapshot_space.system_id + ) try: result = await self._stub.UpdateSpace(hds.UpdateSpaceRequest(diff=diff)) except grpc.aio.AioRpcError as exc: @@ -190,6 +197,11 @@ async def update_space_settings( updated_ts=_now_ts(), ), ) + logger.debug( + "RPC UpdateSpace settings space_id=%s system_id=%s", + snapshot_space.id, + snapshot_space.system_id, + ) try: result = await self._stub.UpdateSpace(hds.UpdateSpaceRequest(diff=diff)) except grpc.aio.AioRpcError as exc: @@ -238,6 +250,7 @@ async def update_indoor_unit( ), ), ) + logger.debug("RPC UpdateIndoorUnit indoor_unit_id=%s system_id=%s", idu.id, idu.system_id) try: result = await self._stub.UpdateIndoorUnit(hds.UpdateIndoorUnitRequest(diff=diff)) except grpc.aio.AioRpcError as exc: @@ -289,6 +302,9 @@ async def update_indoor_unit_settings( ), ), ) + logger.debug( + "RPC UpdateIndoorUnit settings indoor_unit_id=%s system_id=%s", idu.id, idu.system_id + ) try: result = await self._stub.UpdateIndoorUnit(hds.UpdateIndoorUnitRequest(diff=diff)) except grpc.aio.AioRpcError as exc: @@ -332,6 +348,11 @@ async def update_comfort_setting( type=cast("hds.ComfortSettingType.ValueType", setting.type.value), ), ) + logger.debug( + "RPC UpdateComfortSetting comfort_setting_id=%s system_id=%s", + setting.id, + setting.system_id, + ) try: result = await self._stub.UpdateComfortSetting( hds.UpdateComfortSettingRequest(comfort_setting=diff) @@ -355,6 +376,7 @@ async def create_schedule_day( relationships=hds.ScheduleDayRelationships(space_id=space_id), events=wire_events, ) + logger.debug("RPC CreateScheduleDay system_id=%s space_id=%s", system_id, space_id) try: result = await self._stub.CreateScheduleDay( hds.CreateScheduleDayRequest(schedule_day=diff) @@ -376,6 +398,7 @@ async def create_schedule_week( relationships=hds.ScheduleWeekRelationships(space_id=space_id), days=wire_days, ) + logger.debug("RPC CreateScheduleWeek system_id=%s space_id=%s", system_id, space_id) try: result = await self._stub.CreateScheduleWeek( hds.CreateScheduleWeekRequest(schedule_week=diff) @@ -401,6 +424,12 @@ async def update_schedule_week( relationships=hds.ScheduleWeekRelationships(space_id=space_id), days=wire_days, ) + logger.debug( + "RPC UpdateScheduleWeek schedule_week_id=%s system_id=%s space_id=%s", + schedule_week_id, + system_id, + space_id, + ) try: result = await self._stub.UpdateScheduleWeek( hds.UpdateScheduleWeekRequest(schedule_week=diff) @@ -411,6 +440,7 @@ async def update_schedule_week( async def delete_schedule_day(self, schedule_day_id: str) -> None: """Delete a schedule day program.""" + logger.debug("RPC DeleteScheduleDay schedule_day_id=%s", schedule_day_id) try: await self._stub.DeleteScheduleDay( hds.DeleteScheduleDayRequest(schedule_day_id=schedule_day_id) @@ -438,6 +468,12 @@ async def update_schedule_day( diff.attributes.CopyFrom(hds.ScheduleDayAttributes(name=name)) if events is not None: diff.events.extend(_to_wire_schedule_event(event) for event in events) + logger.debug( + "RPC UpdateScheduleDay schedule_day_id=%s system_id=%s space_id=%s", + schedule_day_id, + system_id, + space_id, + ) try: result = await self._stub.UpdateScheduleDay( hds.UpdateScheduleDayRequest(schedule_day=diff) @@ -448,6 +484,7 @@ async def update_schedule_day( async def delete_schedule_week(self, schedule_week_id: str) -> None: """Delete a schedule week.""" + logger.debug("RPC DeleteScheduleWeek schedule_week_id=%s", schedule_week_id) try: await self._stub.DeleteScheduleWeek( hds.DeleteScheduleWeekRequest(schedule_week_id=schedule_week_id) @@ -474,6 +511,7 @@ async def update_location_schedule_execution( ), controls=hds.LocationControls(schedule_execution=execution), ) + logger.debug("RPC UpdateLocation location_id=%s system_id=%s", location_id, system_id) try: await self._stub.UpdateLocation(hds.UpdateLocationRequest(location=diff)) except grpc.aio.AioRpcError as exc: diff --git a/src/quilt_hp/services/streaming.py b/src/quilt_hp/services/streaming.py index 2e55ee5..d4541e4 100644 --- a/src/quilt_hp/services/streaming.py +++ b/src/quilt_hp/services/streaming.py @@ -11,6 +11,7 @@ import contextlib import inspect import logging +import time from collections.abc import AsyncIterator, Awaitable, Callable, Sequence from dataclasses import dataclass, field from typing import Any, Protocol, cast @@ -55,6 +56,9 @@ def Subscribe( RefreshCallback = Callable[[], Awaitable[None]] | Callable[[TokenRefreshContext], Awaitable[None]] +type _EventKey = tuple[str, str] +type _AnyCallback = Callable[[Any], Awaitable[None] | None] + async def _invoke_refresh_callback( refresh_callback: RefreshCallback, context: TokenRefreshContext @@ -139,6 +143,14 @@ class StreamEvent: raw_bytes: bytes | None = None +@dataclass(slots=True) +class _PendingDispatch: + value: Any + callbacks: tuple[_AnyCallback, ...] + error_message: str + task: asyncio.Task[None] + + @dataclass class NotifierStream: """Async manager for the NotifierService bidirectional stream. @@ -167,6 +179,9 @@ class NotifierStream: reconnect_delay_s: Initial back-off delay in seconds before the first reconnect. Doubles on each subsequent attempt, capped at 60 s. Default: ``1.0``. + debounce_s: Quiet period in seconds for coalescing updates by entity + type and ID before dispatching the latest event. Default: ``0.0`` + (dispatch immediately). """ _channel: grpc.aio.Channel @@ -175,6 +190,7 @@ class NotifierStream: _authenticate: RefreshCallback | None = None _max_reconnects: int = -1 _reconnect_delay_s: float = 1.0 + _debounce_s: float = 0.0 _space_callbacks: list[SpaceCallback] = field(default_factory=list, init=False) _idu_callbacks: list[IndoorUnitCallback] = field(default_factory=list, init=False) @@ -186,9 +202,19 @@ class NotifierStream: _sui_callbacks: list[SoftwareUpdateInfoCallback] = field(default_factory=list, init=False) _error_callbacks: list[ErrorCallback] = field(default_factory=list, init=False) _request_queue: asyncio.Queue[notifier.SubscribeRequest] = field(init=False) + _subscription_lock: asyncio.Lock = field(init=False) + _lifecycle_lock: asyncio.Lock = field(init=False) + _pending_dispatch_lock: asyncio.Lock = field(init=False) + _stop_event: asyncio.Event = field(init=False) _running: bool = field(default=False, init=False) _task: asyncio.Task[None] | None = field(default=None, init=False) + _active_call: Any | None = field(default=None, init=False) + _pending_dispatches: dict[_EventKey, _PendingDispatch] = field( + default_factory=dict, init=False + ) _error: Exception | None = field(default=None, init=False) + _last_event_at: float | None = field(default=None, init=False) + _stream_state: str = field(default="idle", init=False) def __post_init__(self) -> None: factory = cast( @@ -197,6 +223,10 @@ def __post_init__(self) -> None: ) self._stub: _NotifierServiceStub = factory(self._channel) self._request_queue = asyncio.Queue() + self._subscription_lock = asyncio.Lock() + self._lifecycle_lock = asyncio.Lock() + self._pending_dispatch_lock = asyncio.Lock() + self._stop_event = asyncio.Event() # --- Public constructor (friendlier than dataclass __init__) --- @@ -210,6 +240,7 @@ def create( authenticate: RefreshCallback | None = None, max_reconnects: int = -1, reconnect_delay_s: float = 1.0, + debounce_s: float = 0.0, ) -> NotifierStream: """Create a NotifierStream with named parameters.""" return cls( @@ -219,6 +250,7 @@ def create( _authenticate=authenticate, _max_reconnects=max_reconnects, _reconnect_delay_s=reconnect_delay_s, + _debounce_s=debounce_s, ) # --- Callback registration --- @@ -264,29 +296,48 @@ def error(self) -> Exception | None: """The last fatal stream error, or None if the stream is healthy.""" return self._error + @property + def is_connected(self) -> bool: + """Whether the stream currently has an active connection.""" + return self._stream_state == "connected" + + @property + def last_event_at(self) -> float | None: + """Monotonic timestamp of the last received non-heartbeat event.""" + return self._last_event_at + + @property + def stream_state(self) -> str: + """Current stream lifecycle state.""" + return self._stream_state + # --- Subscription management --- async def subscribe(self, topics: list[str]) -> None: """Add more topics to the subscription (after stream is started).""" - self._topics.extend(topics) - await self._request_queue.put(_make_subscribe_request(topics)) + async with self._subscription_lock: + self._topics.extend(topics) + await self._request_queue.put(_make_subscribe_request(topics)) async def unsubscribe(self, topics: list[str]) -> None: """Remove topics from the subscription.""" - for t in topics: - if t in self._topics: - self._topics.remove(t) req = notifier.SubscribeRequest( remove=notifier.TopicsMessage( subscriptions=[notifier.Subscription(topic=t) for t in topics] ) ) - await self._request_queue.put(req) + async with self._subscription_lock: + for t in topics: + if t in self._topics: + self._topics.remove(t) + await self._request_queue.put(req) # --- Internal stream machinery --- async def _request_iterator( self, + topics: list[str], + request_queue: asyncio.Queue[notifier.SubscribeRequest], ) -> AsyncIterator[notifier.SubscribeRequest]: """Yield SubscribeRequests from initial subscription, then queue. @@ -294,10 +345,10 @@ async def _request_iterator( without re-sending the topic list; gRPC channel keepalives (configured in GRPC_CHANNEL_OPTIONS) handle the underlying TCP connection. """ - yield _make_subscribe_request(self._topics) + yield _make_subscribe_request(topics) while self._running: try: - req = await asyncio.wait_for(self._request_queue.get(), timeout=30.0) + req = await asyncio.wait_for(request_queue.get(), timeout=30.0) yield req except TimeoutError: continue # keepalive handled by gRPC channel options @@ -390,70 +441,157 @@ def _parse_event(self, evt: object) -> StreamEvent | None: return event + async def _invoke_callbacks[T]( + self, + callbacks: Sequence[Callable[[T], Awaitable[None] | None]], + arg: T, + error_message: str, + ) -> None: + for callback in callbacks: + try: + await _dispatch(callback, arg) + except Exception: + logger.exception(error_message) + + async def _dispatch_debounced(self, key: _EventKey) -> None: + try: + await asyncio.sleep(self._debounce_s) + async with self._pending_dispatch_lock: + pending = self._pending_dispatches.get(key) + if pending is None or pending.task is not asyncio.current_task(): + return + self._pending_dispatches.pop(key, None) + await self._invoke_callbacks(pending.callbacks, pending.value, pending.error_message) + except asyncio.CancelledError: + raise + + async def _queue_debounced_dispatch[T]( + self, + entity_type: str, + entity: T, + callbacks: Sequence[Callable[[T], Awaitable[None] | None]], + error_message: str, + ) -> None: + key = (entity_type, str(getattr(cast("Any", entity), "id", ""))) + callback_snapshot = tuple(cast("Sequence[_AnyCallback]", callbacks)) + async with self._pending_dispatch_lock: + existing = self._pending_dispatches.get(key) + if existing is not None: + existing.task.cancel() + task = asyncio.create_task(self._dispatch_debounced(key)) + self._pending_dispatches[key] = _PendingDispatch( + value=entity, + callbacks=callback_snapshot, + error_message=error_message, + task=task, + ) + + async def _cancel_pending_dispatches(self) -> None: + async with self._pending_dispatch_lock: + pending = list(self._pending_dispatches.values()) + self._pending_dispatches.clear() + for item in pending: + item.task.cancel() + if pending: + await asyncio.gather(*(item.task for item in pending), return_exceptions=True) + + async def _dispatch_entity[T]( + self, + entity_type: str, + entity: T, + callbacks: Sequence[Callable[[T], Awaitable[None] | None]], + error_message: str, + ) -> None: + if self._debounce_s <= 0: + await self._invoke_callbacks(callbacks, entity, error_message) + return + await self._queue_debounced_dispatch(entity_type, entity, callbacks, error_message) + + async def _dispatch_parsed_event(self, parsed: StreamEvent) -> None: + if parsed.space is not None: + await self._dispatch_entity( + "space", parsed.space, self._space_callbacks, "Error in space callback" + ) + if parsed.indoor_unit is not None: + await self._dispatch_entity( + "indoor_unit", + parsed.indoor_unit, + self._idu_callbacks, + "Error in indoor unit callback", + ) + if parsed.outdoor_unit is not None: + await self._dispatch_entity( + "outdoor_unit", + parsed.outdoor_unit, + self._odu_callbacks, + "Error in outdoor unit callback", + ) + if parsed.controller is not None: + await self._dispatch_entity( + "controller", + parsed.controller, + self._ctrl_callbacks, + "Error in controller callback", + ) + if parsed.qsm is not None: + await self._dispatch_entity( + "qsm", parsed.qsm, self._qsm_callbacks, "Error in QSM callback" + ) + if parsed.remote_sensor is not None: + await self._dispatch_entity( + "remote_sensor", + parsed.remote_sensor, + self._rs_callbacks, + "Error in remote sensor callback", + ) + if parsed.controller_remote_sensor is not None: + await self._dispatch_entity( + "controller_remote_sensor", + parsed.controller_remote_sensor, + self._crs_callbacks, + "Error in controller remote sensor callback", + ) + if parsed.software_update_info is not None: + await self._dispatch_entity( + "software_update_info", + parsed.software_update_info, + self._sui_callbacks, + "Error in software update info callback", + ) + async def _run_one_stream(self) -> None: """Run a single stream connection until it ends or errors.""" metadata = self._metadata_provider() if self._metadata_provider else None - call = self._stub.Subscribe( - self._request_iterator(), - metadata=metadata, - ) - async for response in call: - for ctrl in response.control_events: - event_name = notifier.ControlEventType.Name(ctrl.type) - logger.debug("Control event: %s topics=%s", event_name, list(ctrl.topics)) - - for evt in response.notifier_events: - parsed = self._parse_event(evt) - if parsed is None: - continue - if parsed.space is not None: - for space_cb in self._space_callbacks: - try: - await _dispatch(space_cb, parsed.space) - except Exception: - logger.exception("Error in space callback") - if parsed.indoor_unit is not None: - for idu_cb in self._idu_callbacks: - try: - await _dispatch(idu_cb, parsed.indoor_unit) - except Exception: - logger.exception("Error in indoor unit callback") - if parsed.outdoor_unit is not None: - for odu_cb in self._odu_callbacks: - try: - await _dispatch(odu_cb, parsed.outdoor_unit) - except Exception: - logger.exception("Error in outdoor unit callback") - if parsed.controller is not None: - for ctrl_cb in self._ctrl_callbacks: - try: - await _dispatch(ctrl_cb, parsed.controller) - except Exception: - logger.exception("Error in controller callback") - if parsed.qsm is not None: - for qsm_cb in self._qsm_callbacks: - try: - await _dispatch(qsm_cb, parsed.qsm) - except Exception: - logger.exception("Error in QSM callback") - if parsed.remote_sensor is not None: - for rs_cb in self._rs_callbacks: - try: - await _dispatch(rs_cb, parsed.remote_sensor) - except Exception: - logger.exception("Error in remote sensor callback") - if parsed.controller_remote_sensor is not None: - for crs_cb in self._crs_callbacks: - try: - await _dispatch(crs_cb, parsed.controller_remote_sensor) - except Exception: - logger.exception("Error in controller remote sensor callback") - if parsed.software_update_info is not None: - for sui_cb in self._sui_callbacks: - try: - await _dispatch(sui_cb, parsed.software_update_info) - except Exception: - logger.exception("Error in software update info callback") + async with self._subscription_lock: + # Snapshot topics and queue together so reconnect queue swaps and + # subscribe/unsubscribe calls cannot interleave between them. + topics = list(self._topics) + request_queue = self._request_queue + call = self._stub.Subscribe( + self._request_iterator(topics, request_queue), + metadata=metadata, + ) + self._active_call = call + self._stream_state = "connected" + try: + async for response in call: + saw_event = False + for ctrl in response.control_events: + saw_event = True + event_name = notifier.ControlEventType.Name(ctrl.type) + logger.debug("Control event: %s topics=%s", event_name, list(ctrl.topics)) + + for evt in response.notifier_events: + parsed = self._parse_event(evt) + if parsed is None: + continue + saw_event = True + await self._dispatch_parsed_event(parsed) + if saw_event: + self._last_event_at = time.monotonic() + finally: + if self._active_call is call: + self._active_call = None async def _run_stream_with_reconnect(self) -> None: """Run the stream with automatic reconnect and exponential back-off.""" @@ -473,6 +611,7 @@ async def _run_stream_with_reconnect(self) -> None: can_retry = self._max_reconnects < 0 or attempt < self._max_reconnects if is_unauth and self._authenticate is not None and can_retry: + self._stream_state = "reconnecting" logger.warning( "Stream got UNAUTHENTICATED; refreshing token (attempt %d)", attempt + 1, @@ -487,8 +626,10 @@ async def _run_stream_with_reconnect(self) -> None: except Exception: logger.exception("Token refresh failed; giving up stream") self._error = exc + self._stream_state = "error" break elif can_retry: + self._stream_state = "reconnecting" logger.warning( "Stream error %s: %s; reconnecting in %.1fs (attempt %d)", exc.code(), @@ -503,13 +644,25 @@ async def _run_stream_with_reconnect(self) -> None: exc.details(), ) self._error = QuiltStreamError(f"Stream error: {exc.code()} - {exc.details()}") + self._stream_state = "error" break - await asyncio.sleep(delay) + if await self._wait_for_stop(delay): + break delay = min(delay * 2, 60.0) attempt += 1 - # Reset request queue so the next connection re-subscribes. - self._request_queue = asyncio.Queue() + async with self._subscription_lock: + logger.info( + "Resetting subscription queue before reconnect; " + "tracked topics will be re-subscribed on the next stream" + ) + # _topics is the source of truth. The next request iterator + # snapshots the current topics and sends them as its first + # request, so discarding any stale queued requests is safe. + self._request_queue = asyncio.Queue() + + if self._error is None and self._stream_state != "stopped": + self._stream_state = "stopped" if self._error is not None: for cb in self._error_callbacks: @@ -521,20 +674,55 @@ async def _run_stream_with_reconnect(self) -> None: # Propagate to the task so the caller can observe it raise self._error + async def _wait_for_stop(self, delay: float) -> bool: + sleep_task = asyncio.create_task(asyncio.sleep(delay)) + stop_task = asyncio.create_task(self._stop_event.wait()) + done, pending = await asyncio.wait( + {sleep_task, stop_task}, return_when=asyncio.FIRST_COMPLETED + ) + for task in pending: + task.cancel() + if pending: + await asyncio.gather(*pending, return_exceptions=True) + return stop_task in done + + async def _run_until_stopped(self) -> None: + try: + await self._run_stream_with_reconnect() + finally: + await self._cancel_pending_dispatches() + async with self._lifecycle_lock: + self._running = False + self._active_call = None + if self._task is asyncio.current_task(): + self._task = None + if self._error is None and self._stream_state != "error": + self._stream_state = "stopped" + # --- Lifecycle --- async def run_forever(self) -> None: """Run the stream inline (blocking) until cancelled or fatal error.""" - self._running = True - await self._run_stream_with_reconnect() + async with self._lifecycle_lock: + if self._running: + return + self._running = True + self._error = None + self._stream_state = "idle" + self._stop_event.clear() + await self._run_until_stopped() async def start(self) -> None: """Start the stream listener as a background task.""" - if self._running: - return - self._running = True - self._task = asyncio.create_task(self._run_stream_with_reconnect()) - self._task.add_done_callback(self._on_task_done) + async with self._lifecycle_lock: + if self._running: + return + self._running = True + self._error = None + self._stream_state = "idle" + self._stop_event.clear() + self._task = asyncio.create_task(self._run_until_stopped()) + self._task.add_done_callback(self._on_task_done) def _on_task_done(self, task: asyncio.Task[None]) -> None: """Log unhandled task exceptions so they aren't silently swallowed.""" @@ -546,12 +734,24 @@ def _on_task_done(self, task: asyncio.Task[None]) -> None: async def stop(self) -> None: """Stop the stream listener.""" - self._running = False - if self._task is not None: - self._task.cancel() - with contextlib.suppress(asyncio.CancelledError, QuiltStreamError): - await self._task + async with self._lifecycle_lock: + self._running = False + self._stream_state = "stopped" + self._stop_event.set() + task = self._task self._task = None + active_call = self._active_call + + cancel = getattr(active_call, "cancel", None) + if callable(cancel): + cancel() + + if task is not None: + task.cancel() + with contextlib.suppress(asyncio.CancelledError, QuiltStreamError): + await task + + await self._cancel_pending_dispatches() async def __aenter__(self) -> NotifierStream: await self.start() diff --git a/src/quilt_hp/services/system.py b/src/quilt_hp/services/system.py index 4c57dff..5650b32 100644 --- a/src/quilt_hp/services/system.py +++ b/src/quilt_hp/services/system.py @@ -6,22 +6,35 @@ from __future__ import annotations import datetime +import logging from collections.abc import Callable from typing import TYPE_CHECKING, Protocol, cast -import grpc.aio from google.protobuf.timestamp_pb2 import Timestamp +if TYPE_CHECKING: + import grpc.aio + from quilt_hp._proto import quilt_services_pb2 as svc from quilt_hp._proto import quilt_services_pb2_grpc as svc_grpc -from quilt_hp.exceptions import QuiltError from quilt_hp.models.energy import EnergyBucket, SpaceEnergyMetrics +from quilt_hp.models.enums import MetricBucketStatus from quilt_hp.models.system import SystemInfo +from quilt_hp.services import grpc_call + +logger = logging.getLogger(__name__) if TYPE_CHECKING: from datetime import datetime as _datetime +def _safe_bucket_status(value: int) -> MetricBucketStatus: + try: + return MetricBucketStatus(value) + except ValueError: + return MetricBucketStatus.UNSPECIFIED + + class _SystemInformationServiceStub(Protocol): async def ListSystems( self, request: svc.ListSystemInformationRequest @@ -44,10 +57,9 @@ def __init__(self, channel: grpc.aio.Channel) -> None: async def list_systems(self) -> list[SystemInfo]: """List all systems the authenticated user has access to.""" - try: + logger.debug("Listing systems") + async with grpc_call("ListSystems"): resp = await self._stub.ListSystems(svc.ListSystemInformationRequest()) - except grpc.aio.AioRpcError as exc: - raise QuiltError(f"ListSystems failed: {exc.details()}") from exc return [ SystemInfo( id=s.id, @@ -64,12 +76,13 @@ async def get_energy_metrics( end: _datetime, ) -> list[SpaceEnergyMetrics]: """Fetch hourly energy metrics for all spaces in a time range.""" + logger.debug("Fetching energy metrics for system %s", system_id) start_ts = Timestamp() start_ts.FromSeconds(int(start.timestamp())) end_ts = Timestamp() end_ts.FromSeconds(int(end.timestamp())) - try: + async with grpc_call("GetEnergyMetrics"): result = await self._stub.GetEnergyMetrics( svc.GetEnergyMetricsRequest( system_id=system_id, @@ -78,8 +91,6 @@ async def get_energy_metrics( preferred_resolution=svc.TIME_RESOLUTION_HOURLY, ) ) - except grpc.aio.AioRpcError as exc: - raise QuiltError(f"GetEnergyMetrics failed: {exc.details()}") from exc metrics = [] for sm in result.space_energy_metrics: @@ -87,7 +98,7 @@ async def get_energy_metrics( EnergyBucket( start_time=b.start_time.ToDatetime(tzinfo=datetime.UTC), energy_kwh=b.energy_kwh, - status=b.status, + status=_safe_bucket_status(b.status), ) for b in sm.energy_buckets ] diff --git a/src/quilt_hp/services/user.py b/src/quilt_hp/services/user.py index 887eb94..d95c25d 100644 --- a/src/quilt_hp/services/user.py +++ b/src/quilt_hp/services/user.py @@ -2,6 +2,7 @@ from __future__ import annotations +import logging from collections.abc import Callable from dataclasses import dataclass from enum import IntEnum @@ -13,6 +14,8 @@ from quilt_hp._proto import quilt_services_pb2_grpc as svc_grpc from quilt_hp.exceptions import QuiltError +logger = logging.getLogger(__name__) + class DeclaredUserType(IntEnum): """Declared user type used by UserAttributes.""" @@ -75,6 +78,7 @@ def __init__(self, channel: grpc.aio.Channel) -> None: async def get_current_user(self) -> User: """Get the currently authenticated user.""" + logger.debug("Getting current user") try: response = cast( "Any", @@ -92,6 +96,7 @@ async def update_current_user( phone_number: str | None = None, ) -> User: """Update first/last name and optional phone number for current user.""" + logger.debug("Updating current user") try: response = cast( "Any", @@ -109,6 +114,7 @@ async def update_current_user( async def get_user_attributes(self) -> UserAttributes: """Get current user's additional attributes.""" + logger.debug("Getting user attributes") try: response = cast( "Any", @@ -124,6 +130,7 @@ async def patch_user_attributes( declared_user_type: DeclaredUserType, ) -> UserAttributes: """Patch user attributes for the current user.""" + logger.debug("Patching user attributes") try: response = cast( "Any", diff --git a/src/quilt_hp/transport.py b/src/quilt_hp/transport.py index f6ce5e2..7adb1e8 100644 --- a/src/quilt_hp/transport.py +++ b/src/quilt_hp/transport.py @@ -3,6 +3,8 @@ from __future__ import annotations import inspect +import logging +import weakref from collections.abc import Awaitable, Callable from typing import cast @@ -23,6 +25,9 @@ ) type TokenProviderLike = Callable[[], str] | CurrentTokenProvider +logger = logging.getLogger(__name__) +_REFRESH_CALLBACK_HAS_PARAMS: weakref.WeakKeyDictionary[object, bool] = weakref.WeakKeyDictionary() + def _resolve_token_provider(token_provider: TokenProviderLike) -> Callable[[], str]: if callable(token_provider): @@ -34,11 +39,18 @@ async def _invoke_refresh_callback( refresh_callback: RefreshCallback, context: TokenRefreshContext ) -> None: try: - has_params = bool(inspect.signature(refresh_callback).parameters) + has_params = _REFRESH_CALLBACK_HAS_PARAMS.get(refresh_callback) except TypeError: - has_params = False - except ValueError: - has_params = False + has_params = None # non-weakrefable callable — skip cache + if has_params is None: + try: + has_params = bool(inspect.signature(refresh_callback).parameters) + except TypeError, ValueError: + has_params = False + try: + _REFRESH_CALLBACK_HAS_PARAMS[refresh_callback] = has_params + except TypeError: + pass # non-weakrefable callable — skip caching if has_params: await cast("Callable[[TokenRefreshContext], Awaitable[None]]", refresh_callback)(context) return @@ -66,6 +78,7 @@ def __init__( self._refresh_callback = refresh_callback def _metadata(self) -> list[tuple[str, str]]: + logger.debug("Attaching auth metadata") return [ ("authorization", self._token_provider()), ("x-quilt-app-version", APP_VERSION), @@ -125,6 +138,7 @@ async def intercept_unary_unary( return await continuation(self._patch(client_call_details), request) except grpc.aio.AioRpcError as exc: if exc.code() == grpc.StatusCode.UNAUTHENTICATED and self._refresh_callback: + logger.warning("Retrying unary RPC after UNAUTHENTICATED response") return await self._refresh_and_retry(continuation, client_call_details, request) raise @@ -141,6 +155,7 @@ async def intercept_unary_stream( return await continuation(self._patch(client_call_details), request) except grpc.aio.AioRpcError as exc: if exc.code() == grpc.StatusCode.UNAUTHENTICATED and self._refresh_callback: + logger.warning("Retrying streaming RPC setup after UNAUTHENTICATED response") return await self._refresh_and_retry(continuation, client_call_details, request) raise @@ -184,6 +199,7 @@ def create_channel( An async gRPC channel with TLS and auth interceptor. """ host = grpc_host(environment) + logger.debug("Creating gRPC channel for host %s", host) creds = grpc.ssl_channel_credentials() interceptors = [_AuthInterceptor(token_provider, refresh_callback)] return grpc.aio.secure_channel( @@ -200,6 +216,7 @@ def auth_metadata(token_provider: TokenProviderLike) -> list[tuple[str, str]]: Useful for stream-stream RPCs where the channel interceptor may not fire. """ resolved_provider = _resolve_token_provider(token_provider) + logger.debug("Building auth metadata") return [ ("authorization", resolved_provider()), ("x-quilt-app-version", APP_VERSION), diff --git a/tests/conftest.py b/tests/conftest.py index d21d717..252b617 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,87 @@ """Shared test fixtures.""" from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from quilt_hp.models.enums import ( + ComfortSettingOverride, + ComfortSettingType, + HVACMode, + HVACState, + OccupancyMode, + SafetyHeatingMode, +) +from quilt_hp.models.space import Space, SpaceControls, SpaceSettings, SpaceState +from quilt_hp.models.system import Location, SystemSnapshot + + +def _ns(**kwargs: object) -> SimpleNamespace: + """Build a SimpleNamespace recursively from keyword args.""" + return SimpleNamespace(**kwargs) + + +def _make_header(object_id: str = "obj-1", system_id: str = "sys-1") -> SimpleNamespace: + return _ns(object_id=object_id, system_id=system_id) + + +@pytest.fixture +def fake_space() -> Space: + return Space( + id="space-1", + system_id="sys-1", + name="Living Room", + parent_space_id="root-1", + settings=SpaceSettings( + name="Living Room", + timezone="UTC", + occupancy_mode=OccupancyMode.ENABLED, + occupied_timeout_s=180.0, + unoccupied_timeout_s=1200.0, + safety_heating=SafetyHeatingMode.ENABLED, + ), + controls=SpaceControls( + hvac_mode=HVACMode.HEAT, + temperature_setpoint_c=20.0, + cooling_setpoint_c=24.0, + heating_setpoint_c=20.0, + comfort_setting_id="comfort-1", + comfort_setting_override=ComfortSettingOverride.NONE, + ), + state=SpaceState( + ambient_temperature_c=21.0, + hvac_state=HVACState.HEAT, + setpoint_c=20.0, + comfort_setting_id="comfort-1", + ), + active_comfort_setting_type=ComfortSettingType.ACTIVE, + ) + + +@pytest.fixture +def fake_snapshot(fake_space: Space) -> SystemSnapshot: + return SystemSnapshot( + spaces=[fake_space], + indoor_units=[], + outdoor_units=[], + controllers=[], + quilt_smart_modules=[], + comfort_settings=[], + schedule_weeks=[], + schedule_days=[], + remote_sensors=[], + controller_remote_sensors=[], + software_update_infos=[], + locations=[ + Location( + id="loc-1", + name="Home", + system_id="sys-1", + timezone="UTC", + schedule_paused=False, + ) + ], + timezone="UTC", + ) diff --git a/tests/test_auth_store_settings_edges.py b/tests/test_auth_store_settings_edges.py index fa32376..a4105ff 100644 --- a/tests/test_auth_store_settings_edges.py +++ b/tests/test_auth_store_settings_edges.py @@ -118,5 +118,5 @@ def test_settings_store_corruption_and_schema_edges(tmp_path: Path) -> None: settings = store.load() assert settings.email is None assert settings.home is None - assert settings.use_fahrenheit is True + assert settings.use_fahrenheit is False assert settings.dark is None diff --git a/tests/test_cli_login.py b/tests/test_cli_login.py index 46d3956..7403270 100644 --- a/tests/test_cli_login.py +++ b/tests/test_cli_login.py @@ -8,6 +8,7 @@ from unittest.mock import patch from quilt_hp.cli import main as cli_main +from quilt_hp.exceptions import QuiltAuthError class _FakeClient: @@ -27,7 +28,7 @@ async def __aexit__(self, *_args: object) -> None: async def login(self, otp_callback: Callable[[str], object] | None = None) -> None: if otp_callback is None: _FakeClient.events.append("silent-login") - raise RuntimeError("need OTP") + raise QuiltAuthError("need OTP") _FakeClient.events.append("otp-login-start") otp = otp_callback("user@example.com") diff --git a/tests/test_cli_surfaces_extra.py b/tests/test_cli_surfaces_extra.py index 54d5154..d229549 100644 --- a/tests/test_cli_surfaces_extra.py +++ b/tests/test_cli_surfaces_extra.py @@ -7,6 +7,7 @@ from typer.testing import CliRunner from quilt_hp.cli import main as cli_main +from quilt_hp.exceptions import QuiltAuthError, QuiltError from quilt_hp.models.enums import HVACMode runner = CliRunner() @@ -195,3 +196,33 @@ async def get_snapshot(self) -> SimpleNamespace: assert result.exit_code == 0 assert "DAY" in result.stdout + + +def test_info_command_handles_auth_errors() -> None: + class _AuthErrorClient(_FakeClient): + async def login(self) -> None: + raise QuiltAuthError("bad credentials") + + with ( + patch.object(cli_main, "_resolve", return_value=("user@example.com", None)), + patch.object(cli_main, "QuiltClient", _AuthErrorClient), + ): + result = runner.invoke(cli_main.app, ["info"]) + + assert result.exit_code == 1 + assert "Authentication failed: bad credentials" in result.stdout + + +def test_info_command_handles_quilt_errors() -> None: + class _QuiltErrorClient(_FakeClient): + async def get_snapshot(self) -> SimpleNamespace: + raise QuiltError("snapshot unavailable") + + with ( + patch.object(cli_main, "_resolve", return_value=("user@example.com", None)), + patch.object(cli_main, "QuiltClient", _QuiltErrorClient), + ): + result = runner.invoke(cli_main.app, ["info"]) + + assert result.exit_code == 1 + assert "Error: snapshot unavailable" in result.stdout diff --git a/tests/test_client_service_error_paths.py b/tests/test_client_service_error_paths.py index c6b55ea..81398fb 100644 --- a/tests/test_client_service_error_paths.py +++ b/tests/test_client_service_error_paths.py @@ -112,6 +112,35 @@ async def _fake_authenticate(*_args: object, **kwargs: object) -> str: assert fake_channel.close.await_count == 1 +@pytest.mark.asyncio +async def test_client_requires_login_and_close_clears_services() -> None: + client = QuiltClient("user@example.com") + + with pytest.raises(QuiltError, match=r"Client not connected\. Call login\(\) first\."): + await client.list_systems() + with pytest.raises(QuiltError, match=r"Client not connected\. Call login\(\) first\."): + await client.get_snapshot() + with pytest.raises(QuiltError, match=r"Client not connected\. Call login\(\) first\."): + await client.get_current_user() + with pytest.raises(QuiltError, match=r"Client not connected\. Call login\(\) first\."): + client.stream(["hds/space/space-1"]) + + fake_channel = MagicMock() + fake_channel.close = AsyncMock() + client._channel = fake_channel + client._hds = MagicMock() + client._sysinfo = MagicMock() + client._user_svc = MagicMock() + + await client.close() + + fake_channel.close.assert_awaited_once() + assert client._channel is None + assert client._hds is None + assert client._sysinfo is None + assert client._user_svc is None + + @pytest.mark.asyncio async def test_client_wrapper_methods_and_context_manager(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr("quilt_hp.client.authenticate", AsyncMock(return_value="jwt-token")) @@ -172,6 +201,11 @@ def _fake_create_channel(*_args: object, **_kwargs: object) -> object: async with client: assert client.get_current_token() == "jwt-token" + assert client._channel is None + assert client._hds is None + assert client._sysinfo is None + assert client._user_svc is None + @pytest.mark.asyncio async def test_system_service_success_and_error_paths(monkeypatch: pytest.MonkeyPatch) -> None: @@ -214,6 +248,7 @@ async def test_system_service_success_and_error_paths(monkeypatch: pytest.Monkey ) assert metrics[0].space_id == "space-1" assert len(metrics[0].buckets) == 2 + assert metrics[0].buckets[0].status == system_service.MetricBucketStatus.COMPLETE err_stub = MagicMock( ListSystems=AsyncMock(side_effect=_FakeRpcError(grpc.StatusCode.UNKNOWN, "x")) diff --git a/tests/test_grpc_retry.py b/tests/test_grpc_retry.py new file mode 100644 index 0000000..538aad2 --- /dev/null +++ b/tests/test_grpc_retry.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import logging + +import grpc +import pytest + +from quilt_hp.exceptions import QuiltConnectionError, QuiltError +from quilt_hp.services import grpc_call + + +class _FakeRpcError(grpc.aio.AioRpcError): + def __init__(self, code: grpc.StatusCode, details: str = "") -> None: + self._code = code + self._details = details + + def code(self) -> grpc.StatusCode: # type: ignore[override] + return self._code + + def details(self) -> str: # type: ignore[override] + return self._details + + +@pytest.mark.asyncio +async def test_grpc_call_translates_transient_errors_without_retries() -> None: + with pytest.raises(QuiltConnectionError, match="ListSystems failed: down"): + async with grpc_call("ListSystems"): + raise _FakeRpcError(grpc.StatusCode.UNAVAILABLE, "down") + + +@pytest.mark.asyncio +async def test_grpc_call_retries_transient_errors( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + calls = 0 + sleep_calls: list[float] = [] + + async def _flaky(_request: object) -> str: + nonlocal calls + calls += 1 + if calls < 3: + raise _FakeRpcError(grpc.StatusCode.UNAVAILABLE, "down") + return "ok" + + async def _fake_sleep(delay: float) -> None: + sleep_calls.append(delay) + + monkeypatch.setattr("quilt_hp.services.asyncio.sleep", _fake_sleep) + + with caplog.at_level(logging.WARNING): + async with grpc_call( + "ListSystems", max_retries=2, retry_delay=0.5, retry_backoff=3.0 + ) as call: + result = await call(_flaky, object()) + + assert result == "ok" + assert calls == 3 + assert sleep_calls == [0.5, 1.5] + assert "retrying in 0.5s (1/2)" in caplog.text + assert "retrying in 1.5s (2/2)" in caplog.text + + +@pytest.mark.asyncio +async def test_grpc_call_stops_after_max_retries(monkeypatch: pytest.MonkeyPatch) -> None: + calls = 0 + sleep_calls: list[float] = [] + + async def _always_fails() -> None: + nonlocal calls + calls += 1 + raise _FakeRpcError(grpc.StatusCode.DEADLINE_EXCEEDED, "timeout") + + async def _fake_sleep(delay: float) -> None: + sleep_calls.append(delay) + + monkeypatch.setattr("quilt_hp.services.asyncio.sleep", _fake_sleep) + + with pytest.raises(QuiltConnectionError, match="GetEnergyMetrics failed: timeout"): + async with grpc_call("GetEnergyMetrics", max_retries=1) as call: + await call(_always_fails) + + assert calls == 2 + assert sleep_calls == [1.0] + + +@pytest.mark.asyncio +async def test_grpc_call_does_not_retry_non_transient(monkeypatch: pytest.MonkeyPatch) -> None: + sleep_calls: list[float] = [] + + async def _unknown() -> None: + raise _FakeRpcError(grpc.StatusCode.UNKNOWN, "boom") + + async def _fake_sleep(delay: float) -> None: + sleep_calls.append(delay) + + monkeypatch.setattr("quilt_hp.services.asyncio.sleep", _fake_sleep) + + with pytest.raises(QuiltError, match="UpdateSpace failed: boom"): + async with grpc_call("UpdateSpace", max_retries=3) as call: + await call(_unknown) + + assert sleep_calls == [] + + +@pytest.mark.asyncio +async def test_grpc_call_preserves_existing_quilt_errors() -> None: + with pytest.raises(QuiltError, match="already wrapped"): + async with grpc_call("UpdateSpace", max_retries=3) as call: + await call(_raise_wrapped) + + +async def _raise_wrapped() -> None: + raise QuiltError("already wrapped") diff --git a/tests/test_hds_payloads.py b/tests/test_hds_payloads.py new file mode 100644 index 0000000..9cce10a --- /dev/null +++ b/tests/test_hds_payloads.py @@ -0,0 +1,124 @@ +"""Payload-shaping tests for HomeDatastoreService update_space.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +from google.protobuf.timestamp_pb2 import Timestamp + +from quilt_hp._proto import quilt_hds_pb2 as hds +from quilt_hp.models.enums import HVACMode +from quilt_hp.models.space import Space +from quilt_hp.models.system import SystemSnapshot +from quilt_hp.services import hds as hds_service + + +def _fixed_timestamp() -> Timestamp: + ts = Timestamp() + ts.FromSeconds(123) + return ts + + +async def _capture_update_space_diff( + monkeypatch: pytest.MonkeyPatch, + space: Space, + **kwargs: object, +) -> hds.Space: + captured: dict[str, hds.UpdateSpaceRequest] = {} + + async def _update_space(request: hds.UpdateSpaceRequest) -> hds.Space: + captured["request"] = request + return request.diff + + class _Stub: + def __init__(self) -> None: + self.UpdateSpace = _update_space + + monkeypatch.setattr(hds_service.hds_grpc, "HomeDatastoreServiceStub", lambda _ch: _Stub()) + monkeypatch.setattr(hds_service, "_now_ts", _fixed_timestamp) + monkeypatch.setattr(hds_service.Space, "from_proto", lambda proto: proto) + + service = hds_service.HomeDatastoreService(MagicMock()) + diff = await service.update_space(space, **kwargs) + return captured["request"].diff if isinstance(diff, hds.Space) else diff + + +@pytest.mark.parametrize( + ("mode", "expected_temp", "expected_override", "expected_comfort_setting_id"), + [ + ( + HVACMode.HEAT, + 20.0, + hds.COMFORT_SETTING_OVERRIDE_UNTIL_NEXT_SCHEDULE, + "comfort-1", + ), + ( + HVACMode.COOL, + 24.0, + hds.COMFORT_SETTING_OVERRIDE_UNTIL_NEXT_SCHEDULE, + "comfort-1", + ), + ( + HVACMode.AUTO, + 24.0, + hds.COMFORT_SETTING_OVERRIDE_UNTIL_NEXT_SCHEDULE, + "comfort-1", + ), + (HVACMode.STANDBY, 24.0, hds.COMFORT_SETTING_OVERRIDE_NONE, ""), + ], +) +async def test_update_space_builds_expected_payload_for_modes( + monkeypatch: pytest.MonkeyPatch, + fake_snapshot: SystemSnapshot, + mode: HVACMode, + expected_temp: float, + expected_override: int, + expected_comfort_setting_id: str, +) -> None: + diff = await _capture_update_space_diff( + monkeypatch, + fake_snapshot.spaces[0], + mode=mode, + ) + + assert diff.header.object_id == "space-1" + assert diff.header.system_id == "sys-1" + assert diff.controls.hvac_mode == mode.value + assert diff.controls.temperature_setpoint_c == pytest.approx(expected_temp) + assert diff.controls.heating_temperature_setpoint_c == pytest.approx(20.0) + assert diff.controls.cooling_temperature_setpoint_c == pytest.approx(24.0) + assert diff.controls.comfort_setting_override == expected_override + assert diff.controls.comfort_setting_id_string == expected_comfort_setting_id + assert diff.controls.updated_ts.seconds == 123 + + +async def test_update_space_auto_deadband_clamps_cooling_setpoint( + monkeypatch: pytest.MonkeyPatch, + fake_snapshot: SystemSnapshot, +) -> None: + diff = await _capture_update_space_diff( + monkeypatch, + fake_snapshot.spaces[0], + mode=HVACMode.AUTO, + heat_setpoint_c=21.0, + cool_setpoint_c=22.0, + ) + + assert diff.controls.heating_temperature_setpoint_c == pytest.approx(21.0) + assert diff.controls.cooling_temperature_setpoint_c == pytest.approx(23.5) + assert diff.controls.temperature_setpoint_c == pytest.approx(23.5) + + +async def test_update_space_standby_clears_comfort_setting( + monkeypatch: pytest.MonkeyPatch, + fake_snapshot: SystemSnapshot, +) -> None: + diff = await _capture_update_space_diff( + monkeypatch, + fake_snapshot.spaces[0], + mode=HVACMode.STANDBY, + ) + + assert diff.controls.comfort_setting_id_string == "" + assert diff.controls.comfort_setting_override == hds.COMFORT_SETTING_OVERRIDE_NONE diff --git a/tests/test_models_extra.py b/tests/test_models_extra.py new file mode 100644 index 0000000..30c818d --- /dev/null +++ b/tests/test_models_extra.py @@ -0,0 +1,285 @@ +"""Additional model conversion coverage.""" + +from __future__ import annotations + +import math +from datetime import UTC, datetime + +import pytest + +from quilt_hp.models.comfort import ComfortSetting +from quilt_hp.models.controller import Controller +from quilt_hp.models.energy import EnergyBucket, SpaceEnergyMetrics +from quilt_hp.models.enums import ( + ComfortSettingType, + FanSpeed, + HVACMode, + LouverMode, + RemoteSensorControlMode, +) +from quilt_hp.models.outdoor_unit import OutdoorUnit +from quilt_hp.models.sensor import ControllerRemoteSensor, RemoteSensor +from quilt_hp.models.software_update import SoftwareUpdateInfo +from tests.conftest import _make_header, _ns + + +@pytest.mark.parametrize( + ("state", "status", "current_version", "target_version", "current", "total", "unit"), + [ + (0, 0, "", "", 0.0, 0.0, 0), + (2, 3, "1.0.0", "1.1.0", 45.0, 100.0, 1), + ], +) +def test_software_update_info_from_proto( + state: int, + status: int, + current_version: str, + target_version: str, + current: float, + total: float, + unit: int, +) -> None: + proto = _ns( + header=_make_header("update-1"), + attributes=_ns( + state=state, + status=status, + current_version=current_version, + target_version=target_version, + current_progress=current, + total_progress=total, + progress_unit=unit, + ), + ) + + info = SoftwareUpdateInfo.from_proto(proto) + + assert info.id == "update-1" + assert info.state == state + assert info.status == status + assert info.current_version == current_version + assert info.target_version == target_version + assert info.current_progress == current + assert info.total_progress == total + assert info.progress_unit == unit + + +def test_energy_bucket_is_valid_and_total_kwh() -> None: + now = datetime.now(UTC) + valid_bucket = EnergyBucket(start_time=now, energy_kwh=1.25, status=1) + zero_bucket = EnergyBucket(start_time=now, energy_kwh=0.0, status=1) + invalid_bucket = EnergyBucket(start_time=now, energy_kwh=math.nan, status=2) + metrics = SpaceEnergyMetrics( + space_id="space-1", + buckets=[valid_bucket, zero_bucket, invalid_bucket], + ) + + assert valid_bucket.is_valid is True + assert zero_bucket.is_valid is True + assert invalid_bucket.is_valid is False + assert metrics.total_kwh == pytest.approx(1.25) + + +@pytest.mark.parametrize( + ("comfort_type", "hvac_mode", "fan_mode", "fan_percent", "louver_mode", "expected_fan"), + [ + (ComfortSettingType.ACTIVE, HVACMode.HEAT, 1, 0.0, LouverMode.AUTO, FanSpeed.AUTO), + (ComfortSettingType.AWAY, HVACMode.COOL, 2, 0.60, LouverMode.SWEEP, FanSpeed.MEDIUM), + (ComfortSettingType.CUSTOM, HVACMode.AUTO, 2, 0.80, LouverMode.FIXED, FanSpeed.HIGH), + ], +) +def test_comfort_setting_from_proto_different_types( + comfort_type: ComfortSettingType, + hvac_mode: HVACMode, + fan_mode: int, + fan_percent: float, + louver_mode: LouverMode, + expected_fan: FanSpeed, +) -> None: + proto = _ns( + header=_make_header("comfort-1"), + relationships=_ns(space_id="space-1"), + attributes=_ns( + name=comfort_type.name.title(), + type=comfort_type, + hvac_mode=hvac_mode, + heating_temperature_setpoint_c=20.0, + cooling_temperature_setpoint_c=25.0, + fan_speed_mode=fan_mode, + fan_speed_percent=fan_percent, + louver_mode=louver_mode, + louver_fixed_position=0.4, + ), + ) + + setting = ComfortSetting.from_proto(proto) + + assert setting.id == "comfort-1" + assert setting.space_id == "space-1" + assert setting.type == comfort_type + assert setting.hvac_mode == hvac_mode + assert setting.fan_speed == expected_fan + assert setting.louver_mode == louver_mode + + +def test_controller_from_proto_includes_wifi_remote_sensor_and_hardware() -> None: + now = int(datetime.now(tz=UTC).timestamp()) + proto = _ns( + header=_make_header("ctrl-1"), + relationships=_ns( + space_id="space-1", + hardware_id="controllers/HW-1", + software_update_info_id="sw-1", + firmware_update_info_id="fw-1", + ), + settings=_ns(name="Hall Dial"), + state=_ns( + updated_ts=_ns(seconds=now), + ambient_temperature_c=22.1, + temperature_f3=34.5, + temperature_f4=48.0, + temperature_f5=21.7, + ), + hosted_wifi_state=_ns( + ssid="HomeNet", + ipv4_address="192.168.1.10", + signal_level_dbm=-58, + frequency_mhz=2437, + updated_ts=_ns(seconds=now), + ), + ap_wifi_state=_ns( + ssid="Dial-AP", + ipv4_address="192.168.4.1", + signal_level_dbm=-30, + updated_ts=_ns(seconds=now), + ), + p2p_wifi_state=_ns( + ssid="Dial-Direct", + ipv4_address="169.254.1.1", + signal_level_dbm=-40, + updated_ts=_ns(seconds=now), + ), + controls=_ns(remote_sensor_control_mode=RemoteSensorControlMode.ENABLED), + ) + hw_map = { + "hw-1": _ns( + attributes=_ns( + serial_number="SN123", + model_sku="DIAL-V1", + firmware_version="9.9.9", + ) + ) + } + + controller = Controller.from_proto(proto, hw_map=hw_map) + + assert controller.name == "Hall Dial" + assert controller.wifi_ssid == "HomeNet" + assert controller.wifi_ip == "192.168.1.10" + assert controller.wifi_signal_dbm == -58 + assert controller.wifi_band == "2.4 GHz" + assert controller.ap_wifi is not None + assert controller.ap_wifi.ssid == "Dial-AP" + assert controller.p2p_wifi is not None + assert controller.p2p_wifi.ssid == "Dial-Direct" + assert controller.remote_sensor_mode == RemoteSensorControlMode.ENABLED + assert controller.software_update_info_id == "sw-1" + assert controller.firmware_update_info_id == "fw-1" + assert controller.serial_number == "SN123" + assert controller.model_sku == "DIAL-V1" + assert controller.firmware_version == "9.9.9" + assert controller.state_updated_at is not None + assert controller.wifi_last_seen is not None + + +@pytest.mark.parametrize( + ("ambient", "compressor", "energy", "high_pressure", "low_pressure"), + [ + (19.5, 55.0, 7200.0, 2450.0, 780.0), + (0.0, 0.0, 0.0, 0.0, 0.0), + ], +) +def test_outdoor_unit_from_proto_with_performance_data( + ambient: float, + compressor: float, + energy: float, + high_pressure: float, + low_pressure: float, +) -> None: + proto = _ns( + header=_make_header("odu-1"), + relationships=_ns( + space_id="space-1", + hardware_id="outdoor/HW-ODU-1", + firmware_update_info_id="fw-odu-1", + ), + state=_ns(hvac_state=HVACMode.COOL), + performance_data=_ns( + measurement_interval_s=5.0, + energy_measurement_j=energy, + compressor_frequency_hz=compressor, + ambient_temperature_c=ambient, + coil_temperature_c=8.0, + exhaust_temperature_c=42.0, + high_pressure_kpa=high_pressure, + low_pressure_kpa=low_pressure, + ), + ) + hw_map = { + "hw-odu-1": _ns( + attributes=_ns( + model_sku="ODU-24K", + serial_number="ODU123", + firmware_version="3.2.1", + ) + ) + } + + outdoor_unit = OutdoorUnit.from_proto(proto, hw_map=hw_map) + + assert outdoor_unit.model_sku == "ODU-24K" + assert outdoor_unit.serial_number == "ODU123" + assert outdoor_unit.firmware_version == "3.2.1" + assert outdoor_unit.firmware_update_info_id == "fw-odu-1" + assert outdoor_unit.performance_data is not None + assert outdoor_unit.performance_data.ambient_temperature_c == ambient + assert outdoor_unit.performance_data.compressor_frequency_hz == compressor + assert outdoor_unit.performance_data.energy_measurement_j == energy + assert outdoor_unit.performance_data.high_pressure_kpa == high_pressure + assert outdoor_unit.performance_data.low_pressure_kpa == low_pressure + + +@pytest.mark.parametrize( + ("model_cls", "relationship_field", "relationship_value"), + [ + (RemoteSensor, "indoor_unit_id", "idu-1"), + (ControllerRemoteSensor, "controller_id", "ctrl-1"), + ], +) +def test_remote_sensor_models_from_proto( + model_cls: type[RemoteSensor] | type[ControllerRemoteSensor], + relationship_field: str, + relationship_value: str, +) -> None: + proto = _ns( + header=_make_header("sensor-1"), + relationships=_ns(**{relationship_field: relationship_value}), + attributes=_ns(mac=""), + controls=_ns(control_mode=RemoteSensorControlMode.DISABLED), + state=_ns( + ambient_temperature_c=0.0, + humidity_percent=47.5, + battery_level_percent=91.0, + signal_level_dbm=0, + ), + ) + + sensor = model_cls.from_proto(proto) + + assert getattr(sensor, relationship_field) == relationship_value + assert sensor.mac is None + assert sensor.ambient_temperature_c == 0.0 + assert sensor.humidity_percent == 47.5 + assert sensor.battery_level_percent == 91.0 + assert sensor.signal_level_dbm == 0 + assert sensor.control_mode == RemoteSensorControlMode.DISABLED diff --git a/tests/test_models_from_proto.py b/tests/test_models_from_proto.py index bb6d19f..393690b 100644 --- a/tests/test_models_from_proto.py +++ b/tests/test_models_from_proto.py @@ -27,24 +27,17 @@ SafetyHeatingMode, ) from quilt_hp.models.indoor_unit import IndoorUnit -from quilt_hp.models.qsm import QuiltSmartModule +from quilt_hp.models.outdoor_unit import OutdoorUnit +from quilt_hp.models.qsm import QuiltSmartModule, WifiInfo from quilt_hp.models.schedule import ScheduleDay, ScheduleEvent, ScheduleWeek from quilt_hp.models.sensor import RemoteSensor from quilt_hp.models.space import Space, SpaceControls, SpaceSettings from quilt_hp.models.system import Location, SystemSnapshot +from tests.conftest import _make_header, _ns # ─── helpers ──────────────────────────────────────────────────────────────── -def _ns(**kwargs: object) -> SimpleNamespace: - """Build a SimpleNamespace recursively from keyword args.""" - return SimpleNamespace(**kwargs) - - -def _make_header(object_id: str = "obj-1", system_id: str = "sys-1") -> SimpleNamespace: - return _ns(object_id=object_id, system_id=system_id) - - # ─── Space ────────────────────────────────────────────────────────────────── @@ -264,6 +257,18 @@ def test_display_setpoint_standby() -> None: assert c.display_setpoint == "--" +def test_display_setpoint_zero_value_is_preserved() -> None: + c = SpaceControls( + hvac_mode=HVACMode.HEAT, + temperature_setpoint_c=0.0, + cooling_setpoint_c=26.0, + heating_setpoint_c=0.0, + comfort_setting_id="", + comfort_setting_override=0, + ) + assert c.display_setpoint == "0.0°C" + + def test_space_controls_comfort_setting_id_sentinel() -> None: c = SpaceControls( hvac_mode=HVACMode.COOL, @@ -714,12 +719,55 @@ def test_controller_no_wifi() -> None: ctrl = Controller.from_proto(proto) assert ctrl.wifi_ssid is None assert ctrl.wifi_ip is None - assert ctrl.wifi_signal_dbm is None + assert ctrl.wifi_signal_dbm == 0 assert ctrl.wifi_band is None assert ctrl.wifi_last_seen is None assert ctrl.is_online # seconds=0 → no timestamp → unknown → assume online (fail-open) +def test_controller_wifi_signal_zero_is_preserved() -> None: + proto = _ns( + header=_make_header("ctrl-3"), + relationships=_ns( + space_id="space-1", + software_update_info_id="", + firmware_update_info_id="", + ), + settings=_ns(name="Dial"), + state=_ns( + updated_ts=_ns(seconds=0), + ambient_temperature_c=20.0, + temperature_f3=33.0, + temperature_f4=47.0, + temperature_f5=20.0, + ), + hosted_wifi_state=_ns( + ssid="MyNet", + ipv4_address="192.168.1.42", + signal_level_dbm=0, + frequency_mhz=2412, + updated_ts=_ns(seconds=0), + ), + ap_wifi_state=_ns( + ssid="", + ipv4_address="", + signal_level_dbm=0, + frequency_mhz=0, + updated_ts=_ns(seconds=0), + ), + p2p_wifi_state=_ns( + ssid="", + ipv4_address="", + signal_level_dbm=0, + frequency_mhz=0, + updated_ts=_ns(seconds=0), + ), + controls=_ns(remote_sensor_control_mode=0), + ) + ctrl = Controller.from_proto(proto) + assert ctrl.wifi_signal_dbm == 0 + + # ─── QuiltSmartModule ──────────────────────────────────────────────────────── @@ -785,6 +833,14 @@ def test_qsm_from_proto_no_sensors() -> None: assert qsm.hosted_wifi is None +def test_wifi_info_zero_signal_is_preserved() -> None: + info = WifiInfo.from_proto( + _ns(ssid="HomeNet", ipv4_address="192.168.1.50", signal_level_dbm=0) + ) + assert info.signal_dbm == 0 + assert info.connected is True + + # ─── RemoteSensor ──────────────────────────────────────────────────────────── @@ -810,7 +866,7 @@ def test_remote_sensor_from_proto() -> None: assert rs.control_mode == RemoteSensorControlMode.ENABLED -def test_remote_sensor_missing_fields() -> None: +def test_remote_sensor_zero_values_are_preserved() -> None: proto = _ns( header=_make_header("rs-2"), relationships=_ns(indoor_unit_id="idu-1"), @@ -825,9 +881,10 @@ def test_remote_sensor_missing_fields() -> None: ) rs = RemoteSensor.from_proto(proto) assert rs.mac is None - assert rs.ambient_temperature_c is None - assert rs.battery_level_percent is None - assert rs.signal_level_dbm is None + assert rs.ambient_temperature_c == 0.0 + assert rs.humidity_percent == 0.0 + assert rs.battery_level_percent == 0.0 + assert rs.signal_level_dbm == 0 # ─── ScheduleDay / ScheduleWeek ───────────────────────────────────────────── @@ -858,6 +915,7 @@ def test_schedule_day_from_proto_sorted() -> None: day = ScheduleDay.from_proto(proto) assert day.id == "day-1" assert day.name == "Weekday" + assert all(isinstance(ev.hvac_mode, HVACMode) for ev in day.events) times = [ev.start_time for ev in day.events] assert times == ["07:00", "09:00", "18:00"] @@ -915,6 +973,12 @@ def test_energy_bucket_nan_sentinel_handling() -> None: assert metrics.total_kwh == 1.25 +def test_energy_bucket_none_is_missing() -> None: + bucket = EnergyBucket(start_time=datetime.now(UTC), energy_kwh=None, status=0) # type: ignore[arg-type] + assert bucket.has_missing_energy_value is True + assert bucket.energy_kwh_or_none is None + + # ─── Location ──────────────────────────────────────────────────────────────── @@ -1416,13 +1480,27 @@ def _make_odu_proto(odu_id: str, space_id: str = "space-1") -> SimpleNamespace: relationships=_ns(space_id=space_id, hardware_id="", firmware_update_info_id=""), state=_ns(hvac_state=0), performance_data=_ns( + measurement_interval_s=0.0, ambient_temperature_c=0.0, compressor_frequency_hz=0.0, energy_measurement_j=0.0, + coil_temperature_c=0.0, + exhaust_temperature_c=0.0, + high_pressure_kpa=0.0, + low_pressure_kpa=0.0, ), ) +def test_outdoor_unit_zero_performance_values_are_preserved() -> None: + proto = _make_odu_proto("odu-1") + odu = OutdoorUnit.from_proto(proto) + assert odu.performance_data is not None + assert odu.performance_data.ambient_temperature_c == 0.0 + assert odu.performance_data.compressor_frequency_hz == 0.0 + assert odu.performance_data.energy_measurement_j == 0.0 + + def _make_snap_with_multiple_odus() -> SystemSnapshot: from quilt_hp._proto import quilt_hds_pb2 as hds diff --git a/tests/test_streaming.py b/tests/test_streaming.py index cea20fe..c4d4871 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -159,6 +159,23 @@ async def test_unsubscribe_removes_topics() -> None: assert "topic-b" in stream._topics +@pytest.mark.asyncio +async def test_subscribe_after_queue_reset_resubscribes_from_topics() -> None: + stream = _make_stream(["topic-a"]) + stream._request_queue = asyncio.Queue() + + await stream.subscribe(["topic-b"]) + + stream._running = True + request_iterator = stream._request_iterator(list(stream._topics), stream._request_queue) + initial = await anext(request_iterator) + queued = await anext(request_iterator) + await request_iterator.aclose() + + assert [sub.topic for sub in initial.append.subscriptions] == ["topic-a", "topic-b"] + assert [sub.topic for sub in queued.append.subscriptions] == ["topic-b"] + + # ─── lifecycle ─────────────────────────────────────────────────────────────── diff --git a/tests/test_streaming_concurrency.py b/tests/test_streaming_concurrency.py new file mode 100644 index 0000000..7d4d93f --- /dev/null +++ b/tests/test_streaming_concurrency.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import asyncio +from unittest.mock import MagicMock, patch + +import grpc +import pytest + +from quilt_hp.services.streaming import NotifierStream + + +class _FakeRpcError(grpc.aio.AioRpcError): + def __init__(self, code: grpc.StatusCode, details: str = "") -> None: + self._code = code + self._details = details + + def code(self) -> grpc.StatusCode: # type: ignore[override] + return self._code + + def details(self) -> str: # type: ignore[override] + return self._details + + +class _BlockingCall: + def __init__(self) -> None: + self.started = asyncio.Event() + self._cancelled = asyncio.Event() + + def cancel(self) -> None: + self._cancelled.set() + + def __aiter__(self) -> _BlockingCall: + return self + + async def __anext__(self) -> object: + self.started.set() + await self._cancelled.wait() + raise StopAsyncIteration + + +def _make_stream(topics: list[str] | None = None) -> NotifierStream: + with patch("quilt_hp.services.streaming.notifier_grpc.NotifierServiceStub"): + return NotifierStream.create(MagicMock(), topics or ["topic-a"]) + + +@pytest.mark.asyncio +async def test_concurrent_start_stop_calls_do_not_deadlock_or_raise() -> None: + stream = _make_stream() + + async def _wait_for_stop() -> None: + await stream._stop_event.wait() + + stream._run_stream_with_reconnect = _wait_for_stop # type: ignore[method-assign] + + tasks = [ + asyncio.create_task(stream.start()), + asyncio.create_task(stream.start()), + asyncio.create_task(stream.stop()), + asyncio.create_task(stream.stop()), + ] + results = await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=1.0) + + assert not [result for result in results if isinstance(result, Exception)] + + await stream.stop() + assert stream._task is None + + +@pytest.mark.asyncio +async def test_subscribe_during_active_reconnect_keeps_topics() -> None: + stream = _make_stream(["topic-a"]) + stream._running = True + stream._max_reconnects = 1 + + attempts = 0 + reconnect_waiting = asyncio.Event() + continue_reconnect = asyncio.Event() + + async def _flaky() -> None: + nonlocal attempts + attempts += 1 + if attempts == 1: + raise _FakeRpcError(grpc.StatusCode.UNAVAILABLE, "down") + stream._running = False + + async def _fake_wait_for_stop(_delay: float) -> bool: + reconnect_waiting.set() + await continue_reconnect.wait() + return False + + stream._wait_for_stop = _fake_wait_for_stop # type: ignore[method-assign] + stream._run_one_stream = _flaky # type: ignore[method-assign] + + task = asyncio.create_task(stream._run_stream_with_reconnect()) + await reconnect_waiting.wait() + await stream.subscribe(["topic-b"]) + continue_reconnect.set() + await task + + assert stream._topics == ["topic-a", "topic-b"] + + +@pytest.mark.asyncio +async def test_rapid_subscribe_and_unsubscribe_preserve_final_topics() -> None: + stream = _make_stream(["topic-a"]) + + async def _subscribe(topics: list[str], delay: float) -> None: + await asyncio.sleep(delay) + await stream.subscribe(topics) + + async def _unsubscribe(topics: list[str], delay: float) -> None: + await asyncio.sleep(delay) + await stream.unsubscribe(topics) + + await asyncio.gather( + _subscribe(["topic-b"], 0.0), + _unsubscribe(["topic-a"], 0.001), + _subscribe(["topic-c"], 0.002), + _unsubscribe(["topic-b"], 0.003), + _subscribe(["topic-d"], 0.004), + _unsubscribe(["topic-c"], 0.005), + ) + + assert set(stream._topics) == {"topic-d"} + + +@pytest.mark.asyncio +async def test_stop_during_active_run_forever_terminates_cleanly() -> None: + stream = _make_stream(["topic-a"]) + blocking_call = _BlockingCall() + stream._stub = MagicMock(Subscribe=lambda *_args, **_kwargs: blocking_call) + + run_task = asyncio.create_task(stream.run_forever()) + await blocking_call.started.wait() + + await stream.stop() + await asyncio.wait_for(run_task, timeout=1.0) + + assert stream._running is False + assert stream.stream_state == "stopped" diff --git a/tests/test_streaming_debounce.py b/tests/test_streaming_debounce.py new file mode 100644 index 0000000..f90c5d9 --- /dev/null +++ b/tests/test_streaming_debounce.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +from types import SimpleNamespace as _ns +from unittest.mock import MagicMock, patch + +import pytest + +from quilt_hp.services.streaming import NotifierStream, StreamEvent + + +def _make_stream(*, debounce_s: float) -> NotifierStream: + with patch("quilt_hp.services.streaming.notifier_grpc.NotifierServiceStub"): + return NotifierStream.create( + MagicMock(), + ["hds/space/space-1"], + debounce_s=debounce_s, + ) + + +def _response() -> object: + return _ns(control_events=[], notifier_events=[object()]) + + +@pytest.mark.asyncio +async def test_debounce_zero_dispatches_immediately() -> None: + stream = _make_stream(debounce_s=0.0) + seen: list[int] = [] + stream.on_space_update(lambda space: seen.append(space.value)) + stream._parse_event = MagicMock( + return_value=StreamEvent(topic="topic", space=_ns(id="space-1", value=72)) + ) + + async def _iter() -> AsyncIterator[object]: + yield _response() + + stream._stub = MagicMock(Subscribe=lambda *_args, **_kwargs: _iter()) + + await stream._run_one_stream() + + assert seen == [72] + + +@pytest.mark.asyncio +async def test_debounce_coalesces_rapid_events() -> None: + stream = _make_stream(debounce_s=0.05) + seen: list[int] = [] + stream.on_space_update(lambda space: seen.append(space.value)) + stream._parse_event = MagicMock( + side_effect=[ + StreamEvent(topic="topic", space=_ns(id="space-1", value=70)), + StreamEvent(topic="topic", space=_ns(id="space-1", value=71)), + StreamEvent(topic="topic", space=_ns(id="space-1", value=72)), + ] + ) + + async def _iter() -> AsyncIterator[object]: + yield _response() + yield _response() + yield _response() + + stream._stub = MagicMock(Subscribe=lambda *_args, **_kwargs: _iter()) + + await stream._run_one_stream() + assert seen == [] + + await asyncio.sleep(0.07) + + assert seen == [72] + + +@pytest.mark.asyncio +async def test_debounce_dispatches_final_value_after_quiet_period() -> None: + stream = _make_stream(debounce_s=0.05) + seen: list[int] = [] + stream.on_space_update(lambda space: seen.append(space.value)) + stream._parse_event = MagicMock( + side_effect=[ + StreamEvent(topic="topic", space=_ns(id="space-1", value=68)), + StreamEvent(topic="topic", space=_ns(id="space-1", value=69)), + ] + ) + + async def _iter() -> AsyncIterator[object]: + yield _response() + await asyncio.sleep(0.03) + yield _response() + + stream._stub = MagicMock(Subscribe=lambda *_args, **_kwargs: _iter()) + + await stream._run_one_stream() + assert seen == [] + + await asyncio.sleep(0.03) + assert seen == [] + + await asyncio.sleep(0.04) + assert seen == [69] diff --git a/tests/test_streaming_health.py b/tests/test_streaming_health.py new file mode 100644 index 0000000..51114e6 --- /dev/null +++ b/tests/test_streaming_health.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import grpc +import pytest + +from quilt_hp.exceptions import QuiltStreamError +from quilt_hp.services.streaming import NotifierStream, StreamEvent + + +class _FakeRpcError(grpc.aio.AioRpcError): + def __init__(self, code: grpc.StatusCode, details: str = "") -> None: + self._code = code + self._details = details + + def code(self) -> grpc.StatusCode: # type: ignore[override] + return self._code + + def details(self) -> str: # type: ignore[override] + return self._details + + +def _make_stream() -> NotifierStream: + with patch("quilt_hp.services.streaming.notifier_grpc.NotifierServiceStub"): + return NotifierStream.create(MagicMock(), ["hds/space/space-1"]) + + +def test_health_properties_default_to_idle() -> None: + stream = _make_stream() + + assert stream.is_connected is False + assert stream.last_event_at is None + assert stream.stream_state == "idle" + + +@pytest.mark.asyncio +async def test_run_one_stream_marks_connected_and_tracks_last_event( + monkeypatch: pytest.MonkeyPatch, +) -> None: + stream = _make_stream() + stream._running = True + states: list[tuple[bool, str]] = [] + + async def _space_cb(_space: object) -> None: + states.append((stream.is_connected, stream.stream_state)) + + stream.on_space_update(_space_cb) + stream._parse_event = MagicMock(return_value=StreamEvent(topic="topic", space=object())) + monkeypatch.setattr("quilt_hp.services.streaming.time.monotonic", lambda: 123.4) + + response = SimpleNamespace(control_events=[], notifier_events=[object()]) + + async def _iter() -> AsyncIterator[object]: + yield response + + stream._stub = MagicMock(Subscribe=lambda *_args, **_kwargs: _iter()) + await stream._run_one_stream() + + assert states == [(True, "connected")] + assert stream.last_event_at == 123.4 + assert stream.is_connected is True + assert stream.stream_state == "connected" + + +@pytest.mark.asyncio +async def test_reconnect_state_is_exposed_during_backoff(monkeypatch: pytest.MonkeyPatch) -> None: + stream = _make_stream() + stream._running = True + stream._max_reconnects = 1 + + calls = 0 + + async def _flaky() -> None: + nonlocal calls + calls += 1 + if calls == 1: + raise _FakeRpcError(grpc.StatusCode.UNAVAILABLE, "down") + stream._running = False + + seen_states: list[str] = [] + + async def _fake_sleep(_delay: float) -> None: + seen_states.append(stream.stream_state) + + stream._run_one_stream = _flaky # type: ignore[method-assign] + monkeypatch.setattr("quilt_hp.services.streaming.asyncio.sleep", _fake_sleep) + + await stream._run_stream_with_reconnect() + + assert seen_states == ["reconnecting"] + assert stream.is_connected is False + assert stream.stream_state == "stopped" + + +@pytest.mark.asyncio +async def test_fatal_stream_error_sets_error_state() -> None: + stream = _make_stream() + stream._running = True + stream._max_reconnects = 0 + stream._run_one_stream = AsyncMock( + side_effect=_FakeRpcError(grpc.StatusCode.UNAVAILABLE, "down") + ) + + with pytest.raises(QuiltStreamError): + await stream._run_stream_with_reconnect() + + assert stream.is_connected is False + assert stream.stream_state == "error" + + +@pytest.mark.asyncio +async def test_stop_marks_stream_stopped() -> None: + stream = _make_stream() + + async def _noop() -> None: + await asyncio.sleep(3600) + + stream._run_stream_with_reconnect = _noop # type: ignore[method-assign] + await stream.start() + await stream.stop() + + assert stream.is_connected is False + assert stream.stream_state == "stopped" diff --git a/tests/test_streaming_reconnect_dispatch_extra.py b/tests/test_streaming_reconnect_dispatch_extra.py index 0f12ba2..c5a50dc 100644 --- a/tests/test_streaming_reconnect_dispatch_extra.py +++ b/tests/test_streaming_reconnect_dispatch_extra.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import logging from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch @@ -94,7 +95,9 @@ async def _iter() -> asyncio.AsyncIterator[object]: @pytest.mark.asyncio -async def test_reconnect_retries_then_resubscribes(monkeypatch: pytest.MonkeyPatch) -> None: +async def test_reconnect_retries_then_resubscribes( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: stream = _make_stream() stream._running = True stream._max_reconnects = 1 @@ -118,11 +121,13 @@ async def _fake_sleep(delay: float) -> None: monkeypatch.setattr("quilt_hp.services.streaming.asyncio.sleep", _fake_sleep) - await stream._run_stream_with_reconnect() + with caplog.at_level(logging.INFO): + await stream._run_stream_with_reconnect() assert calls == 2 assert sleep_calls == [1.0] assert stream._request_queue is not old_queue + assert "Resetting subscription queue before reconnect" in caplog.text @pytest.mark.asyncio diff --git a/tests/test_transport_interceptor_extra.py b/tests/test_transport_interceptor_extra.py index 35a33a6..1c54604 100644 --- a/tests/test_transport_interceptor_extra.py +++ b/tests/test_transport_interceptor_extra.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect from unittest.mock import MagicMock import grpc @@ -43,6 +44,38 @@ async def _legacy() -> None: assert called == ["legacy"] +@pytest.mark.asyncio +async def test_invoke_refresh_callback_caches_signature( + monkeypatch: pytest.MonkeyPatch, +) -> None: + transport._REFRESH_CALLBACK_HAS_PARAMS.clear() + called: list[transport.TokenRefreshContext] = [] + + async def _with_context(context: transport.TokenRefreshContext) -> None: + called.append(context) + + signature = inspect.Signature( + parameters=[ + inspect.Parameter( + "context", + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + ] + ) + inspect_signature = MagicMock(return_value=signature) + monkeypatch.setattr(transport.inspect, "signature", inspect_signature) + + context = transport.TokenRefreshContext( + reason=transport.TokenRefreshReason.TRANSPORT_UNAUTHENTICATED, + source="test", + ) + await transport._invoke_refresh_callback(_with_context, context) + await transport._invoke_refresh_callback(_with_context, context) + + assert called == [context, context] + assert inspect_signature.call_count == 1 + + @pytest.mark.asyncio async def test_auth_interceptor_retry_paths() -> None: refreshed: list[str] = []