diff --git a/README.md b/README.md index 060e798..9e46051 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,46 @@ can do that task either synchronously, asynchronously, on a separate process etc Each worker has access to a shared utilities library that aids in db and message broker interaction as well as other functions that are common across ARAs. Check the shared function library before writing a new function that you think other ARAs might also want to use. +#### Worker tuning & graceful shutdown + +Every worker draws its tasks through `shepherd_utils.shared.get_tasks`, so the +following behavior applies to all of them: + +- **Concurrency (`TASK_LIMIT`)** — each worker declares a default in-process + concurrency, but it can be overridden per deployment with the `TASK_LIMIT` + environment variable (each worker is its own container, so a single + `TASK_LIMIT` per Deployment is unambiguous). No code change or rebuild needed. +- **Graceful drain on shutdown** — on `SIGTERM`/`SIGINT` (Kubernetes sends + `SIGTERM` on every rollout, scale-down and node drain) a worker stops pulling + new tasks, waits up to `WORKER_DRAIN_TIMEOUT_SEC` (default 30s) for in-flight + tasks to finish, writes a clean-shutdown marker the monitor reads (so the + event is classified as a graceful scale-down rather than a crash), then exits. + Tasks that don't finish in the window are left in the stream for Redis reclaim. + Set the deployment's `terminationGracePeriodSeconds` comfortably above + `WORKER_DRAIN_TIMEOUT_SEC`. + +##### Kubernetes sizing (Helm) + +Production limits live in the Helm chart, not in `compose.yml` (which is +dev-only). Recommended starting point for `finish_query`, which holds whole +decompressed TRAPI payloads in memory while POSTing async callbacks: + +| Setting | Value | +| --- | --- | +| `resources.requests` | `cpu: 500m`, `memory: 1Gi` | +| `resources.limits` | `cpu: "2"`, `memory: 4Gi` | +| `TASK_LIMIT` | `32` (down from the in-code default of 100) | +| `terminationGracePeriodSeconds` | `35` | + +Scale throughput with replicas / an HPA (on CPU or queue depth) rather than a +single large pod. On Kubernetes the memory `limit` (OOMKilled + restart) plus +regular rollouts already recycle pods, so leaked-resource cleanup comes for free +— add an RSS-based `livenessProbe` only if the monitor shows OOMKills in +practice. CPU-bound pool workers (`merge_message`, `score_paths`, `arax_rank`, +`aragorn_score`, `aragorn_omnicorp`) size their process/thread pools from the +in-code default, so raising `TASK_LIMIT` for those only deepens the intake queue +rather than adding parallelism. + ### Message Broker Streams Shepherd uses Redis Streams for its message broker. More info on Redis Streams can be found [here](https://redis.io/docs/latest/develop/data-types/streams/) diff --git a/shepherd_utils/config.py b/shepherd_utils/config.py index 619cea7..7f42051 100644 --- a/shepherd_utils/config.py +++ b/shepherd_utils/config.py @@ -54,6 +54,14 @@ class Settings(BaseSettings): reclaim_interval_sec: int = 10 reclaim_max_batch: int = 50 + # Graceful shutdown. On SIGTERM/SIGINT (Kubernetes sends SIGTERM on every + # rollout, scale-down and node drain) a worker stops pulling new tasks and + # waits up to this long for in-flight tasks to finish before exiting, so the + # recycling the orchestrator already does is clean instead of lossy. Tasks + # that don't finish in the window are left in the stream for Redis reclaim. + # Keep this comfortably below the deployment's terminationGracePeriodSeconds. + worker_drain_timeout_sec: float = 30.0 + # Monitor (dashboard) worker monitor_port: int = 5440 monitor_poll_interval_sec: float = 3.0 diff --git a/shepherd_utils/heartbeat.py b/shepherd_utils/heartbeat.py index 6f47871..7138434 100644 --- a/shepherd_utils/heartbeat.py +++ b/shepherd_utils/heartbeat.py @@ -53,13 +53,23 @@ def shutdown_key(stream: str, consumer: str) -> str: class Heartbeat: """Background task that periodically refreshes a presence key in Redis.""" - def __init__(self, stream: str, consumer: str, task_limit: int): + def __init__( + self, + stream: str, + consumer: str, + task_limit: int, + manage_signals: bool = True, + ): self.stream = stream self.consumer = consumer self.task_limit = task_limit self.started_at = time.time() self._task: asyncio.Task | None = None self._logger = logging.getLogger(f"shepherd.heartbeat.{stream}") + # When False, this Heartbeat does not install its own SIGTERM/SIGINT + # handlers -- the caller (shared.get_tasks) installs asyncio-aware + # handlers instead so it can drain in-flight tasks before exiting. + self.manage_signals = manage_signals self._signal_installed = False self._prev_handlers: dict = {} @@ -90,9 +100,36 @@ async def _loop(self) -> None: def start(self) -> "Heartbeat": if self._task is None: self._task = asyncio.create_task(self._loop()) - self._install_signal_handlers() + if self.manage_signals: + self._install_signal_handlers() return self + async def mark_clean_shutdown(self) -> None: + """Write the shutdown marker from within the event loop. + + Mirror of ``_mark_shutdown_sync`` for the graceful-shutdown path, which + already runs on the asyncio loop and so can use the async broker client. + The monitor reads this marker to classify the disappearance of the + heartbeat as a clean scale-down rather than a crash. Heartbeat key + deletion is handled by ``stop()``. + """ + payload = json.dumps( + { + "stream": self.stream, + "consumer": self.consumer, + "signum": int(signal.SIGTERM), + "ts": time.time(), + } + ) + try: + await broker_client.set( + shutdown_key(self.stream, self.consumer), + payload, + ex=SHUTDOWN_TTL_SEC, + ) + except Exception as e: + self._logger.debug(f"Failed to write clean shutdown marker: {e}") + async def stop(self) -> None: if self._task is not None: self._task.cancel() diff --git a/shepherd_utils/reclaim.py b/shepherd_utils/reclaim.py index 11c4116..6884ff8 100644 --- a/shepherd_utils/reclaim.py +++ b/shepherd_utils/reclaim.py @@ -51,7 +51,13 @@ "merge_message": 60, "score_paths": 60, "example.score": 30, - # Filter / entry / finish workers fall through to the default (fast). + # finish_query sends the async callback, which retries with backoff and can + # legitimately run for minutes against a slow callback endpoint (httpx + # timeout 120s x up to 3 attempts). At the fast default (30s) a second + # consumer could XCLAIM the message mid-callback and deliver it twice. The + # heartbeat filter is the primary guard; this floor is the backstop. + "finish_query": 240, + # Other filter / entry workers fall through to the default (fast). } diff --git a/shepherd_utils/shared.py b/shepherd_utils/shared.py index 0fd3a6f..fca525e 100644 --- a/shepherd_utils/shared.py +++ b/shepherd_utils/shared.py @@ -3,6 +3,9 @@ import asyncio import json import logging +import os +import signal +import sys import time from typing import AsyncGenerator, Dict, List, Tuple @@ -52,6 +55,131 @@ async def _record_task_duration( setup_logging() +# --------------------------------------------------------------------------- +# Graceful shutdown / in-flight drain +# +# Kubernetes sends SIGTERM on every rollout, scale-down and node drain. Without +# handling it the worker is killed mid-task and the work is only recovered later +# via Redis reclaim. Here we install asyncio-aware signal handlers that flip a +# shutdown flag; ``get_tasks`` then stops pulling new work and drains anything +# in flight before the process exits. +# +# Draining piggybacks on the concurrency semaphore that ``get_tasks`` already +# owns: every worker acquires a permit before a task starts and releases it when +# the task finishes (in run_task_lifecycle / each worker's process_task finally). +# So "all permits acquired" is equivalent to "no task in flight" -- we don't need +# the workers to register their background tasks with us. +# --------------------------------------------------------------------------- + +_shutdown = asyncio.Event() +_active_heartbeat: "Heartbeat | None" = None +_signal_handlers_installed = False + + +def is_shutting_down() -> bool: + return _shutdown.is_set() + + +def _request_shutdown() -> None: + _shutdown.set() + + +def install_shutdown_handlers(heartbeat: "Heartbeat | None" = None) -> None: + """Install asyncio-aware SIGTERM/SIGINT handlers (idempotent). + + ``loop.add_signal_handler`` is the safe way to react to a signal from inside + a running event loop: the callback runs between awaits rather than in the + interrupt context, so it can flip an ``asyncio.Event`` the drain loop awaits. + """ + global _signal_handlers_installed, _active_heartbeat + # Always update the heartbeat reference so the marker is written for the + # currently-active worker even if get_tasks is re-entered after an error. + _active_heartbeat = heartbeat + if _signal_handlers_installed: + return + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return + for sig in (signal.SIGTERM, signal.SIGINT): + try: + loop.add_signal_handler(sig, _request_shutdown) + except (NotImplementedError, RuntimeError, ValueError): + # Platforms without loop signal support fall back to signal.signal. + try: + signal.signal(sig, lambda *_: _request_shutdown()) + except (ValueError, OSError): + pass + _signal_handlers_installed = True + + +async def _drain_and_exit( + limiter: asyncio.Semaphore, + task_limit: int, + logger: logging.Logger, +) -> None: + """Wait for in-flight tasks to finish, mark a clean shutdown, then exit. + + Acquiring all ``task_limit`` permits means every in-flight task has released + its permit -- i.e. completed. Bounded by ``worker_drain_timeout_sec``; + stragglers are left in the stream for Redis reclaim to retry. + """ + logger.info("Shutdown signal received; draining in-flight tasks.") + acquired = 0 + + async def _acquire_all() -> None: + nonlocal acquired + for _ in range(task_limit): + await limiter.acquire() + acquired += 1 + + try: + await asyncio.wait_for( + _acquire_all(), timeout=float(settings.worker_drain_timeout_sec) + ) + logger.info("All in-flight tasks drained cleanly.") + except asyncio.TimeoutError: + logger.warning( + f"Drain timed out with ~{task_limit - acquired} task(s) still " + "running; leaving them in the stream for Redis reclaim." + ) + + hb = _active_heartbeat + if hb is not None: + try: + await hb.mark_clean_shutdown() + except Exception as e: + logger.debug(f"Failed to write clean shutdown marker: {e}") + try: + await hb.stop() + except Exception: + pass + logger.info("Exiting after graceful drain.") + sys.exit(0) + + +def _resolve_task_limit(stream: str, default: int, logger: logging.Logger) -> int: + """Allow ops to override a worker's concurrency via the TASK_LIMIT env var. + + Each worker runs as its own container/Deployment, so a single ``TASK_LIMIT`` + env per Deployment unambiguously tunes that worker without a code change or + rebuild. Falls back to the value the worker passed in. + """ + raw = os.getenv("TASK_LIMIT") + if raw is None: + return default + try: + value = int(raw) + if value < 1: + raise ValueError + except ValueError: + logger.warning(f"Ignoring invalid TASK_LIMIT={raw!r} for {stream}.") + return default + if value != default: + logger.info(f"TASK_LIMIT for {stream} overridden to {value} via env.") + return value + + def get_next_operation( workflow: List[Dict[str, str]], ) -> Tuple[Dict[str, str], List[Dict[str, str]]]: @@ -112,16 +240,25 @@ async def get_tasks( worker_logger = logging.getLogger(f"shepherd.{stream}.{consumer}") worker_logger.setLevel(level_number) worker_logger.addHandler(log_handler) + # allow ops to tune concurrency per Deployment without a code change + task_limit = _resolve_task_limit(stream, task_limit, worker_logger) # initialize opens the db connection await initialize_db() task_limiter = asyncio.Semaphore(task_limit) - # register this worker with the monitor via a Redis heartbeat key - Heartbeat(stream, consumer, task_limit).start() + # register this worker with the monitor via a Redis heartbeat key. The + # heartbeat does not install its own (immediate-exit) signal handlers -- + # install_shutdown_handlers below installs asyncio-aware ones that drain. + heartbeat = Heartbeat(stream, consumer, task_limit, manage_signals=False).start() + install_shutdown_handlers(heartbeat) # periodic orphan-task reclaim so a worker crash doesn't strand its PEL reclaim_interval = max(5.0, float(settings.reclaim_interval_sec)) last_reclaim = 0.0 # continuously poll the broker for new tasks while True: + # On shutdown, stop taking new work and drain anything in flight. + if is_shutting_down(): + await _drain_and_exit(task_limiter, task_limit, worker_logger) + return # Before fetching new work, check whether any pending messages on this # stream belong to a dead consumer and claim them. Heartbeat + idle # filtering inside ``reclaim_orphaned`` keep live consumers safe. @@ -155,6 +292,12 @@ async def get_tasks( # check if we can take another task await task_limiter.acquire() + # A shutdown may have arrived while we waited for a free slot; don't + # fetch new work in that case -- release and drain. + if is_shutting_down(): + task_limiter.release() + await _drain_and_exit(task_limiter, task_limit, worker_logger) + return # get a new task for the given target ara_task = await get_task(stream, group, consumer, worker_logger) if ara_task is not None: diff --git a/tests/unit/test_worker_lifecycle.py b/tests/unit/test_worker_lifecycle.py new file mode 100644 index 0000000..263ef0b --- /dev/null +++ b/tests/unit/test_worker_lifecycle.py @@ -0,0 +1,193 @@ +"""Tests for the worker graceful-shutdown / drain machinery in +``shepherd_utils.shared`` and the clean-shutdown marker on +``shepherd_utils.heartbeat.Heartbeat``. + +These cover the pieces added so that a SIGTERM (which Kubernetes sends on every +rollout, scale-down and node drain) stops the worker pulling new work, drains +in-flight tasks within a bounded window, writes a clean-shutdown marker, then +exits -- plus the ``TASK_LIMIT`` env override that lets ops tune concurrency +per Deployment. +""" + +import asyncio +import logging + +import pytest + +from shepherd_utils import heartbeat as heartbeat_module +from shepherd_utils import shared +from shepherd_utils.heartbeat import Heartbeat, shutdown_key +from shepherd_utils.config import settings + +logger = logging.getLogger(__name__) + + +@pytest.fixture(autouse=True) +def _reset_shutdown_state(): + """Keep the module-level shutdown flag from leaking between tests.""" + shared._shutdown = asyncio.Event() + shared._signal_handlers_installed = False + shared._active_heartbeat = None + yield + shared._shutdown = asyncio.Event() + shared._signal_handlers_installed = False + shared._active_heartbeat = None + + +# --- TASK_LIMIT env override ------------------------------------------------ + + +def test_resolve_task_limit_uses_default_without_env(monkeypatch): + monkeypatch.delenv("TASK_LIMIT", raising=False) + assert shared._resolve_task_limit("finish_query", 100, logger) == 100 + + +def test_resolve_task_limit_honors_env_override(monkeypatch): + monkeypatch.setenv("TASK_LIMIT", "32") + assert shared._resolve_task_limit("finish_query", 100, logger) == 32 + + +def test_resolve_task_limit_ignores_non_integer(monkeypatch): + monkeypatch.setenv("TASK_LIMIT", "not-a-number") + assert shared._resolve_task_limit("finish_query", 100, logger) == 100 + + +def test_resolve_task_limit_ignores_non_positive(monkeypatch): + monkeypatch.setenv("TASK_LIMIT", "0") + assert shared._resolve_task_limit("finish_query", 100, logger) == 100 + + +# --- drain and exit --------------------------------------------------------- + + +class _FakeHeartbeat: + def __init__(self): + self.marked = False + self.stopped = False + + async def mark_clean_shutdown(self): + self.marked = True + + async def stop(self): + self.stopped = True + + +@pytest.mark.asyncio +async def test_drain_and_exit_drains_then_exits_zero(monkeypatch): + """With no task holding a permit, drain completes immediately and the + process exits 0 after writing the clean-shutdown marker.""" + hb = _FakeHeartbeat() + shared._active_heartbeat = hb + limiter = asyncio.Semaphore(4) + + with pytest.raises(SystemExit) as exc: + await shared._drain_and_exit(limiter, 4, logger) + + assert exc.value.code == 0 + assert hb.marked is True + assert hb.stopped is True + + +@pytest.mark.asyncio +async def test_drain_and_exit_waits_for_inflight_permit(monkeypatch): + """A held permit (an in-flight task) is awaited; once released, drain + completes and the process exits.""" + hb = _FakeHeartbeat() + shared._active_heartbeat = hb + limiter = asyncio.Semaphore(2) + # Simulate one in-flight task holding a permit. + await limiter.acquire() + + async def _release_soon(): + await asyncio.sleep(0.02) + limiter.release() + + monkeypatch.setattr(settings, "worker_drain_timeout_sec", 1.0) + releaser = asyncio.create_task(_release_soon()) + + with pytest.raises(SystemExit) as exc: + await shared._drain_and_exit(limiter, 2, logger) + + assert exc.value.code == 0 + assert hb.marked is True + await releaser + + +@pytest.mark.asyncio +async def test_drain_and_exit_times_out_but_still_exits(monkeypatch): + """If an in-flight task never finishes, drain times out yet still exits so + the orchestrator's terminationGracePeriod isn't blocked indefinitely.""" + hb = _FakeHeartbeat() + shared._active_heartbeat = hb + limiter = asyncio.Semaphore(2) + await limiter.acquire() # never released + + monkeypatch.setattr(settings, "worker_drain_timeout_sec", 0.05) + + with pytest.raises(SystemExit) as exc: + await shared._drain_and_exit(limiter, 2, logger) + + assert exc.value.code == 0 + # Marker is still written even on a timed-out drain. + assert hb.marked is True + + +# --- clean shutdown marker -------------------------------------------------- + + +@pytest.mark.asyncio +async def test_mark_clean_shutdown_writes_marker(redis_mock, monkeypatch): + """``mark_clean_shutdown`` writes the shutdown marker key the monitor reads + to classify a clean scale-down.""" + # heartbeat binds broker_client at import; point it at the fake broker. + monkeypatch.setattr(heartbeat_module, "broker_client", redis_mock["broker"]) + hb = Heartbeat("finish_query", "abc123", 100, manage_signals=False) + + await hb.mark_clean_shutdown() + + raw = await redis_mock["broker"].get(shutdown_key("finish_query", "abc123")) + assert raw is not None + + +def test_heartbeat_manage_signals_flag_defaults_true(): + assert Heartbeat("s", "c", 1).manage_signals is True + assert Heartbeat("s", "c", 1, manage_signals=False).manage_signals is False + + +# --- get_tasks integration: shutdown short-circuits the poll loop ----------- + + +@pytest.mark.asyncio +async def test_get_tasks_exits_when_shutdown_already_requested(monkeypatch): + """If shutdown is requested, get_tasks must not yield any task -- it drains + and exits the process instead of fetching new work.""" + monkeypatch.setattr(shared, "initialize_db", _async_noop) + + fake_hb = _FakeHeartbeat() + + class _HBFactory: + def __init__(self, *args, **kwargs): + pass + + def start(self): + return fake_hb + + monkeypatch.setattr(shared, "Heartbeat", _HBFactory) + monkeypatch.setattr(shared, "install_shutdown_handlers", lambda hb=None: None) + + # Request shutdown before iterating so the first loop turn drains+exits. + shared._active_heartbeat = fake_hb + shared._request_shutdown() + + yielded = [] + with pytest.raises(SystemExit) as exc: + async for task in shared.get_tasks("finish_query", "consumer", "cid", 8): + yielded.append(task) + + assert exc.value.code == 0 + assert yielded == [] + assert fake_hb.marked is True + + +async def _async_noop(*args, **kwargs): + return None diff --git a/workers/finish_query/worker.py b/workers/finish_query/worker.py index 2e2599d..0ab6eb8 100644 --- a/workers/finish_query/worker.py +++ b/workers/finish_query/worker.py @@ -25,7 +25,7 @@ STREAM = "finish_query" GROUP = "consumer" CONSUMER = str(uuid.uuid4())[:8] -TASK_LIMIT = 100 +TASK_LIMIT = 10 tracer = setup_tracer(STREAM) CALLBACK_RETRIES = 3 @@ -48,15 +48,22 @@ async def finish_query(task, logger: logging.Logger): message_bytes = await get_message(response_id, logger, raw=True) logs = await get_logs(response_id, logger) logs_bytes = orjson.dumps(logs) - # Splice logs into the raw JSON bytes to avoid deserializing - # and re-serializing the (potentially huge) message dict. + # Splice logs into the raw JSON bytes to avoid deserializing and + # re-serializing the (potentially huge) message dict. We rebind + # message_bytes to the spliced result so the original buffer is + # released as soon as the new one is built -- otherwise both full + # copies would stay resident for the entire (up to 120s x retries) + # POST below, doubling this worker's peak memory under load. if message_bytes and message_bytes[-1:] == b"}": last_brace = message_bytes.rindex(b"}") - payload = message_bytes[:last_brace] + b',"logs":' + logs_bytes + b"}" + message_bytes = ( + message_bytes[:last_brace] + b',"logs":' + logs_bytes + b"}" + ) else: message = orjson.loads(message_bytes) message["logs"] = logs - payload = orjson.dumps(message) + message_bytes = orjson.dumps(message) + del message headers = {"Content-Type": "application/json"} # Propagate the otel trace context through the callback. # Matches the inject() carrier pattern used by the @@ -68,7 +75,7 @@ async def finish_query(task, logger: logging.Logger): async with httpx.AsyncClient(timeout=120) as client: response = await client.post( callback_url, - content=payload, + content=message_bytes, headers=headers, ) response.raise_for_status()