diff --git a/subt/main.py b/subt/main.py index a4d09e332..2a656c32d 100644 --- a/subt/main.py +++ b/subt/main.py @@ -10,6 +10,7 @@ from datetime import timedelta from collections import defaultdict from io import StringIO +from random import Random import numpy as np @@ -146,6 +147,7 @@ class SubTChallenge: def __init__(self, config, bus): self.bus = bus bus.register("desired_speed", "pose2d", "artf_xyz", "pose3d", "stdout", "request_origin") + self.random = Random(0) self.start_pose = None self.traveled_dist = 0.0 self.time = None diff --git a/subt/monitors.py b/subt/monitors.py new file mode 100644 index 000000000..a93387c61 --- /dev/null +++ b/subt/monitors.py @@ -0,0 +1,96 @@ +import logging + +g_logger = logging.getLogger(__name__) + + +class TimeoutReached(Exception): + pass + + +class TimeoutMonitor: + def __init__(self, robot, timeout): + self.robot = robot + self.timeout = timeout + self.fired = False + + def update(self): + if (self.robot.time - self.start_time) > self.timeout: + raise TimeoutReached(self) + + def __enter__(self): + self.start_time = self.robot.time + self.fired = False + self.handle = self.robot.register(self.update) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.robot.unregister(self.handle) + if exc_val is not None and isinstance(exc_val, TimeoutReached): + if exc_val.args[0] == self: + self.fired = True + g_logger.info("Timeout {self.timeout} reached.") + return True # don't reraise + + def __bool__(self): + return self.fired + + +class PitchError(Exception): + pass + + +class RollError(Exception): + pass + + +class PitchMonitor: + def __init__(self, robot, pitch_limit): + self.robot = robot + self.pitch_limit = pitch_limit + self.fired = False + + def update(self): + if self.pitch_limit is not None and self.robot.pitch is not None: + if abs(self.robot.pitch) > self.pitch_limit: + raise PitchError(self) + + def __enter__(self): + self.fired = False + self.handle = self.robot.register(self.update) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.robot.unregister(self.handle) + if exc_val is not None and isinstance(exc_val, PitchError): + if exc_val.args[0] == self: + self.fired = True + g_logger.info("Pitch limit {math.degrees(self.pitch_limit)} reached.") + return True # don't reraise + + def __bool__(self): + return self.fired + + +class RollMonitor: + def __init__(self, robot, roll_limit): + self.robot = robot + self.roll_limit = roll_limit + self.fired = False + + def update(self): + if self.roll_limit is not None and self.robot.roll is not None: + if abs(self.robot.roll) > self.roll_limit: + raise RollError(self) + + def __enter__(self): + self.fired = False + self.handle = self.robot.register(self.update) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.robot.unregister(self.handle) + if exc_val is not None and isinstance(exc_val, RollError): + if exc_val.args[0] == self: + self.fired = True + g_logger.info("Pitch limit {math.degrees(self.roll_limit)} reached.") + return True # don't reraise + + def __bool__(self): + return self.fired diff --git a/subt/test_monitors.py b/subt/test_monitors.py new file mode 100644 index 000000000..bdbb8e395 --- /dev/null +++ b/subt/test_monitors.py @@ -0,0 +1,110 @@ +import unittest +import math +from unittest import mock +from datetime import timedelta +from subt import monitors + +class Robot: + + def __init__(self): + self.callbacks = [] + self.time = timedelta() + self.pitch = 0 + self.roll = 0 + + def update(self, dt=timedelta(milliseconds=101)): + self.time += dt + self.pitch += 0.1 + self.roll += 0.15 + for f in self.callbacks: + f() + + def register(self, callback): + self.callbacks.append(callback) + return callback + + def unregister(self, handle): + assert handle in self.callbacks + self.callbacks.remove(handle) + + +class TimeoutTest(unittest.TestCase): + + def test_timeout(self): + robot = Robot() + timeout = monitors.TimeoutMonitor(robot, timedelta(seconds=1)) + with timeout: + for _ in range(100): + robot.update() + else: + self.assertTrue(False) + self.assertGreater(robot.time, timedelta(seconds=1)) + self.assertTrue(timeout) + + def test_stacked_timeout(self): + robot = Robot() + timeout_outside = monitors.TimeoutMonitor(robot, timedelta(seconds=1)) + timeout_inside = monitors.TimeoutMonitor(robot, timedelta(seconds=0.5)) + with timeout_outside: + with timeout_inside: + for _ in range(100): + robot.update() + self.assertGreater(robot.time, timedelta(seconds=0.5)) + self.assertLess(robot.time, timedelta(seconds=1)) + self.assertTrue(timeout_inside) + self.assertFalse(timeout_outside) + for _ in range(100): + robot.update() + self.assertGreater(robot.time, timedelta(seconds=1)) + self.assertTrue(timeout_outside) + + +class PitchRollTest(unittest.TestCase): + + def test_pitch(self): + robot = Robot() + pitch_limit = monitors.PitchMonitor(robot, math.radians(80)) + with pitch_limit: + for _ in range(100): + robot.update() + else: + self.assertTrue(False) + self.assertTrue(pitch_limit) + + def test_roll(self): + robot = Robot() + roll_limit = monitors.RollMonitor(robot, math.radians(80)) + with roll_limit: + for _ in range(100): + robot.update() + else: + self.assertTrue(False) + self.assertTrue(roll_limit) + + def test_stacked_roll(self): + robot = Robot() + max_roll_limit = monitors.RollMonitor(robot, math.radians(80)) + with max_roll_limit: + mid_roll_limit = monitors.RollMonitor(robot, math.radians(20)) + with mid_roll_limit: + for _ in range(100): + robot.update() + else: + self.assertTrue(False) + self.assertTrue(mid_roll_limit) + self.assertFalse(max_roll_limit) + for _ in range(100): + robot.update() + else: + self.assertTrue(False) + self.assertTrue(mid_roll_limit) + self.assertTrue(max_roll_limit) + + def test_roll_logging(self): + with mock.patch("subt.monitors.g_logger") as logger: + robot = Robot() + roll = monitors.RollMonitor(robot, math.radians(20)) + with roll: + robot.roll = 2 + robot.update() + self.assertTrue(logger.info.called)