Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,46 @@ can do that task either synchronously, asynchronously, on a separate process etc
Each worker has access to a shared utilities library that aids in db and message broker interaction as well as other functions that are common across ARAs. Check the
shared function library before writing a new function that you think other ARAs might also want to use.

#### Worker tuning & graceful shutdown

Every worker draws its tasks through `shepherd_utils.shared.get_tasks`, so the
following behavior applies to all of them:

- **Concurrency (`TASK_LIMIT`)** — each worker declares a default in-process
concurrency, but it can be overridden per deployment with the `TASK_LIMIT`
environment variable (each worker is its own container, so a single
`TASK_LIMIT` per Deployment is unambiguous). No code change or rebuild needed.
- **Graceful drain on shutdown** — on `SIGTERM`/`SIGINT` (Kubernetes sends
`SIGTERM` on every rollout, scale-down and node drain) a worker stops pulling
new tasks, waits up to `WORKER_DRAIN_TIMEOUT_SEC` (default 30s) for in-flight
tasks to finish, writes a clean-shutdown marker the monitor reads (so the
event is classified as a graceful scale-down rather than a crash), then exits.
Tasks that don't finish in the window are left in the stream for Redis reclaim.
Set the deployment's `terminationGracePeriodSeconds` comfortably above
`WORKER_DRAIN_TIMEOUT_SEC`.

##### Kubernetes sizing (Helm)

Production limits live in the Helm chart, not in `compose.yml` (which is
dev-only). Recommended starting point for `finish_query`, which holds whole
decompressed TRAPI payloads in memory while POSTing async callbacks:

| Setting | Value |
| --- | --- |
| `resources.requests` | `cpu: 500m`, `memory: 1Gi` |
| `resources.limits` | `cpu: "2"`, `memory: 4Gi` |
| `TASK_LIMIT` | `32` (down from the in-code default of 100) |
| `terminationGracePeriodSeconds` | `35` |

Scale throughput with replicas / an HPA (on CPU or queue depth) rather than a
single large pod. On Kubernetes the memory `limit` (OOMKilled + restart) plus
regular rollouts already recycle pods, so leaked-resource cleanup comes for free
— add an RSS-based `livenessProbe` only if the monitor shows OOMKills in
practice. CPU-bound pool workers (`merge_message`, `score_paths`, `arax_rank`,
`aragorn_score`, `aragorn_omnicorp`) size their process/thread pools from the
in-code default, so raising `TASK_LIMIT` for those only deepens the intake queue
rather than adding parallelism.

### Message Broker Streams

Shepherd uses Redis Streams for its message broker. More info on Redis Streams can be found [here](https://redis.io/docs/latest/develop/data-types/streams/)
Expand Down
8 changes: 8 additions & 0 deletions shepherd_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ class Settings(BaseSettings):
reclaim_interval_sec: int = 10
reclaim_max_batch: int = 50

# Graceful shutdown. On SIGTERM/SIGINT (Kubernetes sends SIGTERM on every
# rollout, scale-down and node drain) a worker stops pulling new tasks and
# waits up to this long for in-flight tasks to finish before exiting, so the
# recycling the orchestrator already does is clean instead of lossy. Tasks
# that don't finish in the window are left in the stream for Redis reclaim.
# Keep this comfortably below the deployment's terminationGracePeriodSeconds.
worker_drain_timeout_sec: float = 30.0

# Monitor (dashboard) worker
monitor_port: int = 5440
monitor_poll_interval_sec: float = 3.0
Expand Down
41 changes: 39 additions & 2 deletions shepherd_utils/heartbeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,23 @@ def shutdown_key(stream: str, consumer: str) -> str:
class Heartbeat:
"""Background task that periodically refreshes a presence key in Redis."""

def __init__(self, stream: str, consumer: str, task_limit: int):
def __init__(
self,
stream: str,
consumer: str,
task_limit: int,
manage_signals: bool = True,
):
self.stream = stream
self.consumer = consumer
self.task_limit = task_limit
self.started_at = time.time()
self._task: asyncio.Task | None = None
self._logger = logging.getLogger(f"shepherd.heartbeat.{stream}")
# When False, this Heartbeat does not install its own SIGTERM/SIGINT
# handlers -- the caller (shared.get_tasks) installs asyncio-aware
# handlers instead so it can drain in-flight tasks before exiting.
self.manage_signals = manage_signals
self._signal_installed = False
self._prev_handlers: dict = {}

Expand Down Expand Up @@ -90,9 +100,36 @@ async def _loop(self) -> None:
def start(self) -> "Heartbeat":
if self._task is None:
self._task = asyncio.create_task(self._loop())
self._install_signal_handlers()
if self.manage_signals:
self._install_signal_handlers()
return self

async def mark_clean_shutdown(self) -> None:
"""Write the shutdown marker from within the event loop.

Mirror of ``_mark_shutdown_sync`` for the graceful-shutdown path, which
already runs on the asyncio loop and so can use the async broker client.
The monitor reads this marker to classify the disappearance of the
heartbeat as a clean scale-down rather than a crash. Heartbeat key
deletion is handled by ``stop()``.
"""
payload = json.dumps(
{
"stream": self.stream,
"consumer": self.consumer,
"signum": int(signal.SIGTERM),
"ts": time.time(),
}
)
try:
await broker_client.set(
shutdown_key(self.stream, self.consumer),
payload,
ex=SHUTDOWN_TTL_SEC,
)
except Exception as e:
self._logger.debug(f"Failed to write clean shutdown marker: {e}")

async def stop(self) -> None:
if self._task is not None:
self._task.cancel()
Expand Down
8 changes: 7 additions & 1 deletion shepherd_utils/reclaim.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@
"merge_message": 60,
"score_paths": 60,
"example.score": 30,
# Filter / entry / finish workers fall through to the default (fast).
# finish_query sends the async callback, which retries with backoff and can
# legitimately run for minutes against a slow callback endpoint (httpx
# timeout 120s x up to 3 attempts). At the fast default (30s) a second
# consumer could XCLAIM the message mid-callback and deliver it twice. The
# heartbeat filter is the primary guard; this floor is the backstop.
"finish_query": 240,
# Other filter / entry workers fall through to the default (fast).
}


Expand Down
147 changes: 145 additions & 2 deletions shepherd_utils/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import asyncio
import json
import logging
import os
import signal
import sys
import time
from typing import AsyncGenerator, Dict, List, Tuple

Expand Down Expand Up @@ -52,6 +55,131 @@ async def _record_task_duration(
setup_logging()


# ---------------------------------------------------------------------------
# Graceful shutdown / in-flight drain
#
# Kubernetes sends SIGTERM on every rollout, scale-down and node drain. Without
# handling it the worker is killed mid-task and the work is only recovered later
# via Redis reclaim. Here we install asyncio-aware signal handlers that flip a
# shutdown flag; ``get_tasks`` then stops pulling new work and drains anything
# in flight before the process exits.
#
# Draining piggybacks on the concurrency semaphore that ``get_tasks`` already
# owns: every worker acquires a permit before a task starts and releases it when
# the task finishes (in run_task_lifecycle / each worker's process_task finally).
# So "all permits acquired" is equivalent to "no task in flight" -- we don't need
# the workers to register their background tasks with us.
# ---------------------------------------------------------------------------

_shutdown = asyncio.Event()
_active_heartbeat: "Heartbeat | None" = None
_signal_handlers_installed = False


def is_shutting_down() -> bool:
return _shutdown.is_set()


def _request_shutdown() -> None:
_shutdown.set()


def install_shutdown_handlers(heartbeat: "Heartbeat | None" = None) -> None:
"""Install asyncio-aware SIGTERM/SIGINT handlers (idempotent).

``loop.add_signal_handler`` is the safe way to react to a signal from inside
a running event loop: the callback runs between awaits rather than in the
interrupt context, so it can flip an ``asyncio.Event`` the drain loop awaits.
"""
global _signal_handlers_installed, _active_heartbeat
# Always update the heartbeat reference so the marker is written for the
# currently-active worker even if get_tasks is re-entered after an error.
_active_heartbeat = heartbeat
if _signal_handlers_installed:
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
for sig in (signal.SIGTERM, signal.SIGINT):
try:
loop.add_signal_handler(sig, _request_shutdown)
except (NotImplementedError, RuntimeError, ValueError):
# Platforms without loop signal support fall back to signal.signal.
try:
signal.signal(sig, lambda *_: _request_shutdown())
except (ValueError, OSError):
pass
_signal_handlers_installed = True


async def _drain_and_exit(
limiter: asyncio.Semaphore,
task_limit: int,
logger: logging.Logger,
) -> None:
"""Wait for in-flight tasks to finish, mark a clean shutdown, then exit.

Acquiring all ``task_limit`` permits means every in-flight task has released
its permit -- i.e. completed. Bounded by ``worker_drain_timeout_sec``;
stragglers are left in the stream for Redis reclaim to retry.
"""
logger.info("Shutdown signal received; draining in-flight tasks.")
acquired = 0

async def _acquire_all() -> None:
nonlocal acquired
for _ in range(task_limit):
await limiter.acquire()
acquired += 1

try:
await asyncio.wait_for(
_acquire_all(), timeout=float(settings.worker_drain_timeout_sec)
)
logger.info("All in-flight tasks drained cleanly.")
except asyncio.TimeoutError:
logger.warning(
f"Drain timed out with ~{task_limit - acquired} task(s) still "
"running; leaving them in the stream for Redis reclaim."
)

hb = _active_heartbeat
if hb is not None:
try:
await hb.mark_clean_shutdown()
except Exception as e:
logger.debug(f"Failed to write clean shutdown marker: {e}")
try:
await hb.stop()
except Exception:
pass
logger.info("Exiting after graceful drain.")
sys.exit(0)


def _resolve_task_limit(stream: str, default: int, logger: logging.Logger) -> int:
"""Allow ops to override a worker's concurrency via the TASK_LIMIT env var.

Each worker runs as its own container/Deployment, so a single ``TASK_LIMIT``
env per Deployment unambiguously tunes that worker without a code change or
rebuild. Falls back to the value the worker passed in.
"""
raw = os.getenv("TASK_LIMIT")
if raw is None:
return default
try:
value = int(raw)
if value < 1:
raise ValueError
except ValueError:
logger.warning(f"Ignoring invalid TASK_LIMIT={raw!r} for {stream}.")
return default
if value != default:
logger.info(f"TASK_LIMIT for {stream} overridden to {value} via env.")
return value


def get_next_operation(
workflow: List[Dict[str, str]],
) -> Tuple[Dict[str, str], List[Dict[str, str]]]:
Expand Down Expand Up @@ -112,16 +240,25 @@ async def get_tasks(
worker_logger = logging.getLogger(f"shepherd.{stream}.{consumer}")
worker_logger.setLevel(level_number)
worker_logger.addHandler(log_handler)
# allow ops to tune concurrency per Deployment without a code change
task_limit = _resolve_task_limit(stream, task_limit, worker_logger)
# initialize opens the db connection
await initialize_db()
task_limiter = asyncio.Semaphore(task_limit)
# register this worker with the monitor via a Redis heartbeat key
Heartbeat(stream, consumer, task_limit).start()
# register this worker with the monitor via a Redis heartbeat key. The
# heartbeat does not install its own (immediate-exit) signal handlers --
# install_shutdown_handlers below installs asyncio-aware ones that drain.
heartbeat = Heartbeat(stream, consumer, task_limit, manage_signals=False).start()
install_shutdown_handlers(heartbeat)
# periodic orphan-task reclaim so a worker crash doesn't strand its PEL
reclaim_interval = max(5.0, float(settings.reclaim_interval_sec))
last_reclaim = 0.0
# continuously poll the broker for new tasks
while True:
# On shutdown, stop taking new work and drain anything in flight.
if is_shutting_down():
await _drain_and_exit(task_limiter, task_limit, worker_logger)
return
# Before fetching new work, check whether any pending messages on this
# stream belong to a dead consumer and claim them. Heartbeat + idle
# filtering inside ``reclaim_orphaned`` keep live consumers safe.
Expand Down Expand Up @@ -155,6 +292,12 @@ async def get_tasks(

# check if we can take another task
await task_limiter.acquire()
# A shutdown may have arrived while we waited for a free slot; don't
# fetch new work in that case -- release and drain.
if is_shutting_down():
task_limiter.release()
await _drain_and_exit(task_limiter, task_limit, worker_logger)
return
# get a new task for the given target
ara_task = await get_task(stream, group, consumer, worker_logger)
if ara_task is not None:
Expand Down
Loading
Loading