From c40a1a33852262996c8eb6b7c5ce401b1ba2cf7b Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 3 Mar 2026 08:36:16 +0530 Subject: [PATCH 1/7] fix: Clean shutdown for Sink threaded server using threading.Event Signed-off-by: Sreekanth --- .../pynumaflow/pynumaflow/shared/server.py | 20 + .../pynumaflow/shared/thread_with_return.py | 18 +- .../pynumaflow/pynumaflow/sinker/server.py | 7 + .../sinker/servicer/sync_servicer.py | 15 +- packages/pynumaflow/tests/sink/test_server.py | 587 ++++++++++-------- 5 files changed, 385 insertions(+), 262 deletions(-) diff --git a/packages/pynumaflow/pynumaflow/shared/server.py b/packages/pynumaflow/pynumaflow/shared/server.py index 6c5eb491..630256e1 100644 --- a/packages/pynumaflow/pynumaflow/shared/server.py +++ b/packages/pynumaflow/pynumaflow/shared/server.py @@ -4,6 +4,7 @@ import multiprocessing import os import socket +import threading import traceback from google.protobuf import any_pb2 @@ -18,6 +19,7 @@ from pynumaflow._constants import ( _LOGGER, MULTIPROC_MAP_SOCK_ADDR, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, UDFType, ) from pynumaflow.exceptions import SocketError @@ -57,6 +59,7 @@ def sync_server_start( server_options=None, server_info: ServerInfo | None = None, udf_type: str = UDFType.Map, + shutdown_event: threading.Event | None = None, ): """ Utility function to start a sync grpc server instance. @@ -75,6 +78,7 @@ def sync_server_start( udf_type=udf_type, server_info_file=server_info_file, server_info=server_info, + shutdown_event=shutdown_event, ) @@ -86,10 +90,15 @@ def _run_server( udf_type: str, server_info_file: str | None = None, server_info: ServerInfo | None = None, + shutdown_event: threading.Event | None = None, ) -> None: """ Starts the Synchronous server instance on the given UNIX socket with given max threads. Wait for the server to terminate. + + If *shutdown_event* is provided, a background daemon thread will wait + on it and then call ``server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)`` + for a cooperative graceful shutdown (no process kill). """ server = grpc.server( ThreadPoolExecutor( @@ -115,10 +124,21 @@ def _run_server( server.add_insecure_port(bind_address) # start the gRPC server server.start() + # Add the server information to the server info file if provided if server_info and server_info_file: info_server_write(server_info=server_info, info_file=server_info_file) + if shutdown_event is not None: + + def _watch_for_shutdown(): + shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + watcher = threading.Thread(target=_watch_for_shutdown, daemon=True) + watcher.start() + _LOGGER.info("GRPC Server listening on: %s %d", bind_address, os.getpid()) server.wait_for_termination() diff --git a/packages/pynumaflow/pynumaflow/shared/thread_with_return.py b/packages/pynumaflow/pynumaflow/shared/thread_with_return.py index 9b9a7643..92d9cfac 100644 --- a/packages/pynumaflow/pynumaflow/shared/thread_with_return.py +++ b/packages/pynumaflow/pynumaflow/shared/thread_with_return.py @@ -5,6 +5,7 @@ class ThreadWithReturnValue(Thread): """ A custom Thread class that allows the target function to return a value. This class extends the built-in threading.Thread class. + Exceptions raised by the target are captured and re-raised on join(). """ def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, verbose=None): @@ -23,6 +24,7 @@ def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, verbo Thread.__init__(self, group, target, name, args, kwargs) # Variable to store the return value of the target function self._return = None + self._exception: BaseException | None = None def run(self): """ @@ -30,25 +32,35 @@ def run(self): This method is overridden from the Thread class. It calls the target function and saves the return value. + If the target raises, the exception is captured for re-raising on join(). """ if self._target is not None: - # Execute target and store the result - self._return = self._target(*self._args, **self._kwargs) + try: + # Execute target and store the result + self._return = self._target(*self._args, **self._kwargs) + except BaseException as exc: + self._exception = exc def join(self, *args): """ Wait for the thread to complete and return the result. This method is overridden from the Thread class. - It calls the parent class's join() method and then returns the stored return value. + It calls the parent class's join() method, re-raises any captured + exception, and then returns the stored return value. Parameters: *args: Variable length argument list to pass to the join() method. Returns: The return value from the target function. + + Raises: + BaseException: If the target function raised during run(). """ # Call the parent class's join() method to wait for the thread to finish Thread.join(self, *args) + if self._exception is not None: + raise self._exception # Return the result of the target function return self._return diff --git a/packages/pynumaflow/pynumaflow/sinker/server.py b/packages/pynumaflow/pynumaflow/sinker/server.py index 378dff12..2dd71080 100644 --- a/packages/pynumaflow/pynumaflow/sinker/server.py +++ b/packages/pynumaflow/pynumaflow/sinker/server.py @@ -1,4 +1,5 @@ import os +import sys from pynumaflow.info.types import ServerInfo, ContainerType, MINIMUM_NUMAFLOW_VERSION from pynumaflow.sinker.servicer.sync_servicer import SyncSinkServicer @@ -120,6 +121,7 @@ def start(self): ) serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Sinker] + # Start the server sync_server_start( servicer=self.servicer, @@ -129,4 +131,9 @@ def start(self): server_options=self._server_options, udf_type=UDFType.Sink, server_info=serv_info, + shutdown_event=self.servicer._shutdown_event, ) + + if self.servicer._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self.servicer._error) + sys.exit(1) diff --git a/packages/pynumaflow/pynumaflow/sinker/servicer/sync_servicer.py b/packages/pynumaflow/pynumaflow/sinker/servicer/sync_servicer.py index bdcf98d7..e7bce225 100644 --- a/packages/pynumaflow/pynumaflow/sinker/servicer/sync_servicer.py +++ b/packages/pynumaflow/pynumaflow/sinker/servicer/sync_servicer.py @@ -1,9 +1,10 @@ +import threading from collections.abc import Iterator -from pynumaflow._constants import _LOGGER, STREAM_EOF +from pynumaflow._constants import _LOGGER, STREAM_EOF, ERR_UDF_EXCEPTION_STRING from pynumaflow.proto.sinker import sink_pb2_grpc, sink_pb2 -from pynumaflow.shared.server import exit_on_error +from pynumaflow.shared.server import update_context_err from pynumaflow.shared.synciter import SyncIterator from pynumaflow.shared.thread_with_return import ThreadWithReturnValue from pynumaflow.sinker._dtypes import SinkSyncCallable @@ -24,6 +25,8 @@ class SyncSinkServicer(sink_pb2_grpc.SinkServicer): def __init__(self, handler: SinkSyncCallable): self.handler: SinkSyncCallable = handler + self._shutdown_event: threading.Event = threading.Event() + self._error: BaseException | None = None def SinkFn( self, request_iterator: Iterator[sink_pb2.SinkRequest], context: NumaflowServicerContext @@ -79,10 +82,11 @@ def SinkFn( cur_task.join() except BaseException as err: - # Handle exceptions - err_msg = f"UDSinkError: {repr(err)}" + err_msg = f"UDSinkError, {ERR_UDF_EXCEPTION_STRING}: {repr(err)}" _LOGGER.critical(err_msg, exc_info=True) - exit_on_error(context, err_msg) + update_context_err(context, err, err_msg) + self._error = err + self._shutdown_event.set() return def _invoke_sink(self, request_queue: SyncIterator, context: NumaflowServicerContext): @@ -93,7 +97,6 @@ def _invoke_sink(self, request_queue: SyncIterator, context: NumaflowServicerCon except BaseException as err: err_msg = f"UDSinkError: {repr(err)}" _LOGGER.critical(err_msg, exc_info=True) - exit_on_error(context, err_msg) raise err def IsReady(self, request, context: NumaflowServicerContext) -> sink_pb2.ReadyResponse: diff --git a/packages/pynumaflow/tests/sink/test_server.py b/packages/pynumaflow/tests/sink/test_server.py index 9e517b16..79a077c8 100644 --- a/packages/pynumaflow/tests/sink/test_server.py +++ b/packages/pynumaflow/tests/sink/test_server.py @@ -1,10 +1,9 @@ import os -import unittest from collections.abc import Iterator from datetime import datetime, timezone from unittest import mock -from unittest.mock import patch +import pytest from grpc import StatusCode from grpc_testing import server_from_dictionary, strict_real_time from google.protobuf import empty_pb2 as _empty_pb2 @@ -21,7 +20,7 @@ from pynumaflow.proto.common import metadata_pb2 from pynumaflow.proto.sinker import sink_pb2 from pynumaflow.sinker import Responses, Datum, Response, SinkServer, Message, UserMetadata -from tests.testing_utils import mock_terminate_on_stop +from pynumaflow.sinker.servicer.sync_servicer import SyncSinkServicer def mockenv(**envvars): @@ -79,259 +78,341 @@ def mock_watermark(): return t -# We are mocking the terminate function from the psutil to not exit the program during testing -@patch("psutil.Process.kill", mock_terminate_on_stop) -class TestServer(unittest.TestCase): - metadata = metadata_pb2.Metadata( - previous_vertex="test-source", - user_metadata={ - "custom_info": metadata_pb2.KeyValueGroup(key_value={"version": b"1.0.0"}), - }, - sys_metadata={ - "numaflow_version_info": metadata_pb2.KeyValueGroup(key_value={"version": b"1.0.0"}), - }, +METADATA = metadata_pb2.Metadata( + previous_vertex="test-source", + user_metadata={ + "custom_info": metadata_pb2.KeyValueGroup(key_value={"version": b"1.0.0"}), + }, + sys_metadata={ + "numaflow_version_info": metadata_pb2.KeyValueGroup(key_value={"version": b"1.0.0"}), + }, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def sink_test_server(): + server = SinkServer(sinker_instance=udsink_handler) + services = {sink_pb2.DESCRIPTOR.services_by_name["Sink"]: server.servicer} + return server_from_dictionary(services, strict_real_time()) + + +@pytest.fixture() +def err_sink_test_server(): + server = SinkServer(sinker_instance=err_udsink_handler) + return server, server_from_dictionary( + {sink_pb2.DESCRIPTOR.services_by_name["Sink"]: server.servicer}, + strict_real_time(), ) - def setUp(self) -> None: - server = SinkServer(sinker_instance=udsink_handler) - my_servicer = server.servicer - services = {sink_pb2.DESCRIPTOR.services_by_name["Sink"]: my_servicer} - self.test_server = server_from_dictionary(services, strict_real_time()) - - def test_is_ready(self): - method = self.test_server.invoke_unary_unary( - method_descriptor=( - sink_pb2.DESCRIPTOR.services_by_name["Sink"].methods_by_name["IsReady"] - ), - invocation_metadata={}, - request=_empty_pb2.Empty(), - timeout=1, - ) - response, metadata, code, details = method.termination() - expected = sink_pb2.ReadyResponse(ready=True) - self.assertEqual(expected, response) - self.assertEqual(code, StatusCode.OK) - - def test_udsink_err_handshake(self): - server = SinkServer(sinker_instance=err_udsink_handler) - my_servicer = server.servicer - services = {sink_pb2.DESCRIPTOR.services_by_name["Sink"]: my_servicer} - self.test_server = server_from_dictionary(services, strict_real_time()) - - event_time_timestamp = _timestamp_pb2.Timestamp() - event_time_timestamp.FromDatetime(dt=mock_event_time()) - watermark_timestamp = _timestamp_pb2.Timestamp() - watermark_timestamp.FromDatetime(dt=mock_watermark()) - - test_datums = [ - sink_pb2.SinkRequest( - request=sink_pb2.SinkRequest.Request( - id="test_id_0", - value=mock_message(), - event_time=event_time_timestamp, - watermark=watermark_timestamp, - ) - ), - sink_pb2.SinkRequest(status=sink_pb2.TransmissionStatus(eot=True)), - ] - - method = self.test_server.invoke_stream_stream( - method_descriptor=( - sink_pb2.DESCRIPTOR.services_by_name["Sink"].methods_by_name["SinkFn"] - ), - invocation_metadata={}, - timeout=1, - ) +def _make_timestamps(): + event_time_timestamp = _timestamp_pb2.Timestamp() + event_time_timestamp.FromDatetime(dt=mock_event_time()) + watermark_timestamp = _timestamp_pb2.Timestamp() + watermark_timestamp.FromDatetime(dt=mock_watermark()) + return event_time_timestamp, watermark_timestamp - method.send_request(test_datums[0]) - - metadata, code, details = method.termination() - self.assertTrue("UDSinkError: Exception('SinkFn: expected handshake message')" in details) - self.assertEqual(StatusCode.INTERNAL, code) - - def test_udsink_err(self): - server = SinkServer(sinker_instance=err_udsink_handler) - my_servicer = server.servicer - services = {sink_pb2.DESCRIPTOR.services_by_name["Sink"]: my_servicer} - self.test_server = server_from_dictionary(services, strict_real_time()) - - event_time_timestamp = _timestamp_pb2.Timestamp() - event_time_timestamp.FromDatetime(dt=mock_event_time()) - watermark_timestamp = _timestamp_pb2.Timestamp() - watermark_timestamp.FromDatetime(dt=mock_watermark()) - - test_datums = [ - sink_pb2.SinkRequest( - handshake=sink_pb2.Handshake(sot=True), - ), - sink_pb2.SinkRequest( - request=sink_pb2.SinkRequest.Request( - id="test_id_0", - value=mock_message(), - event_time=event_time_timestamp, - watermark=watermark_timestamp, - metadata=self.metadata, - ) - ), - sink_pb2.SinkRequest( - request=sink_pb2.SinkRequest.Request( - id="test_id_1", - value=mock_err_message(), - event_time=event_time_timestamp, - watermark=watermark_timestamp, - metadata=self.metadata, - ) - ), - sink_pb2.SinkRequest(status=sink_pb2.TransmissionStatus(eot=True)), - ] - - method = self.test_server.invoke_stream_stream( - method_descriptor=( - sink_pb2.DESCRIPTOR.services_by_name["Sink"].methods_by_name["SinkFn"] - ), - invocation_metadata={}, - timeout=1, - ) - method.send_request(test_datums[0]) - method.send_request(test_datums[1]) - method.send_request(test_datums[2]) - method.requests_closed() - responses = [] - while True: - try: - resp = method.take_response() - responses.append(resp) - except ValueError as err: - if "No more responses!" in err.__str__(): - break - - metadata, code, details = method.termination() - print(code) - self.assertEqual(StatusCode.INTERNAL, code) - - def test_forward_message(self): - event_time_timestamp = _timestamp_pb2.Timestamp() - event_time_timestamp.FromDatetime(dt=mock_event_time()) - watermark_timestamp = _timestamp_pb2.Timestamp() - watermark_timestamp.FromDatetime(dt=mock_watermark()) - - test_datums = [ - sink_pb2.SinkRequest( - handshake=sink_pb2.Handshake(sot=True), - ), - sink_pb2.SinkRequest( - request=sink_pb2.SinkRequest.Request( - id="test_id_0", - value=mock_message(), - event_time=event_time_timestamp, - watermark=watermark_timestamp, - metadata=self.metadata, - ) - ), - sink_pb2.SinkRequest( - request=sink_pb2.SinkRequest.Request( - id="test_id_1", - value=mock_err_message(), - event_time=event_time_timestamp, - watermark=watermark_timestamp, - metadata=self.metadata, - ) - ), - sink_pb2.SinkRequest( - request=sink_pb2.SinkRequest.Request( - id="test_id_2", - value=b"test_mock_on_success1_message", - event_time=event_time_timestamp, - watermark=watermark_timestamp, - metadata=self.metadata, - ) - ), - sink_pb2.SinkRequest( - request=sink_pb2.SinkRequest.Request( - id="test_id_3", - value=b"test_mock_on_success2_message", - event_time=event_time_timestamp, - watermark=watermark_timestamp, - metadata=self.metadata, - ) - ), - sink_pb2.SinkRequest(status=sink_pb2.TransmissionStatus(eot=True)), - ] - - method = self.test_server.invoke_stream_stream( - method_descriptor=( - sink_pb2.DESCRIPTOR.services_by_name["Sink"].methods_by_name["SinkFn"] - ), - invocation_metadata={}, - timeout=1, +# --------------------------------------------------------------------------- +# Server tests +# --------------------------------------------------------------------------- + + +def test_is_ready(sink_test_server): + method = sink_test_server.invoke_unary_unary( + method_descriptor=(sink_pb2.DESCRIPTOR.services_by_name["Sink"].methods_by_name["IsReady"]), + invocation_metadata={}, + request=_empty_pb2.Empty(), + timeout=1, + ) + + response, metadata, code, details = method.termination() + expected = sink_pb2.ReadyResponse(ready=True) + assert response == expected + assert code == StatusCode.OK + + +def test_udsink_err_handshake(err_sink_test_server): + _, test_server = err_sink_test_server + event_time_timestamp, watermark_timestamp = _make_timestamps() + + test_datums = [ + sink_pb2.SinkRequest( + request=sink_pb2.SinkRequest.Request( + id="test_id_0", + value=mock_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + ) + ), + sink_pb2.SinkRequest(status=sink_pb2.TransmissionStatus(eot=True)), + ] + + method = test_server.invoke_stream_stream( + method_descriptor=(sink_pb2.DESCRIPTOR.services_by_name["Sink"].methods_by_name["SinkFn"]), + invocation_metadata={}, + timeout=1, + ) + + method.send_request(test_datums[0]) + + metadata, code, details = method.termination() + assert "UDSinkError" in details + assert "SinkFn: expected handshake message" in details + assert code == StatusCode.INTERNAL + + +def test_udsink_err(err_sink_test_server): + _, test_server = err_sink_test_server + event_time_timestamp, watermark_timestamp = _make_timestamps() + + test_datums = [ + sink_pb2.SinkRequest(handshake=sink_pb2.Handshake(sot=True)), + sink_pb2.SinkRequest( + request=sink_pb2.SinkRequest.Request( + id="test_id_0", + value=mock_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + metadata=METADATA, + ) + ), + sink_pb2.SinkRequest( + request=sink_pb2.SinkRequest.Request( + id="test_id_1", + value=mock_err_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + metadata=METADATA, + ) + ), + sink_pb2.SinkRequest(status=sink_pb2.TransmissionStatus(eot=True)), + ] + + method = test_server.invoke_stream_stream( + method_descriptor=(sink_pb2.DESCRIPTOR.services_by_name["Sink"].methods_by_name["SinkFn"]), + invocation_metadata={}, + timeout=1, + ) + + method.send_request(test_datums[0]) + method.send_request(test_datums[1]) + method.send_request(test_datums[2]) + method.requests_closed() + + responses = [] + while True: + try: + resp = method.take_response() + responses.append(resp) + except ValueError as err: + if "No more responses!" in str(err): + break + + metadata, code, details = method.termination() + assert code == StatusCode.INTERNAL + + +def test_forward_message(sink_test_server): + event_time_timestamp, watermark_timestamp = _make_timestamps() + + test_datums = [ + sink_pb2.SinkRequest(handshake=sink_pb2.Handshake(sot=True)), + sink_pb2.SinkRequest( + request=sink_pb2.SinkRequest.Request( + id="test_id_0", + value=mock_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + metadata=METADATA, + ) + ), + sink_pb2.SinkRequest( + request=sink_pb2.SinkRequest.Request( + id="test_id_1", + value=mock_err_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + metadata=METADATA, + ) + ), + sink_pb2.SinkRequest( + request=sink_pb2.SinkRequest.Request( + id="test_id_2", + value=b"test_mock_on_success1_message", + event_time=event_time_timestamp, + watermark=watermark_timestamp, + metadata=METADATA, + ) + ), + sink_pb2.SinkRequest( + request=sink_pb2.SinkRequest.Request( + id="test_id_3", + value=b"test_mock_on_success2_message", + event_time=event_time_timestamp, + watermark=watermark_timestamp, + metadata=METADATA, + ) + ), + sink_pb2.SinkRequest(status=sink_pb2.TransmissionStatus(eot=True)), + ] + + method = sink_test_server.invoke_stream_stream( + method_descriptor=(sink_pb2.DESCRIPTOR.services_by_name["Sink"].methods_by_name["SinkFn"]), + invocation_metadata={}, + timeout=1, + ) + for x in test_datums: + method.send_request(x) + method.requests_closed() + + responses = [] + while True: + try: + resp = method.take_response() + responses.append(resp) + except ValueError as err: + if "No more responses!" in str(err): + break + + # 1 handshake + 1 data messages + 1 EOT + assert len(responses) == 3 + # first message should be handshake response + assert responses[0].handshake.sot + + assert len(responses[1].results) == 4 + + # assert the values for the corresponding messages + assert responses[1].results[0].id == "test_id_0" + assert responses[1].results[1].id == "test_id_1" + assert responses[1].results[2].id == "test_id_2" + assert responses[1].results[3].id == "test_id_3" + assert responses[1].results[0].status == sink_pb2.Status.SUCCESS + assert responses[1].results[1].status == sink_pb2.Status.FAILURE + assert responses[1].results[0].err_msg == "" + assert responses[1].results[1].err_msg == "mock sink message error" + + # last message should be EOT response + assert responses[2].status.eot + + _, code, _ = method.termination() + assert code == StatusCode.OK + + +def test_invalid_init(): + with pytest.raises(TypeError): + SinkServer() + + +@mockenv(NUMAFLOW_UD_CONTAINER_TYPE=UD_CONTAINER_FALLBACK_SINK) +def test_start_fallback_sink(): + server = SinkServer(sinker_instance=udsink_handler) + assert server.sock_path == f"unix://{FALLBACK_SINK_SOCK_PATH}" + assert server.server_info_file == FALLBACK_SINK_SERVER_INFO_FILE_PATH + + +@mockenv(NUMAFLOW_UD_CONTAINER_TYPE=UD_CONTAINER_ON_SUCCESS_SINK) +def test_start_on_success_sink(): + server = SinkServer(sinker_instance=udsink_handler) + assert server.sock_path == f"unix://{ON_SUCCESS_SINK_SOCK_PATH}" + assert server.server_info_file == ON_SUCCESS_SINK_SERVER_INFO_FILE_PATH + + +def test_max_threads(): + # max cap at 16 + server = SinkServer(sinker_instance=udsink_handler, max_threads=32) + assert server.max_threads == 16 + + # use argument provided + server = SinkServer(sinker_instance=udsink_handler, max_threads=5) + assert server.max_threads == 5 + + # defaults to 4 + server = SinkServer(sinker_instance=udsink_handler) + assert server.max_threads == 4 + + +# --------------------------------------------------------------------------- +# Shutdown-event tests +# --------------------------------------------------------------------------- + + +def test_shutdown_event_set_on_handler_error(): + """When the UDF handler raises, the servicer must signal the shutdown event.""" + servicer = SyncSinkServicer(handler=err_udsink_handler) + event_time_timestamp, watermark_timestamp = _make_timestamps() + + services = {sink_pb2.DESCRIPTOR.services_by_name["Sink"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + test_datums = [ + sink_pb2.SinkRequest(handshake=sink_pb2.Handshake(sot=True)), + sink_pb2.SinkRequest( + request=sink_pb2.SinkRequest.Request( + id="test_id_0", + value=mock_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + metadata=METADATA, + ) + ), + sink_pb2.SinkRequest(status=sink_pb2.TransmissionStatus(eot=True)), + ] + + method = test_server.invoke_stream_stream( + method_descriptor=(sink_pb2.DESCRIPTOR.services_by_name["Sink"].methods_by_name["SinkFn"]), + invocation_metadata={}, + timeout=2, + ) + + for d in test_datums: + method.send_request(d) + method.requests_closed() + + while True: + try: + method.take_response() + except ValueError: + break + + _, code, _ = method.termination() + assert code == StatusCode.INTERNAL + assert servicer._shutdown_event.is_set() + assert servicer._error is not None + + +def test_shutdown_event_set_on_handshake_error(): + """Missing handshake must also signal the shutdown event.""" + servicer = SyncSinkServicer(handler=udsink_handler) + event_time_timestamp, watermark_timestamp = _make_timestamps() + + services = {sink_pb2.DESCRIPTOR.services_by_name["Sink"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + method = test_server.invoke_stream_stream( + method_descriptor=(sink_pb2.DESCRIPTOR.services_by_name["Sink"].methods_by_name["SinkFn"]), + invocation_metadata={}, + timeout=1, + ) + + method.send_request( + sink_pb2.SinkRequest( + request=sink_pb2.SinkRequest.Request( + id="test_id_0", + value=mock_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + ) ) - for x in test_datums: - method.send_request(x) - method.requests_closed() - - responses = [] - while True: - try: - resp = method.take_response() - responses.append(resp) - except ValueError as err: - if "No more responses!" in err.__str__(): - break - - # 1 handshake + 1 data messages + 1 EOT - self.assertEqual(3, len(responses)) - # first message should be handshake response - self.assertTrue(responses[0].handshake.sot) - - self.assertEqual(4, len(responses[1].results)) - - # assert the values for the corresponding messages - self.assertEqual("test_id_0", responses[1].results[0].id) - self.assertEqual("test_id_1", responses[1].results[1].id) - self.assertEqual("test_id_2", responses[1].results[2].id) - self.assertEqual("test_id_3", responses[1].results[3].id) - self.assertEqual(responses[1].results[0].status, sink_pb2.Status.SUCCESS) - self.assertEqual(responses[1].results[1].status, sink_pb2.Status.FAILURE) - self.assertEqual("", responses[1].results[0].err_msg) - self.assertEqual("mock sink message error", responses[1].results[1].err_msg) - - # last message should be EOT response - self.assertTrue(responses[2].status.eot) - - _, code, _ = method.termination() - self.assertEqual(code, StatusCode.OK) - - def test_invalid_init(self): - with self.assertRaises(TypeError): - SinkServer() - - @mockenv(NUMAFLOW_UD_CONTAINER_TYPE=UD_CONTAINER_FALLBACK_SINK) - def test_start_fallback_sink(self): - server = SinkServer(sinker_instance=udsink_handler) - self.assertEqual(server.sock_path, f"unix://{FALLBACK_SINK_SOCK_PATH}") - self.assertEqual(server.server_info_file, FALLBACK_SINK_SERVER_INFO_FILE_PATH) - - @mockenv(NUMAFLOW_UD_CONTAINER_TYPE=UD_CONTAINER_ON_SUCCESS_SINK) - def test_start_on_success_sink(self): - server = SinkServer(sinker_instance=udsink_handler) - self.assertEqual(server.sock_path, f"unix://{ON_SUCCESS_SINK_SOCK_PATH}") - self.assertEqual(server.server_info_file, ON_SUCCESS_SINK_SERVER_INFO_FILE_PATH) - - def test_max_threads(self): - # max cap at 16 - server = SinkServer(sinker_instance=udsink_handler, max_threads=32) - self.assertEqual(server.max_threads, 16) - - # use argument provided - server = SinkServer(sinker_instance=udsink_handler, max_threads=5) - self.assertEqual(server.max_threads, 5) - - # defaults to 4 - server = SinkServer(sinker_instance=udsink_handler) - self.assertEqual(server.max_threads, 4) - - -if __name__ == "__main__": - unittest.main() + ) + + _, code, details = method.termination() + assert code == StatusCode.INTERNAL + assert "SinkFn: expected handshake message" in details + assert servicer._shutdown_event.is_set() + assert servicer._error is not None From 287b841f6960a9d5f23bd35b849116875306716c Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 3 Mar 2026 14:29:59 +0530 Subject: [PATCH 2/7] close the request iterator on shutdown Signed-off-by: Sreekanth --- .../pynumaflow/pynumaflow/shared/server.py | 2 +- .../pynumaflow/pynumaflow/shared/synciter.py | 6 ++++- .../pynumaflow/shared/thread_with_return.py | 7 +++--- .../pynumaflow/pynumaflow/sinker/server.py | 6 ++--- .../sinker/servicer/sync_servicer.py | 24 +++++++++---------- packages/pynumaflow/tests/sink/test_server.py | 13 +++++----- 6 files changed, 30 insertions(+), 28 deletions(-) diff --git a/packages/pynumaflow/pynumaflow/shared/server.py b/packages/pynumaflow/pynumaflow/shared/server.py index 630256e1..174ec476 100644 --- a/packages/pynumaflow/pynumaflow/shared/server.py +++ b/packages/pynumaflow/pynumaflow/shared/server.py @@ -315,7 +315,7 @@ def update_context_err(context: NumaflowServicerContext, e: BaseException, err_m """ trace = get_exception_traceback_str(e) _LOGGER.critical(trace) - _LOGGER.critical(e.__str__()) + _LOGGER.critical(err_msg) grpc_status = get_grpc_status(err_msg) diff --git a/packages/pynumaflow/pynumaflow/shared/synciter.py b/packages/pynumaflow/pynumaflow/shared/synciter.py index b7c38455..e5deb4c9 100644 --- a/packages/pynumaflow/pynumaflow/shared/synciter.py +++ b/packages/pynumaflow/pynumaflow/shared/synciter.py @@ -4,7 +4,7 @@ class SyncIterator: - """A Sync Interator backed by a queue""" + """A Sync Iterator backed by a queue""" __slots__ = "_queue" @@ -21,3 +21,7 @@ def read_iterator(self): def put(self, item): self._queue.put(item) + + def close(self): + """Unblock any thread waiting on read_iterator() by injecting STREAM_EOF.""" + self._queue.put(STREAM_EOF) diff --git a/packages/pynumaflow/pynumaflow/shared/thread_with_return.py b/packages/pynumaflow/pynumaflow/shared/thread_with_return.py index 92d9cfac..ec662857 100644 --- a/packages/pynumaflow/pynumaflow/shared/thread_with_return.py +++ b/packages/pynumaflow/pynumaflow/shared/thread_with_return.py @@ -41,7 +41,7 @@ def run(self): except BaseException as exc: self._exception = exc - def join(self, *args): + def join(self, timeout=None): """ Wait for the thread to complete and return the result. @@ -50,7 +50,7 @@ def join(self, *args): exception, and then returns the stored return value. Parameters: - *args: Variable length argument list to pass to the join() method. + timeout: Seconds to wait (None means wait indefinitely). Returns: The return value from the target function. @@ -58,8 +58,7 @@ def join(self, *args): Raises: BaseException: If the target function raised during run(). """ - # Call the parent class's join() method to wait for the thread to finish - Thread.join(self, *args) + Thread.join(self, timeout=timeout) if self._exception is not None: raise self._exception # Return the result of the target function diff --git a/packages/pynumaflow/pynumaflow/sinker/server.py b/packages/pynumaflow/pynumaflow/sinker/server.py index 2dd71080..5da8a132 100644 --- a/packages/pynumaflow/pynumaflow/sinker/server.py +++ b/packages/pynumaflow/pynumaflow/sinker/server.py @@ -131,9 +131,9 @@ def start(self): server_options=self._server_options, udf_type=UDFType.Sink, server_info=serv_info, - shutdown_event=self.servicer._shutdown_event, + shutdown_event=self.servicer.shutdown_event, ) - if self.servicer._error: - _LOGGER.critical("Server exiting due to UDF error: %s", self.servicer._error) + if self.servicer.error: + _LOGGER.critical("Server exiting due to UDF error: %s", self.servicer.error) sys.exit(1) diff --git a/packages/pynumaflow/pynumaflow/sinker/servicer/sync_servicer.py b/packages/pynumaflow/pynumaflow/sinker/servicer/sync_servicer.py index e7bce225..abec5acc 100644 --- a/packages/pynumaflow/pynumaflow/sinker/servicer/sync_servicer.py +++ b/packages/pynumaflow/pynumaflow/sinker/servicer/sync_servicer.py @@ -25,8 +25,8 @@ class SyncSinkServicer(sink_pb2_grpc.SinkServicer): def __init__(self, handler: SinkSyncCallable): self.handler: SinkSyncCallable = handler - self._shutdown_event: threading.Event = threading.Event() - self._error: BaseException | None = None + self.shutdown_event: threading.Event = threading.Event() + self.error: BaseException | None = None def SinkFn( self, request_iterator: Iterator[sink_pb2.SinkRequest], context: NumaflowServicerContext @@ -85,19 +85,19 @@ def SinkFn( err_msg = f"UDSinkError, {ERR_UDF_EXCEPTION_STRING}: {repr(err)}" _LOGGER.critical(err_msg, exc_info=True) update_context_err(context, err, err_msg) - self._error = err - self._shutdown_event.set() + # Unblock the handler thread if it is waiting on queue.get() + # (e.g. gRPC stream broke while the handler was waiting for the next message). + # This lets it exit gracefully and release any user-held resources + # before the process shuts down. + req_queue.close() + self.error = err + self.shutdown_event.set() return def _invoke_sink(self, request_queue: SyncIterator, context: NumaflowServicerContext): - try: - # Invoke the handler function with the request queue - rspns = self.handler(request_queue.read_iterator()) - return build_sink_resp_results(rspns) - except BaseException as err: - err_msg = f"UDSinkError: {repr(err)}" - _LOGGER.critical(err_msg, exc_info=True) - raise err + # Invoke the handler function with the request queue + rspns = self.handler(request_queue.read_iterator()) + return build_sink_resp_results(rspns) def IsReady(self, request, context: NumaflowServicerContext) -> sink_pb2.ReadyResponse: """ diff --git a/packages/pynumaflow/tests/sink/test_server.py b/packages/pynumaflow/tests/sink/test_server.py index 79a077c8..1e9fe786 100644 --- a/packages/pynumaflow/tests/sink/test_server.py +++ b/packages/pynumaflow/tests/sink/test_server.py @@ -200,9 +200,8 @@ def test_udsink_err(err_sink_test_server): timeout=1, ) - method.send_request(test_datums[0]) - method.send_request(test_datums[1]) - method.send_request(test_datums[2]) + for d in test_datums: + method.send_request(d) method.requests_closed() responses = [] @@ -382,8 +381,8 @@ def test_shutdown_event_set_on_handler_error(): _, code, _ = method.termination() assert code == StatusCode.INTERNAL - assert servicer._shutdown_event.is_set() - assert servicer._error is not None + assert servicer.shutdown_event.is_set() + assert servicer.error is not None def test_shutdown_event_set_on_handshake_error(): @@ -414,5 +413,5 @@ def test_shutdown_event_set_on_handshake_error(): _, code, details = method.termination() assert code == StatusCode.INTERNAL assert "SinkFn: expected handshake message" in details - assert servicer._shutdown_event.is_set() - assert servicer._error is not None + assert servicer.shutdown_event.is_set() + assert servicer.error is not None From e8474f03dafa6eacbfacf74648cbd203cd9a0dca Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 3 Mar 2026 19:44:57 +0530 Subject: [PATCH 3/7] fix req_queue scope Signed-off-by: Sreekanth --- .../pynumaflow/pynumaflow/sinker/servicer/sync_servicer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/pynumaflow/pynumaflow/sinker/servicer/sync_servicer.py b/packages/pynumaflow/pynumaflow/sinker/servicer/sync_servicer.py index abec5acc..9ed2f36a 100644 --- a/packages/pynumaflow/pynumaflow/sinker/servicer/sync_servicer.py +++ b/packages/pynumaflow/pynumaflow/sinker/servicer/sync_servicer.py @@ -35,6 +35,7 @@ def SinkFn( Applies a sink function to datum elements. """ + req_queue = None try: # The first message to be received should be a valid handshake req = next(request_iterator) @@ -89,7 +90,8 @@ def SinkFn( # (e.g. gRPC stream broke while the handler was waiting for the next message). # This lets it exit gracefully and release any user-held resources # before the process shuts down. - req_queue.close() + if req_queue is not None: + req_queue.close() self.error = err self.shutdown_event.set() return From f5594c4d869f945649e5abe07f7e0f91de97fc6d Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 3 Mar 2026 19:52:37 +0530 Subject: [PATCH 4/7] handle grpc stream close Signed-off-by: Sreekanth --- .../pynumaflow/sinker/servicer/sync_servicer.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/packages/pynumaflow/pynumaflow/sinker/servicer/sync_servicer.py b/packages/pynumaflow/pynumaflow/sinker/servicer/sync_servicer.py index 9ed2f36a..5b0789f7 100644 --- a/packages/pynumaflow/pynumaflow/sinker/servicer/sync_servicer.py +++ b/packages/pynumaflow/pynumaflow/sinker/servicer/sync_servicer.py @@ -1,6 +1,7 @@ import threading from collections.abc import Iterator +import grpc from pynumaflow._constants import _LOGGER, STREAM_EOF, ERR_UDF_EXCEPTION_STRING from pynumaflow.proto.sinker import sink_pb2_grpc, sink_pb2 @@ -82,6 +83,13 @@ def SinkFn( if cur_task: cur_task.join() + except grpc.RpcError: + _LOGGER.warning("gRPC stream closed, shutting down the server.") + if req_queue is not None: + req_queue.close() + self.shutdown_event.set() + return + except BaseException as err: err_msg = f"UDSinkError, {ERR_UDF_EXCEPTION_STRING}: {repr(err)}" _LOGGER.critical(err_msg, exc_info=True) From 9de02db7cc7f929b12bdd697bdbd3261569cdc6c Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 3 Mar 2026 20:28:45 +0530 Subject: [PATCH 5/7] propagte error Signed-off-by: Sreekanth --- packages/pynumaflow/pynumaflow/shared/server.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/pynumaflow/pynumaflow/shared/server.py b/packages/pynumaflow/pynumaflow/shared/server.py index 174ec476..3986e0dc 100644 --- a/packages/pynumaflow/pynumaflow/shared/server.py +++ b/packages/pynumaflow/pynumaflow/shared/server.py @@ -263,14 +263,14 @@ def check_instance(instance, callable_type) -> bool: return False -def get_grpc_status(err: str): +def get_grpc_status(err: str, detail: str | None = None): """ Create a grpc status object with the error details. """ details = any_pb2.Any() details.Pack( error_details_pb2.DebugInfo( - detail="\n".join(traceback.format_stack()), + detail=detail if detail is not None else "\n".join(traceback.format_stack()), ) ) @@ -317,7 +317,7 @@ def update_context_err(context: NumaflowServicerContext, e: BaseException, err_m _LOGGER.critical(trace) _LOGGER.critical(err_msg) - grpc_status = get_grpc_status(err_msg) + grpc_status = get_grpc_status(err_msg, detail=trace) context.set_code(grpc.StatusCode.INTERNAL) context.set_details(err_msg) From ae7c50e59524a20647303b342a01904ae85d0ab3 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Thu, 5 Mar 2026 07:05:36 +0530 Subject: [PATCH 6/7] more tests Signed-off-by: Sreekanth --- packages/pynumaflow/tests/sink/test_server.py | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/packages/pynumaflow/tests/sink/test_server.py b/packages/pynumaflow/tests/sink/test_server.py index 1e9fe786..e76a65b2 100644 --- a/packages/pynumaflow/tests/sink/test_server.py +++ b/packages/pynumaflow/tests/sink/test_server.py @@ -3,6 +3,7 @@ from datetime import datetime, timezone from unittest import mock +import grpc import pytest from grpc import StatusCode from grpc_testing import server_from_dictionary, strict_real_time @@ -415,3 +416,43 @@ def test_shutdown_event_set_on_handshake_error(): assert "SinkFn: expected handshake message" in details assert servicer.shutdown_event.is_set() assert servicer.error is not None + + +def test_shutdown_event_set_on_stream_close_before_handshake(): + """grpc.RpcError on the first read (before handshake): shutdown_event set, req_queue is None so close is skipped.""" + servicer = SyncSinkServicer(handler=udsink_handler) + + def _cancelled_iter(): + raise grpc.RpcError() + yield # make it a generator + + responses = list(servicer.SinkFn(_cancelled_iter(), mock.MagicMock())) + + assert responses == [] + assert servicer.shutdown_event.is_set() + assert servicer.error is None + + +def test_shutdown_event_set_on_stream_close_mid_batch(): + """grpc.RpcError mid-batch: req_queue is closed (unblocking the handler thread) and shutdown_event is set.""" + servicer = SyncSinkServicer(handler=udsink_handler) + event_time_timestamp, watermark_timestamp = _make_timestamps() + + def _cancelled_iter(): + yield sink_pb2.SinkRequest(handshake=sink_pb2.Handshake(sot=True)) + yield sink_pb2.SinkRequest( + request=sink_pb2.SinkRequest.Request( + id="test_id_0", + value=mock_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + metadata=METADATA, + ) + ) + raise grpc.RpcError() + + responses = list(servicer.SinkFn(_cancelled_iter(), mock.MagicMock())) + + assert responses[0].handshake.sot + assert servicer.shutdown_event.is_set() + assert servicer.error is None From 20d00d5eb43bbdba148a5f74a9e4065933615778 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Thu, 5 Mar 2026 07:07:22 +0530 Subject: [PATCH 7/7] lints Signed-off-by: Sreekanth --- packages/pynumaflow/tests/sink/test_server.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/packages/pynumaflow/tests/sink/test_server.py b/packages/pynumaflow/tests/sink/test_server.py index e76a65b2..580dfb5e 100644 --- a/packages/pynumaflow/tests/sink/test_server.py +++ b/packages/pynumaflow/tests/sink/test_server.py @@ -419,7 +419,8 @@ def test_shutdown_event_set_on_handshake_error(): def test_shutdown_event_set_on_stream_close_before_handshake(): - """grpc.RpcError on the first read (before handshake): shutdown_event set, req_queue is None so close is skipped.""" + """grpc.RpcError on the first read (before handshake): shutdown_event set, + req_queue is None so close is skipped.""" servicer = SyncSinkServicer(handler=udsink_handler) def _cancelled_iter(): @@ -434,7 +435,8 @@ def _cancelled_iter(): def test_shutdown_event_set_on_stream_close_mid_batch(): - """grpc.RpcError mid-batch: req_queue is closed (unblocking the handler thread) and shutdown_event is set.""" + """grpc.RpcError mid-batch: req_queue is closed (unblocking the handler thread) + and shutdown_event is set.""" servicer = SyncSinkServicer(handler=udsink_handler) event_time_timestamp, watermark_timestamp = _make_timestamps()