From 1487a4dc8c939ea6a9af2b0d0a2cc9e63d6ee619 Mon Sep 17 00:00:00 2001 From: Dhinesh Ponnarasan Date: Sun, 5 Apr 2026 03:39:54 -0400 Subject: [PATCH 1/2] fix: Remove unnecessary whitespace in build_sequences_per_dataset.py and Dockerfile.linting --- docker/Dockerfile.linting | 2 +- tools/build_sequences_per_dataset.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/docker/Dockerfile.linting b/docker/Dockerfile.linting index 259c0bbedcd..7f2b629cc85 100644 --- a/docker/Dockerfile.linting +++ b/docker/Dockerfile.linting @@ -20,4 +20,4 @@ FROM main as jet ARG JET_API_VERSION RUN --mount=type=secret,id=JET_INDEX_URLS \ JET_INDEX_URLS=$(cat /run/secrets/JET_INDEX_URLS) && \ - uv pip install --no-cache-dir "jet-client~=2.0" --upgrade $JET_INDEX_URLS + uv pip install --no-cache-dir "jet-client~=2.0" --upgrade $JET_INDEX_URLS \ No newline at end of file diff --git a/tools/build_sequences_per_dataset.py b/tools/build_sequences_per_dataset.py index e2787dd6434..ebf86dcabc0 100644 --- a/tools/build_sequences_per_dataset.py +++ b/tools/build_sequences_per_dataset.py @@ -21,7 +21,6 @@ def get_paths_from_blend( blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]], ) -> List[str]: """Extract all dataset paths from blend and blend_per_split. - Args: blend (Optional[Tuple[List[str], Optional[List[float]]]]): A blend tuple containing a list of dataset paths and optionally a list of weights, e.g., @@ -29,12 +28,10 @@ def get_paths_from_blend( blend_per_split (Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]]): A list of 3 blend tuples (for train, valid, test splits), where each element has the same structure as blend - Returns: List[str]: A list of all unique dataset paths found in blend and blend_per_split """ paths = [] - # Extract paths from blend if blend is not None: paths_list, _ = blend From 801f8ccf8408384f9263899ad0e20f4b9f9bda41 Mon Sep 17 00:00:00 2001 From: Dhinesh Ponnarasan Date: Fri, 1 May 2026 18:08:03 -0400 Subject: [PATCH 2/2] observability: add metrics instrumentation to DataParallelInferenceCoordinator Implements observability enhancements for DataParallelInferenceCoordinator as described in issue #4176. Adds a backend-agnostic CoordinatorMetrics abstraction and instruments the coordinator with 10 metrics covering routing quality, reliability, and latency. New file: megatron/core/inference/coordinator_metrics.py - CoordinatorMetrics ABC with inc(), observe(), gauge() - NoOpMetrics default (near-zero overhead when observability is disabled) - Fully decoupled from any specific metrics backend (Prometheus, StatsD, etc.) Modified: megatron/core/inference/data_parallel_inference_coordinator.py - metrics: CoordinatorMetrics | None = None param in __init__ and entrypoint - coordinator_active_engines gauge set at init, on engine removal, and re-registration - _log_protocol_error() centralizes error classification for structured logging - routing_cache_hit_total / routing_cache_miss_total / routing_stale_detected_total emitted from get_best_data_parallel_rank() with record_metrics guard to prevent double-counting on retry loops - coordinator_engine_unreachable_total fired in _send_to_engine() before EHOSTUNREACH - coordinator_unknown_sender_total covers SUBMIT_REQUEST, control signals, and SHUTDOWN - coordinator_all_engines_exhausted_total in for-else when every engine is unreachable - coordinator_routing_latency_seconds observed after successful engine send - coordinator_message_processing_latency_seconds in try/finally to cover every message - coordinator_invalid_message_total / coordinator_internal_error_total via _log_protocol_error New file: tests/unit_tests/inference/test_coordinator_observability.py - 32 unit tests with in-memory TestMetrics backend - Covers all 10 metrics, NoOpMetrics default, double-count prevention, entrypoint forwarding, and all unknown-sender paths No routing or protocol behavior changes. Compatible with PR #4419. --- .../core/inference/coordinator_metrics.py | 72 ++ .../data_parallel_inference_coordinator.py | 416 ++++++---- .../test_coordinator_observability.py | 721 ++++++++++++++++++ 3 files changed, 1047 insertions(+), 162 deletions(-) create mode 100644 megatron/core/inference/coordinator_metrics.py create mode 100644 tests/unit_tests/inference/test_coordinator_observability.py diff --git a/megatron/core/inference/coordinator_metrics.py b/megatron/core/inference/coordinator_metrics.py new file mode 100644 index 00000000000..3f979510d0b --- /dev/null +++ b/megatron/core/inference/coordinator_metrics.py @@ -0,0 +1,72 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Lightweight metrics abstraction for the inference coordinator. + +Provides a backend-agnostic interface so that instrumentation code is +decoupled from any specific metrics system (Prometheus, StatsD, etc.). + +Usage +----- +Pass a ``CoordinatorMetrics`` implementation into the coordinator at +construction time. When no implementation is provided the coordinator +defaults to ``NoOpMetrics``, which adds near-zero overhead. + +Metric naming conventions +-------------------------- +- ``coordinator_*`` — system-level metrics (errors, active engines, latency). +- ``routing_*`` — routing-quality metrics (cache hits, misses, fallbacks). +""" + +from abc import ABC, abstractmethod + + +class CoordinatorMetrics(ABC): + """Abstract interface for coordinator observability metrics. + + Implement this class to plug in any metrics backend without modifying + coordinator logic. + """ + + @abstractmethod + def inc(self, name: str, value: int = 1) -> None: + """Increment a counter by *value*. + + Args: + name: Metric name, e.g. ``"routing_cache_hit_total"``. + value: Amount to add (default 1). + """ + + @abstractmethod + def observe(self, name: str, value: float) -> None: + """Record a latency or distribution sample. + + Args: + name: Metric name, e.g. ``"coordinator_routing_latency_seconds"``. + value: Observed value in the metric's natural unit. + """ + + @abstractmethod + def gauge(self, name: str, value: float) -> None: + """Set an instantaneous gauge value. + + Args: + name: Metric name, e.g. ``"coordinator_active_engines"``. + value: Current value. + """ + + +class NoOpMetrics(CoordinatorMetrics): + """No-op implementation. Adds near-zero overhead when observability is disabled. + + This is the default used by the coordinator when no metrics backend is + supplied. + """ + + def inc(self, name: str, value: int = 1) -> None: + pass + + def observe(self, name: str, value: float) -> None: + pass + + def gauge(self, name: str, value: float) -> None: + pass diff --git a/megatron/core/inference/data_parallel_inference_coordinator.py b/megatron/core/inference/data_parallel_inference_coordinator.py index 146ecf1f1dc..f7f2d6a6a00 100644 --- a/megatron/core/inference/data_parallel_inference_coordinator.py +++ b/megatron/core/inference/data_parallel_inference_coordinator.py @@ -6,6 +6,7 @@ import logging import signal import socket +import time from collections import deque from enum import Enum, auto from multiprocessing import Event @@ -14,7 +15,8 @@ import torch from megatron.core.inference.config import PrefixCachingCoordinatorPolicy -from megatron.core.inference.headers import Headers, UnknownHeaderError +from megatron.core.inference.coordinator_metrics import CoordinatorMetrics, NoOpMetrics +from megatron.core.inference.headers import Headers from megatron.core.inference.inference_request import compute_block_hashes_batched from megatron.core.inference.text_generation_controllers.text_generation_controller import ( TextGenerationController, @@ -95,6 +97,7 @@ def __init__( PrefixCachingCoordinatorPolicy.FIRST_PREFIX_BLOCK ), schedule_output_path: str | None = None, + metrics: CoordinatorMetrics | None = None, ): """ Initializes the inference coordinator. @@ -200,6 +203,11 @@ def __init__( identity: idx for idx, identity in enumerate(sorted_identities) } + # Metrics backend (defaults to no-op when not provided). + self.metrics: CoordinatorMetrics = metrics if metrics is not None else NoOpMetrics() + # Best-effort gauge update; eventual consistency is sufficient for observability. + self.metrics.gauge("coordinator_active_engines", float(len(self.identities_of_data_parallel_ranks))) + def get_next_data_parallel_rank(self): """ Selects the next data parallel rank using round-robin scheduling. @@ -214,6 +222,28 @@ def get_next_data_parallel_rank(self): self._round_robin_idx = idx + 1 return identities[idx] + def _log_protocol_error( + self, error_type: str, message: str, context: dict | None = None + ) -> None: + """Observability enhancement that complements (not replaces) PR #4419 robustness. + + Centralizes protocol error classification for structured logging and metrics attribution. + """ + context = context or {} + logging.warning( + "Coordinator protocol error | type=%s message=%s context=%s", + error_type, + message, + context, + ) + + metric_name = { + "client_error": "coordinator_invalid_message_total", + "internal_error": "coordinator_internal_error_total", + }.get(error_type) + if metric_name is not None: + self.metrics.inc(metric_name) + def _remove_engine(self, identity): """Remove a disconnected engine from the routing pool.""" self.identities_of_data_parallel_ranks.remove(identity) @@ -222,6 +252,8 @@ def _remove_engine(self, identity): identity, len(self.identities_of_data_parallel_ranks), ) + # Best-effort gauge update; eventual consistency is sufficient for observability. + self.metrics.gauge("coordinator_active_engines", float(len(self.identities_of_data_parallel_ranks))) def _send_to_engine(self, identity, payload): """Send payload to an engine, removing it from the pool if unreachable. @@ -233,6 +265,9 @@ def _send_to_engine(self, identity, payload): self.router_socket.send_multipart([identity, payload]) return True except zmq.error.ZMQError as e: + # We treat all send failures as "unreachable" signals for observability + # to surface communication reliability issues, without changing behavior + self.metrics.inc("coordinator_engine_unreachable_total") if e.errno == zmq.EHOSTUNREACH: self._remove_engine(identity) return False @@ -256,7 +291,7 @@ def compute_request_hashes(self, prompt): token_tensor = torch.tensor(tokens, dtype=torch.int64) return compute_block_hashes_batched(token_tensor, self.block_size_tokens) - def get_best_data_parallel_rank(self, request_hashes): + def get_best_data_parallel_rank(self, request_hashes, record_metrics: bool = False): """Select the best DP rank based on prefix cache affinity. Iterates request hashes in reverse order and picks the rank that cached @@ -285,8 +320,17 @@ def get_best_data_parallel_rank(self, request_hashes): if rank_info: # Pick the most recently assigned rank. best_rank = max(rank_info, key=rank_info.get) + # Detect stale entries: rank matched in hash table but no longer alive. + if record_metrics: + if best_rank not in self.identities_of_data_parallel_ranks: + self.metrics.inc("routing_stale_detected_total") + else: + self.metrics.inc("routing_cache_hit_total") return best_rank + # No hash match — fall back to round-robin. + if record_metrics: + self.metrics.inc("routing_cache_miss_total") return self.get_next_data_parallel_rank() def _update_rank_hashes(self, rank_identity, request_hashes): @@ -315,184 +359,229 @@ def start(self): known_clients = set() while True: sender_identity, serialized_payload = self.router_socket.recv_multipart() + _message_start = time.monotonic() - # Allow for re-registration if connecting to a running coordinator. - if serialized_payload == b"": - if sender_identity not in self.identities_of_data_parallel_ranks: - self.identities_of_data_parallel_ranks.append(sender_identity) - continue - - deserialized_payload = msgpack.unpackb(serialized_payload, raw=False) - header = Headers(deserialized_payload[0]) + try: - if header == Headers.CONNECT: - if sender_identity in known_clients: - logging.info( - f"Client {sender_identity} sent a duplicate connect request. Ignoring .." - ) + # Allow for re-registration if connecting to a running coordinator. + if serialized_payload == b"": + if sender_identity not in self.identities_of_data_parallel_ranks: + self.identities_of_data_parallel_ranks.append(sender_identity) + # Best-effort gauge update; eventual consistency is sufficient for observability. + self.metrics.gauge( + "coordinator_active_engines", + float(len(self.identities_of_data_parallel_ranks)), + ) continue - # print(f"New client connected: {sender_identity}") - known_clients.add(sender_identity) - self.router_socket.send_multipart( - [sender_identity, msgpack.packb([Headers.CONNECT_ACK.value], use_bin_type=True)] - ) + deserialized_payload = msgpack.unpackb(serialized_payload, raw=False) + header = Headers(deserialized_payload[0]) - elif header == Headers.SUBMIT_REQUEST: - # ToDo [Siddharth]: We might want to tokenize the prompt on the - # assigned data parallel rank for this process instead - # of the coordinator. + if header == Headers.CONNECT: + if sender_identity in known_clients: + logging.info( + f"Client {sender_identity} sent a duplicate connect request. Ignoring .." + ) + continue - # Message from a known client - if sender_identity not in known_clients: - logging.info( - f"Received message from unknown client {sender_identity}. Ignoring." + # print(f"New client connected: {sender_identity}") + known_clients.add(sender_identity) + self.router_socket.send_multipart( + [ + sender_identity, + msgpack.packb([Headers.CONNECT_ACK.value], use_bin_type=True), + ] ) - continue - # this is a message from a client. - # route it to a data parallel rank - client_request_id, prompt, sampling_params = deserialized_payload[1:] - # map client request_id to server request_id - # necessary because multiple clients might have the same request_id. - request_id = self.next_request_id - self.next_request_id += 1 - self.request_id_to_client_id[request_id] = sender_identity - self.request_id_to_client_request_id[request_id] = client_request_id - - # Serialize prompt. - if isinstance(prompt, (str, list)): - pass - elif isinstance(prompt, torch.Tensor): - prompt = prompt.tolist() - else: - raise Exception("specialize for <%s> prompt." % type(prompt).__name__) - payload = msgpack.packb( - [Headers.SUBMIT_REQUEST.value, request_id, prompt, sampling_params], - use_bin_type=True, - ) + elif header == Headers.SUBMIT_REQUEST: + # ToDo [Siddharth]: We might want to tokenize the prompt on the + # assigned data parallel rank for this process instead + # of the coordinator. + + # Message from a known client + if sender_identity not in known_clients: + self.metrics.inc("coordinator_unknown_sender_total") + logging.info( + f"Received message from unknown client {sender_identity}. Ignoring." + ) + continue + # this is a message from a client. + # route it to a data parallel rank + client_request_id, prompt, sampling_params = deserialized_payload[1:] + # map client request_id to server request_id + # necessary because multiple clients might have the same request_id. + request_id = self.next_request_id + self.next_request_id += 1 + self.request_id_to_client_id[request_id] = sender_identity + self.request_id_to_client_request_id[request_id] = client_request_id + + # Serialize prompt. + if isinstance(prompt, (str, list)): + pass + elif isinstance(prompt, torch.Tensor): + prompt = prompt.tolist() + else: + raise Exception("specialize for <%s> prompt." % type(prompt).__name__) - request_hashes = self.compute_request_hashes(prompt) - if ( - self.prefix_caching_coordinator_policy - == PrefixCachingCoordinatorPolicy.FIRST_PREFIX_BLOCK - ): - request_hashes = request_hashes[:1] + payload = msgpack.packb( + [Headers.SUBMIT_REQUEST.value, request_id, prompt, sampling_params], + use_bin_type=True, + ) - # Account for the fact that some engines may have died. - for _ in range(len(self.identities_of_data_parallel_ranks)): - next_identity = self.get_best_data_parallel_rank(request_hashes) - if self._send_to_engine(next_identity, payload): - break - else: - # If all engines have died, we are in an abnormal state, and must exit cleanly. - logging.error("Coordinator: no reachable engines for request %d", request_id) - del self.request_id_to_client_id[request_id] - del self.request_id_to_client_request_id[request_id] - return - - if request_hashes: - self._update_rank_hashes(next_identity, request_hashes) - if self.schedule_records is not None: - self.schedule_records.append( - { - "request_id": request_id, - "rank_index": self.identity_to_rank_index[next_identity], - "num_hashes": len(request_hashes), - } + _routing_start = time.monotonic() + request_hashes = self.compute_request_hashes(prompt) + if ( + self.prefix_caching_coordinator_policy + == PrefixCachingCoordinatorPolicy.FIRST_PREFIX_BLOCK + ): + request_hashes = request_hashes[:1] + + # Metrics are recorded once per request to avoid inflation + record_routing_metrics = True + # Account for the fact that some engines may have died. + for _ in range(len(self.identities_of_data_parallel_ranks)): + next_identity = self.get_best_data_parallel_rank( + request_hashes, record_metrics=record_routing_metrics + ) + record_routing_metrics = False + if self._send_to_engine(next_identity, payload): + break + else: + # If all engines have died, we are in an abnormal state, and must exit cleanly. + self.metrics.inc("coordinator_all_engines_exhausted_total") + logging.error("Coordinator: no reachable engines for request %d", request_id) + del self.request_id_to_client_id[request_id] + del self.request_id_to_client_request_id[request_id] + return + + self.metrics.observe( + "coordinator_routing_latency_seconds", time.monotonic() - _routing_start ) - elif header in ( - Headers.PAUSE, - Headers.UNPAUSE, - Headers.SUSPEND, - Headers.RESUME, - Headers.SET_GENERATION_EPOCH, - Headers.STOP, - ): - # Start by checking the current state against the control signal. - if sender_identity not in known_clients: - logging.warning("Coordinator: ignoring signal from unknown client.") - continue + if request_hashes: + self._update_rank_hashes(next_identity, request_hashes) + if self.schedule_records is not None: + self.schedule_records.append( + { + "request_id": request_id, + "rank_index": self.identity_to_rank_index[next_identity], + "num_hashes": len(request_hashes), + } + ) + + elif header in ( + Headers.PAUSE, + Headers.UNPAUSE, + Headers.SUSPEND, + Headers.RESUME, + Headers.SET_GENERATION_EPOCH, + Headers.STOP, + ): + # Start by checking the current state against the control signal. + if sender_identity not in known_clients: + self.metrics.inc("coordinator_unknown_sender_total") + logging.warning("Coordinator: ignoring signal from unknown client.") + continue - if header == Headers.PAUSE: - idem_states = (self.CoordinatorState.PAUSED, self.CoordinatorState.SUSPENDED) - if self.state == self.CoordinatorState.RUNNING: + if header == Headers.PAUSE: + idem_states = (self.CoordinatorState.PAUSED, self.CoordinatorState.SUSPENDED) + if self.state == self.CoordinatorState.RUNNING: + self.state = self.CoordinatorState.PAUSED + elif self.state in idem_states: + # Already paused/suspended, ignore redundant PAUSE. + continue + else: + logging.warning("Coordinator: ignoring PAUSE in state %s", self.state) + continue + elif header == Headers.UNPAUSE: + if self.state != self.CoordinatorState.PAUSED: + logging.warning("Coordinator: ignoring UNPAUSE in state %s", self.state) + continue + self.state = self.CoordinatorState.RUNNING + elif header == Headers.SUSPEND: + if self.state != self.CoordinatorState.PAUSED: + logging.warning("Coordinator: ignoring SUSPEND in state %s", self.state) + continue + self.state = self.CoordinatorState.SUSPENDED + elif header == Headers.RESUME: + if self.state != self.CoordinatorState.SUSPENDED: + logging.warning("Coordinator: ignoring RESUME in state %s", self.state) + continue self.state = self.CoordinatorState.PAUSED - elif self.state in idem_states: - # Already paused/suspended, ignore redundant PAUSE. - continue - else: - logging.warning("Coordinator: ignoring PAUSE in state %s", self.state) - continue - elif header == Headers.UNPAUSE: - if self.state != self.CoordinatorState.PAUSED: - logging.warning("Coordinator: ignoring UNPAUSE in state %s", self.state) - continue - self.state = self.CoordinatorState.RUNNING - elif header == Headers.SUSPEND: - if self.state != self.CoordinatorState.PAUSED: - logging.warning("Coordinator: ignoring SUSPEND in state %s", self.state) - continue - self.state = self.CoordinatorState.SUSPENDED - elif header == Headers.RESUME: - if self.state != self.CoordinatorState.SUSPENDED: - logging.warning("Coordinator: ignoring RESUME in state %s", self.state) + elif header == Headers.STOP: + good_states = (self.CoordinatorState.PAUSED, self.CoordinatorState.SUSPENDED) + if self.state not in good_states: + logging.warning("Coordinator: ignoring STOP in state %s", self.state) + continue + self.state = self.CoordinatorState.STOPPING + + # Broadcast the control signal if we're in a good state. + # Forward the full deserialized payload so that data-bearing + # signals (e.g. SET_GENERATION_EPOCH) retain their arguments. + broadcast_payload = msgpack.packb(deserialized_payload, use_bin_type=True) + for data_parallel_rank_id in list(self.identities_of_data_parallel_ranks): + self._send_to_engine(data_parallel_rank_id, broadcast_payload) + + # STOP affects engines; reset coordinator to RUNNING to allow future engines. + if header == Headers.STOP: + self.state = self.CoordinatorState.RUNNING + + elif header == Headers.ENGINE_REPLY: + # This is the output of a single engine step on some data parallel rank. + if sender_identity not in self.identities_of_data_parallel_ranks: + self._log_protocol_error( + "internal_error", + "ENGINE_REPLY from unregistered engine", + context={"sender": repr(sender_identity)}, + ) + assert sender_identity in self.identities_of_data_parallel_ranks + finished_requests = deserialized_payload[1] + + for finished_request in finished_requests: + self.detokenize(finished_request) + fid = finished_request["request_id"] + client_identity = self.request_id_to_client_id[fid] + client_request_identity = self.request_id_to_client_request_id[fid] + del self.request_id_to_client_id[fid] + del self.request_id_to_client_request_id[fid] + + self.router_socket.send_multipart( + [ + client_identity, + msgpack.packb( + [header.value, client_request_identity, finished_request], + use_bin_type=True, + ), + ] + ) + + elif header == Headers.SHUTDOWN: + if sender_identity not in known_clients: + self.metrics.inc("coordinator_unknown_sender_total") + logging.warning("Coordinator: ignoring signal from unknown client.") continue - self.state = self.CoordinatorState.PAUSED - elif header == Headers.STOP: - good_states = (self.CoordinatorState.PAUSED, self.CoordinatorState.SUSPENDED) - if self.state not in good_states: - logging.warning("Coordinator: ignoring STOP in state %s", self.state) - continue - self.state = self.CoordinatorState.STOPPING - - # Broadcast the control signal if we're in a good state. - # Forward the full deserialized payload so that data-bearing - # signals (e.g. SET_GENERATION_EPOCH) retain their arguments. - broadcast_payload = msgpack.packb(deserialized_payload, use_bin_type=True) - for data_parallel_rank_id in list(self.identities_of_data_parallel_ranks): - self._send_to_engine(data_parallel_rank_id, broadcast_payload) - - # STOP affects engines; reset coordinator to RUNNING to allow future engines. - if header == Headers.STOP: - self.state = self.CoordinatorState.RUNNING - - elif header == Headers.ENGINE_REPLY: - # This is the output of a single engine step on some data parallel rank. - assert sender_identity in self.identities_of_data_parallel_ranks - finished_requests = deserialized_payload[1] - - for finished_request in finished_requests: - self.detokenize(finished_request) - fid = finished_request["request_id"] - client_identity = self.request_id_to_client_id[fid] - client_request_identity = self.request_id_to_client_request_id[fid] - del self.request_id_to_client_id[fid] - del self.request_id_to_client_request_id[fid] + break - self.router_socket.send_multipart( - [ - client_identity, - msgpack.packb( - [header.value, client_request_identity, finished_request], - use_bin_type=True, - ), - ] - ) + elif header == Headers.DISCONNECT: + if sender_identity in self.identities_of_data_parallel_ranks: + self._remove_engine(sender_identity) - elif header == Headers.SHUTDOWN: - if sender_identity not in known_clients: - logging.warning("Coordinator: ignoring signal from unknown client.") + else: + # Unknown headers are treated as client errors and safely ignored + # to preserve coordinator stability (consistent with protocol robustness goals) + self._log_protocol_error( + "client_error", + "Unrecognized message header", + context={"header": repr(header)}, + ) continue - break - - elif header == Headers.DISCONNECT: - if sender_identity in self.identities_of_data_parallel_ranks: - self._remove_engine(sender_identity) - - else: - raise UnknownHeaderError(header) + finally: + # Includes full message handling time (routing + processing) + # Separate routing latency metric captures hash computation, routing decision, and send overhead + self.metrics.observe( + "coordinator_message_processing_latency_seconds", + time.monotonic() - _message_start, + ) def detokenize(self, finished_request): """ @@ -533,6 +622,7 @@ def entrypoint( PrefixCachingCoordinatorPolicy.FIRST_PREFIX_BLOCK ), schedule_output_path: str | None = None, + metrics: CoordinatorMetrics | None = None, ): """ Class method to instantiate and run the coordinator, for use in a separate process. @@ -551,6 +641,7 @@ def entrypoint( enable_prefix_caching (bool): Whether prefix caching is enabled. prefix_caching_coordinator_policy (PrefixCachingCoordinatorPolicy): Routing policy. schedule_output_path (Optional[str]): Path to write scheduling decisions JSON. + metrics (Optional[CoordinatorMetrics]): Metrics backend. Defaults to NoOpMetrics. """ coordinator = cls( pipe_connection, @@ -562,6 +653,7 @@ def entrypoint( enable_prefix_caching=enable_prefix_caching, prefix_caching_coordinator_policy=prefix_caching_coordinator_policy, schedule_output_path=schedule_output_path, + metrics=metrics, ) ready_event.set() try: diff --git a/tests/unit_tests/inference/test_coordinator_observability.py b/tests/unit_tests/inference/test_coordinator_observability.py new file mode 100644 index 00000000000..9312789b933 --- /dev/null +++ b/tests/unit_tests/inference/test_coordinator_observability.py @@ -0,0 +1,721 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Unit tests for coordinator observability (metrics instrumentation). + +Tests verify that the correct metric counters and observations are emitted by +the DataParallelInferenceCoordinator for each instrumented scenario. + +The TestMetrics class is used as an in-memory backend so tests can inspect +recorded values without any real metrics infrastructure. +""" + +from collections import defaultdict, deque +from unittest.mock import MagicMock, patch + +import pytest + +from megatron.core.inference.config import PrefixCachingCoordinatorPolicy +from megatron.core.inference.coordinator_metrics import CoordinatorMetrics, NoOpMetrics + + +# --------------------------------------------------------------------------- +# In-memory metrics implementation used by tests +# --------------------------------------------------------------------------- + + +class TestMetrics(CoordinatorMetrics): + """In-memory metrics backend for unit testing. + + Captures all metric operations so test cases can assert on their effects + without depending on a real metrics backend. + """ + + def __init__(self): + self.counters: dict[str, int] = defaultdict(int) + self.observations: dict[str, list[float]] = defaultdict(list) + self.gauges: dict[str, float] = {} + + def inc(self, name: str, value: int = 1) -> None: + self.counters[name] += value + + def observe(self, name: str, value: float) -> None: + self.observations[name].append(value) + + def gauge(self, name: str, value: float) -> None: + self.gauges[name] = value + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_minimal_coordinator( + *, + data_parallel_size: int = 1, + enable_prefix_caching: bool = False, + block_size_tokens: int = 4, + policy: PrefixCachingCoordinatorPolicy = PrefixCachingCoordinatorPolicy.FIRST_PREFIX_BLOCK, + metrics: CoordinatorMetrics | None = None, +): + """Build a DataParallelInferenceCoordinator without ZMQ or tokenizer. + + Bypasses __init__ to avoid ZMQ socket setup. Only the fields required by + the methods under test are initialised. + """ + # Import here so that missing optional deps don't break collection. + from megatron.core.inference.data_parallel_inference_coordinator import ( + DataParallelInferenceCoordinator, + ) + + coordinator = object.__new__(DataParallelInferenceCoordinator) + + # Mimic the state set by __init__ that the tested methods rely on. + coordinator.data_parallel_size = data_parallel_size + coordinator.enable_prefix_caching = enable_prefix_caching + coordinator.block_size_tokens = block_size_tokens + coordinator.prefix_caching_coordinator_policy = policy + coordinator.hash_to_rank_info = {} + coordinator._assignment_counter = 0 + coordinator._round_robin_idx = 0 + coordinator.next_request_id = 0 + coordinator.request_id_to_client_id = {} + coordinator.request_id_to_client_request_id = {} + coordinator.schedule_records = None + coordinator.identity_to_rank_index = {} + + # Use two fake engine identities. + coordinator.identities_of_data_parallel_ranks = deque([b"engine-0", b"engine-1"]) + coordinator.metrics = metrics if metrics is not None else NoOpMetrics() + + return coordinator + + +def _make_start_loop_coordinator( + metrics: CoordinatorMetrics, + num_engines: int = 1, +): + """Build a coordinator ready to run start() with a mock ZMQ socket. + + Bypasses __init__ and populates only the state required by the start() loop. + The caller must set mock_socket.recv_multipart.side_effect with the desired + message sequence before calling coordinator.start(). + """ + from megatron.core.inference.data_parallel_inference_coordinator import ( + DataParallelInferenceCoordinator, + ) + + coordinator = object.__new__(DataParallelInferenceCoordinator) + coordinator.data_parallel_size = num_engines + coordinator.enable_prefix_caching = False + coordinator.block_size_tokens = 4 + coordinator.prefix_caching_coordinator_policy = PrefixCachingCoordinatorPolicy.FIRST_PREFIX_BLOCK + coordinator.hash_to_rank_info = {} + coordinator._assignment_counter = 0 + coordinator._round_robin_idx = 0 + coordinator.next_request_id = 0 + coordinator.request_id_to_client_id = {} + coordinator.request_id_to_client_request_id = {} + coordinator.schedule_records = None + coordinator.state = DataParallelInferenceCoordinator.CoordinatorState.RUNNING + + engines = [f"engine-{i}".encode() for i in range(num_engines)] + coordinator.identities_of_data_parallel_ranks = deque(engines) + coordinator.identity_to_rank_index = {e: i for i, e in enumerate(engines)} + coordinator.metrics = metrics + coordinator.tokenizer = MagicMock() + + mock_socket = MagicMock() + coordinator.router_socket = mock_socket + + return coordinator, mock_socket + + +# --------------------------------------------------------------------------- +# Tests: CoordinatorMetrics abstraction +# --------------------------------------------------------------------------- + + +class TestCoordinatorMetricsAbstraction: + def test_metrics_noop_calls_no_exception(self): + m = NoOpMetrics() + m.inc("any_counter") + m.observe("any_observation", 0.1) + m.gauge("any_gauge", 42.0) + + def test_metrics_counter_increments_accumulated_value(self): + m = TestMetrics() + m.inc("foo") + m.inc("foo") + m.inc("bar", 3) + assert m.counters["foo"] == 2 + assert m.counters["bar"] == 3 + + def test_metrics_observe_appends_samples_in_order(self): + m = TestMetrics() + m.observe("latency", 0.5) + m.observe("latency", 1.2) + assert m.observations["latency"] == [0.5, 1.2] + + def test_metrics_gauge_overwrite_keeps_latest_value(self): + m = TestMetrics() + m.gauge("engines", 4.0) + m.gauge("engines", 3.0) # overwrite + assert m.gauges["engines"] == 3.0 + + +# --------------------------------------------------------------------------- +# Tests: _log_protocol_error +# --------------------------------------------------------------------------- + + +class TestLogProtocolError: + def test_protocol_error_client_error_increments_invalid_message_counter(self): + m = TestMetrics() + coordinator = _make_minimal_coordinator(metrics=m) + coordinator._log_protocol_error("client_error", "bad header") + assert m.counters["coordinator_invalid_message_total"] == 1 + + def test_protocol_error_internal_error_increments_internal_error_counter(self): + m = TestMetrics() + coordinator = _make_minimal_coordinator(metrics=m) + coordinator._log_protocol_error("internal_error", "unexpected state") + assert m.counters["coordinator_internal_error_total"] == 1 + + def test_protocol_error_unknown_type_does_not_increment_known_counters(self): + m = TestMetrics() + coordinator = _make_minimal_coordinator(metrics=m) + coordinator._log_protocol_error("bogus_type", "some message") + assert m.counters["coordinator_invalid_message_total"] == 0 + assert m.counters["coordinator_internal_error_total"] == 0 + + def test_protocol_error_multiple_calls_accumulate_counters(self): + m = TestMetrics() + coordinator = _make_minimal_coordinator(metrics=m) + coordinator._log_protocol_error("client_error", "err1") + coordinator._log_protocol_error("client_error", "err2") + coordinator._log_protocol_error("internal_error", "err3") + assert m.counters["coordinator_invalid_message_total"] == 2 + assert m.counters["coordinator_internal_error_total"] == 1 + + +# --------------------------------------------------------------------------- +# Tests: routing quality metrics +# --------------------------------------------------------------------------- + + +class TestRoutingMetrics: + """Tests for cache hit / miss / stale routing metrics.""" + + def _make_cache_coordinator(self, policy=PrefixCachingCoordinatorPolicy.FIRST_PREFIX_BLOCK): + m = TestMetrics() + c = _make_minimal_coordinator( + enable_prefix_caching=True, + block_size_tokens=4, + policy=policy, + metrics=m, + ) + return c, m + + def test_routing_cache_hit_increments_hit_counter(self): + c, m = self._make_cache_coordinator() + # Pre-populate hash table: hash 99 → engine-0 alive in pool. + c.hash_to_rank_info[99] = {b"engine-0": 1} + + result = c.get_best_data_parallel_rank([99], record_metrics=True) + + assert result == b"engine-0" + assert m.counters["routing_cache_hit_total"] == 1 + assert m.counters["routing_cache_miss_total"] == 0 + assert m.counters["routing_stale_detected_total"] == 0 + + def test_routing_cache_miss_increments_miss_counter(self): + c, m = self._make_cache_coordinator() + # hash_to_rank_info is empty → no match. + + c.get_best_data_parallel_rank([42, 43], record_metrics=True) + + assert m.counters["routing_cache_miss_total"] == 1 + assert m.counters["routing_cache_hit_total"] == 0 + + def test_routing_stale_entry_increments_stale_counter(self): + c, m = self._make_cache_coordinator() + # Hash points to a rank that has been removed from the pool. + c.hash_to_rank_info[99] = {b"dead-engine": 1} + # b"dead-engine" is NOT in identities_of_data_parallel_ranks. + + c.get_best_data_parallel_rank([99], record_metrics=True) + + assert m.counters["routing_stale_detected_total"] == 1 + assert m.counters["routing_cache_hit_total"] == 0 + + def test_routing_prefix_caching_disabled_records_no_quality_metrics(self): + m = TestMetrics() + c = _make_minimal_coordinator(enable_prefix_caching=False, metrics=m) + c.hash_to_rank_info[99] = {b"engine-0": 1} + + c.get_best_data_parallel_rank([99], record_metrics=True) + + # Routing-quality counters are meaningless when prefix caching is off. + assert m.counters["routing_cache_hit_total"] == 0 + assert m.counters["routing_cache_miss_total"] == 0 + assert m.counters["routing_stale_detected_total"] == 0 + + def test_routing_round_robin_policy_records_no_quality_metrics(self): + c, m = self._make_cache_coordinator( + policy=PrefixCachingCoordinatorPolicy.ROUND_ROBIN + ) + c.hash_to_rank_info[99] = {b"engine-0": 1} + + c.get_best_data_parallel_rank([99], record_metrics=True) + + assert m.counters["routing_cache_hit_total"] == 0 + assert m.counters["routing_cache_miss_total"] == 0 + + def test_routing_empty_hashes_records_no_quality_metrics(self): + c, m = self._make_cache_coordinator() + # No hashes → can't do prefix match. + c.get_best_data_parallel_rank([], record_metrics=True) + + assert m.counters["routing_cache_hit_total"] == 0 + assert m.counters["routing_cache_miss_total"] == 0 + + def test_routing_longest_prefix_policy_returns_cache_hit_metric(self): + c, m = self._make_cache_coordinator( + policy=PrefixCachingCoordinatorPolicy.LONGEST_PREFIX + ) + c.hash_to_rank_info[10] = {b"engine-0": 1} + c.hash_to_rank_info[20] = {b"engine-0": 2} + + # Reversed scan finds hash 20 first (longest match). + result = c.get_best_data_parallel_rank([10, 20], record_metrics=True) + + assert result == b"engine-0" + assert m.counters["routing_cache_hit_total"] == 1 + + def test_routing_record_metrics_false_prevents_double_counting(self): + c, m = self._make_cache_coordinator() + c.hash_to_rank_info[99] = {b"engine-0": 1} + + c.get_best_data_parallel_rank([99], record_metrics=True) + c.get_best_data_parallel_rank([99], record_metrics=False) + + assert m.counters["routing_cache_hit_total"] == 1 + + +# --------------------------------------------------------------------------- +# Tests: engine-unreachable metrics +# --------------------------------------------------------------------------- + + +class TestEngineUnreachableMetric: + def test_send_ehostunreach_failure_increments_unreachable_counter(self): + try: + import zmq + except ImportError: + pytest.skip("pyzmq not installed") + + m = TestMetrics() + c = _make_minimal_coordinator(metrics=m) + + # Fake a router socket that raises EHOSTUNREACH. + mock_socket = MagicMock() + mock_socket.send_multipart.side_effect = zmq.error.ZMQError(zmq.EHOSTUNREACH) + c.router_socket = mock_socket + c.identities_of_data_parallel_ranks = deque([b"engine-0", b"engine-1"]) + + result = c._send_to_engine(b"engine-0", b"payload") + + assert result is False + assert m.counters["coordinator_engine_unreachable_total"] == 1 + # Active engines gauge should reflect removal. + assert m.gauges["coordinator_active_engines"] == 1.0 + + def test_send_success_does_not_increment_unreachable_counter(self): + m = TestMetrics() + c = _make_minimal_coordinator(metrics=m) + + mock_socket = MagicMock() + c.router_socket = mock_socket + + result = c._send_to_engine(b"engine-0", b"payload") + + assert result is True + assert m.counters["coordinator_engine_unreachable_total"] == 0 + + def test_send_non_ehostunreach_failure_still_increments_unreachable_counter(self): + try: + import zmq + except ImportError: + pytest.skip("pyzmq not installed") + + m = TestMetrics() + c = _make_minimal_coordinator(metrics=m) + + mock_socket = MagicMock() + mock_socket.send_multipart.side_effect = zmq.error.ZMQError(zmq.EINVAL) + c.router_socket = mock_socket + + with pytest.raises(zmq.error.ZMQError): + c._send_to_engine(b"engine-0", b"payload") + + assert m.counters["coordinator_engine_unreachable_total"] == 1 + + +# --------------------------------------------------------------------------- +# Tests: routing latency is recorded +# --------------------------------------------------------------------------- + + +class TestRoutingLatencyRecorded: + """Verify that coordinator_routing_latency_seconds is recorded by start().""" + + def test_routing_latency_recorded_by_start_loop(self): + try: + import msgpack + except ImportError: + pytest.skip("msgpack not installed") + + from megatron.core.inference.headers import Headers + + m = TestMetrics() + coordinator, mock_socket = _make_start_loop_coordinator(m) + mock_socket.recv_multipart.side_effect = [ + (b"client-1", msgpack.packb([Headers.CONNECT.value], use_bin_type=True)), + ( + b"client-1", + msgpack.packb( + [Headers.SUBMIT_REQUEST.value, 0, [1, 2, 3], {}], use_bin_type=True + ), + ), + (b"client-1", msgpack.packb([Headers.SHUTDOWN.value], use_bin_type=True)), + ] + + coordinator.start() + + assert len(m.observations["coordinator_routing_latency_seconds"]) == 1 + assert m.observations["coordinator_routing_latency_seconds"][0] >= 0.0 + + +# --------------------------------------------------------------------------- +# Tests: active engines gauge +# --------------------------------------------------------------------------- + + +class TestActiveEnginesGauge: + def test_active_engines_remove_engine_updates_gauge_value(self): + m = TestMetrics() + c = _make_minimal_coordinator(metrics=m) + # Pool starts with 2 engines. + assert len(c.identities_of_data_parallel_ranks) == 2 + + c._remove_engine(b"engine-0") + + assert m.gauges["coordinator_active_engines"] == 1.0 + + def test_active_engines_multiple_removals_decrement_gauge_value(self): + m = TestMetrics() + c = _make_minimal_coordinator(metrics=m) + + c._remove_engine(b"engine-0") + c._remove_engine(b"engine-1") + + assert m.gauges["coordinator_active_engines"] == 0.0 + + +# --------------------------------------------------------------------------- +# Tests: default metrics (NoOp) does not crash +# --------------------------------------------------------------------------- + + +class TestDefaultNoOpMetrics: + def test_default_metrics_none_uses_noop_and_no_exception(self): + """When no metrics arg is passed, coordinator uses NoOpMetrics silently.""" + c = _make_minimal_coordinator(metrics=None) + assert isinstance(c.metrics, NoOpMetrics) + + # Execute a small basic flow with default NoOpMetrics. + c.router_socket = MagicMock() + assert c._send_to_engine(b"engine-0", b"payload") is True + + # Exercise instrumented paths; none should raise. + c._log_protocol_error("client_error", "test") + c.get_best_data_parallel_rank([], record_metrics=True) + c._remove_engine(b"engine-0") + + def test_default_metrics_parameter_optional_defaults_to_noop(self): + c = _make_minimal_coordinator(metrics=None) + assert isinstance(c.metrics, NoOpMetrics) + + +# --------------------------------------------------------------------------- +# Tests: message processing latency +# --------------------------------------------------------------------------- + + +class TestMessageProcessingLatencyRecorded: + """Verify coordinator_message_processing_latency_seconds is recorded per message.""" + + def test_message_processing_latency_recorded_per_message(self): + try: + import msgpack + except ImportError: + pytest.skip("msgpack not installed") + + from megatron.core.inference.headers import Headers + + m = TestMetrics() + coordinator, mock_socket = _make_start_loop_coordinator(m) + mock_socket.recv_multipart.side_effect = [ + (b"client-1", msgpack.packb([Headers.CONNECT.value], use_bin_type=True)), + ( + b"client-1", + msgpack.packb( + [Headers.SUBMIT_REQUEST.value, 0, [1, 2, 3], {}], use_bin_type=True + ), + ), + (b"client-1", msgpack.packb([Headers.SHUTDOWN.value], use_bin_type=True)), + ] + + coordinator.start() + + # One observation per message: CONNECT + SUBMIT_REQUEST + SHUTDOWN. + assert len(m.observations["coordinator_message_processing_latency_seconds"]) == 3 + assert all( + v >= 0.0 for v in m.observations["coordinator_message_processing_latency_seconds"] + ) + + +# --------------------------------------------------------------------------- +# Tests: unknown sender metric +# --------------------------------------------------------------------------- + + +class TestUnknownSenderMetric: + """Verify coordinator_unknown_sender_total increments for all unknown-sender paths.""" + + def test_submit_request_from_unknown_sender_increments_counter(self): + try: + import msgpack + except ImportError: + pytest.skip("msgpack not installed") + + from megatron.core.inference.headers import Headers + + m = TestMetrics() + coordinator, mock_socket = _make_start_loop_coordinator(m) + mock_socket.recv_multipart.side_effect = [ + ( + b"unknown-client", + msgpack.packb( + [Headers.SUBMIT_REQUEST.value, 0, [1, 2, 3], {}], use_bin_type=True + ), + ), + ] + + with pytest.raises(StopIteration): + coordinator.start() + + assert m.counters["coordinator_unknown_sender_total"] == 1 + + def test_control_signal_from_unknown_sender_increments_counter(self): + try: + import msgpack + except ImportError: + pytest.skip("msgpack not installed") + + from megatron.core.inference.headers import Headers + + m = TestMetrics() + coordinator, mock_socket = _make_start_loop_coordinator(m) + mock_socket.recv_multipart.side_effect = [ + (b"unknown-client", msgpack.packb([Headers.PAUSE.value], use_bin_type=True)), + ] + + with pytest.raises(StopIteration): + coordinator.start() + + assert m.counters["coordinator_unknown_sender_total"] == 1 + + def test_shutdown_from_unknown_sender_increments_counter(self): + try: + import msgpack + except ImportError: + pytest.skip("msgpack not installed") + + from megatron.core.inference.headers import Headers + + m = TestMetrics() + coordinator, mock_socket = _make_start_loop_coordinator(m) + mock_socket.recv_multipart.side_effect = [ + (b"unknown-client", msgpack.packb([Headers.SHUTDOWN.value], use_bin_type=True)), + ] + + with pytest.raises(StopIteration): + coordinator.start() + + assert m.counters["coordinator_unknown_sender_total"] == 1 + + +# --------------------------------------------------------------------------- +# Tests: all-engines-exhausted metric +# --------------------------------------------------------------------------- + + +class TestAllEnginesExhaustedMetric: + """Verify coordinator_all_engines_exhausted_total fires when every engine is unreachable.""" + + def test_all_engines_exhausted_metric_increments_when_no_engine_reachable(self): + try: + import zmq + import msgpack + except ImportError: + pytest.skip("pyzmq and msgpack required") + + from megatron.core.inference.headers import Headers + + m = TestMetrics() + coordinator, mock_socket = _make_start_loop_coordinator(m, num_engines=1) + + # CONNECT ACK is sent to b"client-1"; engine sends go to b"engine-0". + def send_side_effect(parts): + if parts[0] == b"engine-0": + raise zmq.error.ZMQError(zmq.EHOSTUNREACH) + + mock_socket.send_multipart.side_effect = send_side_effect + mock_socket.recv_multipart.side_effect = [ + (b"client-1", msgpack.packb([Headers.CONNECT.value], use_bin_type=True)), + ( + b"client-1", + msgpack.packb( + [Headers.SUBMIT_REQUEST.value, 0, [1, 2, 3], {}], use_bin_type=True + ), + ), + ] + + coordinator.start() # Returns via return after all-engines-dead path. + + assert m.counters["coordinator_all_engines_exhausted_total"] == 1 + + +# --------------------------------------------------------------------------- +# Tests: ENGINE_REPLY internal error metric +# --------------------------------------------------------------------------- + + +class TestEngineReplyInternalErrorMetric: + """Verify coordinator_internal_error_total fires for ENGINE_REPLY from unregistered engine.""" + + def test_engine_reply_from_unregistered_sender_emits_internal_error_metric(self): + try: + import msgpack + except ImportError: + pytest.skip("msgpack not installed") + + from megatron.core.inference.headers import Headers + + m = TestMetrics() + coordinator, mock_socket = _make_start_loop_coordinator(m) + + # ENGINE_REPLY from a sender not in identities_of_data_parallel_ranks. + mock_socket.recv_multipart.side_effect = [ + ( + b"unknown-engine", + msgpack.packb([Headers.ENGINE_REPLY.value, []], use_bin_type=True), + ), + ] + + with pytest.raises(AssertionError): + coordinator.start() + + assert m.counters["coordinator_internal_error_total"] == 1 + + +# --------------------------------------------------------------------------- +# Tests: initial active engines gauge (via __init__) +# --------------------------------------------------------------------------- + + +class TestInitActiveEnginesGauge: + """Verify coordinator_active_engines gauge is written during __init__.""" + + def test_active_engines_gauge_set_at_init_time(self): + try: + import zmq + import msgpack + except ImportError: + pytest.skip("pyzmq and msgpack required") + + from megatron.core.inference.data_parallel_inference_coordinator import ( + DataParallelInferenceCoordinator, + ) + + m = TestMetrics() + mock_pipe = MagicMock() + mock_socket = MagicMock() + mock_socket.recv_multipart.side_effect = [(b"eng-0", b""), (b"eng-1", b"")] + mock_socket.getsockopt_string.return_value = "tcp://localhost:1234" + + mock_context = MagicMock() + mock_context.socket.return_value = mock_socket + + with patch( + "megatron.core.inference.data_parallel_inference_coordinator.zmq.Context", + return_value=mock_context, + ), patch("socket.gethostname", return_value="localhost"): + coordinator = DataParallelInferenceCoordinator( + pipe_connection=mock_pipe, + data_parallel_size=2, + tokenizer=MagicMock(), + metrics=m, + ) + + assert m.gauges.get("coordinator_active_engines") == 2.0 + + +# --------------------------------------------------------------------------- +# Tests: entrypoint metrics forwarding +# --------------------------------------------------------------------------- + + +class TestEntrypointMetricsForwarding: + """Verify metrics passed to entrypoint() reach the constructed coordinator.""" + + def test_entrypoint_forwards_metrics_to_coordinator(self): + try: + import zmq + import msgpack + except ImportError: + pytest.skip("pyzmq and msgpack required") + + from megatron.core.inference.data_parallel_inference_coordinator import ( + DataParallelInferenceCoordinator, + ) + + m = TestMetrics() + mock_pipe = MagicMock() + mock_ready_event = MagicMock() + mock_socket = MagicMock() + mock_socket.recv_multipart.side_effect = [(b"eng-0", b"")] + mock_socket.getsockopt_string.return_value = "tcp://localhost:1234" + + mock_context = MagicMock() + mock_context.socket.return_value = mock_socket + + with patch( + "megatron.core.inference.data_parallel_inference_coordinator.zmq.Context", + return_value=mock_context, + ), patch("socket.gethostname", return_value="localhost"), patch.object( + DataParallelInferenceCoordinator, "start" + ): + DataParallelInferenceCoordinator.entrypoint( + pipe_connection=mock_pipe, + ready_event=mock_ready_event, + data_parallel_size=1, + tokenizer=MagicMock(), + metrics=m, + ) + + mock_ready_event.set.assert_called_once() + # If metrics were forwarded correctly, the init gauge was written. + assert m.gauges.get("coordinator_active_engines") == 1.0