diff --git a/framework/python/src/common/logger.py b/framework/python/src/common/logger.py index 972eb2ce5..bad6b1010 100644 --- a/framework/python/src/common/logger.py +++ b/framework/python/src/common/logger.py @@ -16,6 +16,17 @@ import json import logging import os +from common.mqtt_topics import MQTTTopic + + +class TestrunLogger(logging.Logger): + def ui_info(self, msg, *args, **kwargs): + from common import mqtt # pylint: disable=import-outside-toplevel + with mqtt.MQTT() as client: + client.send_message(MQTTTopic.INFO, {'message': msg}) + self.info(msg, *args, **kwargs) + +logging.setLoggerClass(TestrunLogger) LOGGERS = {} _LOG_FORMAT = '%(asctime)s %(name)-8s %(levelname)-7s %(message)s' diff --git a/framework/python/src/common/mqtt.py b/framework/python/src/common/mqtt.py index b98b4ab1b..efff459d1 100644 --- a/framework/python/src/common/mqtt.py +++ b/framework/python/src/common/mqtt.py @@ -32,6 +32,16 @@ class MQTT: def __init__(self) -> None: self._host = WEBSOCKETS_HOST self._client = mqtt_client.Client(mqtt_client.CallbackAPIVersion.VERSION2) + self._connect() + + def __enter__(self): + self._connect() + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + if exc_traceback: + LOGGER.error(exc_traceback) + self.disconnect() def _connect(self): """Establish connection to MQTT broker""" @@ -54,7 +64,6 @@ def send_message(self, topic: str, message: t.Union[str, dict]) -> None: topic (str): mqtt topic message (t.Union[str, dict]): message """ - self._connect() if isinstance(message, dict): message = json.dumps(message) self._client.publish(topic, str(message)) diff --git a/framework/python/src/common/mqtt_topics.py b/framework/python/src/common/mqtt_topics.py new file mode 100644 index 000000000..b5c8d6246 --- /dev/null +++ b/framework/python/src/common/mqtt_topics.py @@ -0,0 +1,21 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Enums for mqtt topics""" +from enum import Enum + +class MQTTTopic(str, Enum): + INFO = "info" + INTERNET_CONNECTION_TOPIC = "events/internet" + NETWORK_ADAPTERS_TOPIC = "events/adapter" + STATUS_TOPIC = "status" diff --git a/framework/python/src/common/tasks.py b/framework/python/src/common/tasks.py index 5da0b40c9..b75478e83 100644 --- a/framework/python/src/common/tasks.py +++ b/framework/python/src/common/tasks.py @@ -21,13 +21,12 @@ from fastapi import FastAPI from common import logger +from common.mqtt_topics import MQTTTopic # Check adapters period seconds # Check adapters period seconds CHECK_NETWORK_ADAPTERS_PERIOD = 5 CHECK_INTERNET_PERIOD = 2 -INTERNET_CONNECTION_TOPIC = 'events/internet' -NETWORK_ADAPTERS_TOPIC = 'events/adapter' LOGGER = logger.get_logger('tasks') @@ -48,8 +47,7 @@ def __init__( self.adapters_checker_job = self._scheduler.add_job( func=self._testrun.get_net_orc().network_adapters_checker, kwargs={ - 'mqtt_client': self._mqtt_client, - 'topic': NETWORK_ADAPTERS_TOPIC + 'topic': MQTTTopic.NETWORK_ADAPTERS_TOPIC }, trigger='interval', seconds=CHECK_NETWORK_ADAPTERS_PERIOD, @@ -59,8 +57,7 @@ def __init__( self.internet_shecker = self._scheduler.add_job( func=self._testrun.get_net_orc().internet_conn_checker, kwargs={ - 'mqtt_client': self._mqtt_client, - 'topic': INTERNET_CONNECTION_TOPIC + 'topic': MQTTTopic.INTERNET_CONNECTION_TOPIC }, trigger='interval', seconds=CHECK_INTERNET_PERIOD, diff --git a/framework/python/src/core/session.py b/framework/python/src/core/session.py index 3fc373793..83a42fd79 100644 --- a/framework/python/src/core/session.py +++ b/framework/python/src/core/session.py @@ -19,6 +19,7 @@ import os from fastapi.encoders import jsonable_encoder from common import util, logger, mqtt +from common.mqtt_topics import MQTTTopic from common.risk_profile import RiskProfile from common.statuses import TestrunStatus, TestResult, TestrunResult from net_orc.ip_control import IPControl @@ -42,7 +43,6 @@ ALLOW_DISCONNECT_KEY='allow_disconnect' CERTS_PATH = 'local/root_certs' CONFIG_FILE_PATH = 'local/system.json' -STATUS_TOPIC = 'status' MAKE_CONTROL_DIR = 'make/DEBIAN/control' @@ -65,10 +65,11 @@ def wrapper(self, *args, **kwargs): result = method(self, *args, **kwargs) if self.get_status() != TestrunStatus.IDLE and not self.pause_message: - self.get_mqtt_client().send_message( - STATUS_TOPIC, - jsonable_encoder(self.to_json()) - ) + with mqtt.MQTT() as client: + client.send_message( + MQTTTopic.STATUS_TOPIC, + jsonable_encoder(self.to_json()) + ) if self.get_status() in STATUSES_COMPLETE: self.pause_message = True @@ -168,9 +169,6 @@ def __init__(self, root_dir): self._timezone = tz[0] LOGGER.debug(f'System timezone is {self._timezone}') - # MQTT client - self._mqtt_client = mqtt.MQTT() - def start(self): self.reset() self._status = TestrunStatus.STARTING @@ -1061,8 +1059,5 @@ def detect_network_adapters_change(self) -> dict: self._ifaces = ifaces_new return adapters - def get_mqtt_client(self): - return self._mqtt_client - def get_ifaces(self): return self._ifaces diff --git a/framework/python/src/core/tasks.py b/framework/python/src/core/tasks.py index 5da0b40c9..466e74b29 100644 --- a/framework/python/src/core/tasks.py +++ b/framework/python/src/core/tasks.py @@ -48,7 +48,6 @@ def __init__( self.adapters_checker_job = self._scheduler.add_job( func=self._testrun.get_net_orc().network_adapters_checker, kwargs={ - 'mqtt_client': self._mqtt_client, 'topic': NETWORK_ADAPTERS_TOPIC }, trigger='interval', @@ -59,7 +58,6 @@ def __init__( self.internet_shecker = self._scheduler.add_job( func=self._testrun.get_net_orc().internet_conn_checker, kwargs={ - 'mqtt_client': self._mqtt_client, 'topic': INTERNET_CONNECTION_TOPIC }, trigger='interval', diff --git a/framework/python/src/core/testrun.py b/framework/python/src/core/testrun.py index 069552320..f9ef397d2 100644 --- a/framework/python/src/core/testrun.py +++ b/framework/python/src/core/testrun.py @@ -398,9 +398,6 @@ async def stop(self): self.get_session().set_status(TestrunStatus.CANCELLED) - # Disconnect before WS server stops to prevent error - self._mqtt_client.disconnect() - self._stop_network(kill=True) def _register_exits(self): @@ -468,7 +465,7 @@ def _device_discovered(self, mac_addr): if device is not None: if mac_addr != device.mac_addr: LOGGER.info(f'Found device with mac addr: {mac_addr} but was ignored') - LOGGER.info(f'Expected device mac address is {device.mac_addr}') + LOGGER.ui_info(f'Expected device mac address is {device.mac_addr}') # Ignore discovered device because it is not the target device return else: diff --git a/framework/python/src/net_orc/network_orchestrator.py b/framework/python/src/net_orc/network_orchestrator.py index 4acd5f3c3..2f8b20ec4 100644 --- a/framework/python/src/net_orc/network_orchestrator.py +++ b/framework/python/src/net_orc/network_orchestrator.py @@ -697,14 +697,15 @@ def restore_net(self): def get_session(self): return self._session - def network_adapters_checker(self, mqtt_client: mqtt.MQTT, topic: str): + def network_adapters_checker(self, topic: str): """Checks for changes in network adapters and sends a message to the frontend """ try: adapters = self._session.detect_network_adapters_change() if adapters: - mqtt_client.send_message(topic, adapters) + with mqtt.MQTT() as client: + client.send_message(topic, adapters) except Exception: # pylint: disable=W0703 LOGGER.error(traceback.format_exc()) @@ -713,7 +714,7 @@ def is_device_connected(self): return self._ip_ctrl.check_interface_status( self._session.get_device_interface()) - def internet_conn_checker(self, mqtt_client: mqtt.MQTT, topic: str): + def internet_conn_checker(self, topic: str): """Checks internet connection and sends a status to frontend""" # Default message @@ -739,8 +740,9 @@ def internet_conn_checker(self, mqtt_client: mqtt.MQTT, topic: str): if internet_connection: message['connection'] = True - # Broadcast via MQTT client - mqtt_client.send_message(topic, message) + with mqtt.MQTT() as client: + # Broadcast via MQTT client + client.send_message(topic, message) class NetworkConfig: