From 977f787dbd03fec41649c034c8a47797f30e200c Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Fri, 6 Mar 2026 06:45:24 +0530 Subject: [PATCH 1/6] fix: Clean shutdown for multithreaded unary Map Signed-off-by: Sreekanth --- .../mapper/_servicer/_sync_servicer.py | 54 ++++++++++++++----- .../pynumaflow/mapper/sync_server.py | 7 +++ 2 files changed, 48 insertions(+), 13 deletions(-) diff --git a/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py b/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py index be5f7199..61700f1b 100644 --- a/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py +++ b/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py @@ -2,11 +2,12 @@ from concurrent.futures import ThreadPoolExecutor from collections.abc import Iterator +import grpc from google.protobuf import empty_pb2 as _empty_pb2 -from pynumaflow.shared.server import exit_on_error +from pynumaflow.shared.server import update_context_err from pynumaflow._metadata import _user_and_system_metadata_from_proto -from pynumaflow._constants import NUM_THREADS_DEFAULT, STREAM_EOF, _LOGGER, ERR_UDF_EXCEPTION_STRING +from pynumaflow._constants import NUM_THREADS_DEFAULT, _LOGGER, ERR_UDF_EXCEPTION_STRING from pynumaflow.mapper._dtypes import MapSyncCallable, Datum, MapError from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc from pynumaflow.shared.synciter import SyncIterator @@ -26,6 +27,8 @@ def __init__(self, handler: MapSyncCallable, multiproc: bool = False): self.multiproc = multiproc # create a thread pool for executing UDF code self.executor = ThreadPoolExecutor(max_workers=NUM_THREADS_DEFAULT) + self.shutdown_event: threading.Event = threading.Event() + self.error: BaseException | None = None def MapFn( self, @@ -36,6 +39,7 @@ def MapFn( Applies a function to each datum element. The pascal case function name comes from the proto map_pb2_grpc.py file. """ + result_queue = None try: # The first message to be received should be a valid handshake req = next(request_iterator) @@ -57,10 +61,13 @@ def MapFn( for res in result_queue.read_iterator(): # if error handler accordingly if isinstance(res, BaseException): - # Terminate the current server process due to exception - exit_on_error( - context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(res)}", parent=self.multiproc - ) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(res)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, res, err_msg) + # Unblock the reader thread if it is waiting on queue.put() + result_queue.close() + self.error = res + self.shutdown_event.set() return # return the result yield res @@ -69,12 +76,22 @@ def MapFn( reader_thread.join() self.executor.shutdown(cancel_futures=True) + except grpc.RpcError: + _LOGGER.warning("gRPC stream closed, shutting down the server.") + if result_queue is not None: + result_queue.close() + self.shutdown_event.set() + return + except BaseException as err: - _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - # Terminate the current server process due to exception - exit_on_error( - context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}", parent=self.multiproc - ) + err_msg = f"UDFError, {ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, err, err_msg) + # Unblock the reader thread if it is waiting on queue.put() + if result_queue is not None: + result_queue.close() + self.error = err + self.shutdown_event.set() return def _process_requests( @@ -91,9 +108,20 @@ def _process_requests( # wait for all tasks to finish after all requests exhausted self.executor.shutdown(wait=True) # Indicate to the result queue that no more messages left to process - result_queue.put(STREAM_EOF) + result_queue.close() + except grpc.RpcError: + # The only error that can occur here is the gRPC stream closing + # (e.g. client disconnected). UDF exceptions are caught inside _invoke_map + # and never propagate here. + _LOGGER.warning("gRPC stream closed in reader thread, shutting down the server.") + # Let already-submitted UDF tasks finish within the graceful shutdown period + self.executor.shutdown(wait=True) + # Signal MapFn's read_iterator() loop to exit cleanly + result_queue.close() + # Trigger server shutdown (not a UDF error, so self.error is not set) + self.shutdown_event.set() except BaseException as e: - _LOGGER.critical("MapFn Error, re-raising the error", exc_info=True) + _LOGGER.critical("MapFn Error while reading requests from gRPC stream", exc_info=True) # Surface the error to the consumer; MapFn will handle and exit result_queue.put(e) diff --git a/packages/pynumaflow/pynumaflow/mapper/sync_server.py b/packages/pynumaflow/pynumaflow/mapper/sync_server.py index 9c2431b6..a96ceb7f 100644 --- a/packages/pynumaflow/pynumaflow/mapper/sync_server.py +++ b/packages/pynumaflow/pynumaflow/mapper/sync_server.py @@ -1,3 +1,5 @@ +import sys + from pynumaflow.info.types import ( ServerInfo, MAP_MODE_KEY, @@ -112,4 +114,9 @@ def start(self) -> None: server_options=self._server_options, udf_type=UDFType.Map, server_info=serv_info, + shutdown_event=self.servicer.shutdown_event, ) + + if self.servicer.error: + _LOGGER.critical("Server exiting due to UDF error: %s", self.servicer.error) + sys.exit(1) From a43164a7c46c2c52d84473c3a2b84cef78b0ac76 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Fri, 6 Mar 2026 18:01:43 +0530 Subject: [PATCH 2/6] Avoid invoking UDF while shutting down Signed-off-by: Sreekanth --- .../pynumaflow/mapper/_servicer/_sync_servicer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py b/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py index 61700f1b..c9b60570 100644 --- a/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py +++ b/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py @@ -62,7 +62,6 @@ def MapFn( # if error handler accordingly if isinstance(res, BaseException): err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(res)}" - _LOGGER.critical(err_msg, exc_info=True) update_context_err(context, res, err_msg) # Unblock the reader thread if it is waiting on queue.put() result_queue.close() @@ -132,6 +131,11 @@ def _invoke_map( result_queue: SyncIterator, ): try: + # Skip processing if a shutdown is already in progress + # (e.g. a prior invocation raised an exception) + if self.shutdown_event.is_set(): + return + (user_metadata, system_metadata) = _user_and_system_metadata_from_proto( request.request.metadata ) From 9f74b6ecf30acb25731ea1d451be95f4cbc437b6 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Fri, 6 Mar 2026 18:15:22 +0530 Subject: [PATCH 3/6] Rewrite tests with pytest Signed-off-by: Sreekanth --- .../pynumaflow/tests/map/test_sync_mapper.py | 192 +++++++----------- 1 file changed, 73 insertions(+), 119 deletions(-) diff --git a/packages/pynumaflow/tests/map/test_sync_mapper.py b/packages/pynumaflow/tests/map/test_sync_mapper.py index edafa835..668b7cd9 100644 --- a/packages/pynumaflow/tests/map/test_sync_mapper.py +++ b/packages/pynumaflow/tests/map/test_sync_mapper.py @@ -1,7 +1,5 @@ -import unittest -from unittest.mock import patch - import grpc +import pytest from google.protobuf import empty_pb2 as _empty_pb2 from grpc import StatusCode from grpc_testing import server_from_dictionary, strict_real_time @@ -9,88 +7,72 @@ from pynumaflow.mapper import MapServer from pynumaflow.proto.mapper import map_pb2 from tests.map.utils import map_handler, err_map_handler, ExampleMap, get_test_datums -from tests.testing_utils import ( - mock_terminate_on_stop, -) - - -# We are mocking the terminate function from the psutil to not exit the program during testing -@patch("psutil.Process.kill", mock_terminate_on_stop) -class TestSyncMapper(unittest.TestCase): - # @patch("psutil.Process.kill", mock_terminate_on_stop) - def setUp(self) -> None: - class_instance = ExampleMap() - my_server = MapServer(mapper_instance=class_instance) - services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} - self.test_server = server_from_dictionary(services, strict_real_time()) - - def test_init_with_args(self) -> None: + + +@pytest.fixture +def test_server(): + class_instance = ExampleMap() + my_server = MapServer(mapper_instance=class_instance) + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} + return server_from_dictionary(services, strict_real_time()) + + +def _invoke_map(test_server, handshake=True): + test_datums = get_test_datums(handshake=handshake) + method = test_server.invoke_stream_stream( + method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]), + invocation_metadata={}, + timeout=1, + ) + for x in test_datums: + method.send_request(x) + method.requests_closed() + + responses = [] + while True: + try: + resp = method.take_response() + responses.append(resp) + except ValueError as err: + if "No more responses!" in str(err): + break + return method, responses + + +def _err_server(): + my_server = MapServer(mapper_instance=err_map_handler) + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} + return server_from_dictionary(services, strict_real_time()) + + +class TestSyncMapper: + def test_init_with_args(self): my_servicer = MapServer( mapper_instance=map_handler, sock_path="/tmp/test.sock", max_message_size=1024 * 1024 * 5, ) - self.assertEqual(my_servicer.sock_path, "unix:///tmp/test.sock") - self.assertEqual(my_servicer.max_message_size, 1024 * 1024 * 5) + assert my_servicer.sock_path == "unix:///tmp/test.sock" + assert my_servicer.max_message_size == 1024 * 1024 * 5 def test_udf_map_err_handshake(self): - my_server = MapServer(mapper_instance=err_map_handler) - services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} - self.test_server = server_from_dictionary(services, strict_real_time()) - - test_datums = get_test_datums(handshake=False) - method = self.test_server.invoke_stream_stream( - method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]), - invocation_metadata={}, - timeout=1, - ) - for x in test_datums: - method.send_request(x) - method.requests_closed() - - responses = [] - while True: - try: - resp = method.take_response() - responses.append(resp) - except ValueError as err: - if "No more responses!" in err.__str__(): - break + server = _err_server() + method, responses = _invoke_map(server, handshake=False) metadata, code, details = method.termination() - self.assertTrue("MapFn: expected handshake as the first message" in details) - self.assertEqual(grpc.StatusCode.INTERNAL, code) + assert "MapFn: expected handshake as the first message" in details + assert code == grpc.StatusCode.INTERNAL def test_udf_map_error_response(self): - my_server = MapServer(mapper_instance=err_map_handler) - services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} - self.test_server = server_from_dictionary(services, strict_real_time()) - - test_datums = get_test_datums(handshake=True) - method = self.test_server.invoke_stream_stream( - method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]), - invocation_metadata={}, - timeout=1, - ) - for x in test_datums: - method.send_request(x) - method.requests_closed() - - responses = [] - while True: - try: - resp = method.take_response() - responses.append(resp) - except ValueError as err: - if "No more responses!" in err.__str__(): - break + server = _err_server() + method, responses = _invoke_map(server, handshake=True) metadata, code, details = method.termination() - self.assertTrue("Something is fishy!" in details) - self.assertEqual(grpc.StatusCode.INTERNAL, code) + assert "Something is fishy!" in details + assert code == grpc.StatusCode.INTERNAL - def test_is_ready(self): - method = self.test_server.invoke_unary_unary( + def test_is_ready(self, test_server): + method = test_server.invoke_unary_unary( method_descriptor=( map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["IsReady"] ), @@ -100,70 +82,42 @@ def test_is_ready(self): ) response, metadata, code, details = method.termination() - expected = map_pb2.ReadyResponse(ready=True) - self.assertEqual(expected, response) - self.assertEqual(code, StatusCode.OK) - - def test_map_forward_message(self): - test_datums = get_test_datums(handshake=True) - method = self.test_server.invoke_stream_stream( - method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]), - invocation_metadata={}, - timeout=1, - ) - for x in test_datums: - method.send_request(x) - method.requests_closed() - - responses = [] - while True: - try: - resp = method.take_response() - responses.append(resp) - except ValueError as err: - if "No more responses!" in err.__str__(): - break + assert response == map_pb2.ReadyResponse(ready=True) + assert code == StatusCode.OK + + def test_map_forward_message(self, test_server): + method, responses = _invoke_map(test_server, handshake=True) metadata, code, details = method.termination() # 1 handshake + 3 data responses - self.assertEqual(4, len(responses)) - - self.assertTrue(responses[0].handshake.sot) + assert len(responses) == 4 + assert responses[0].handshake.sot result_ids = {f"test-id-{id}" for id in range(1, 4)} - idx = 1 - while idx < len(responses): - result_ids.remove(responses[idx].id) - self.assertEqual( - bytes( - "payload:test_mock_message " - "event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00", - encoding="utf-8", - ), - responses[idx].results[0].value, + for resp in responses[1:]: + result_ids.remove(resp.id) + assert resp.results[0].value == bytes( + "payload:test_mock_message " + "event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00", + encoding="utf-8", ) - self.assertEqual(1, len(responses[idx].results)) - idx += 1 - self.assertEqual(len(result_ids), 0) - self.assertEqual(code, StatusCode.OK) + assert len(resp.results) == 1 + assert len(result_ids) == 0 + assert code == StatusCode.OK def test_invalid_input(self): - with self.assertRaises(TypeError): + with pytest.raises(TypeError): MapServer() def test_max_threads(self): # max cap at 16 server = MapServer(mapper_instance=map_handler, max_threads=32) - self.assertEqual(server.max_threads, 16) + assert server.max_threads == 16 # use argument provided server = MapServer(mapper_instance=map_handler, max_threads=5) - self.assertEqual(server.max_threads, 5) + assert server.max_threads == 5 # defaults to 4 server = MapServer(mapper_instance=map_handler) - self.assertEqual(server.max_threads, 4) - - -if __name__ == "__main__": - unittest.main() + assert server.max_threads == 4 From bfcb3797ec79390f17e68957662c02314eb12399 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Fri, 6 Mar 2026 18:21:15 +0530 Subject: [PATCH 4/6] tests for grpc error Signed-off-by: Sreekanth --- .../pynumaflow/tests/map/test_sync_mapper.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/packages/pynumaflow/tests/map/test_sync_mapper.py b/packages/pynumaflow/tests/map/test_sync_mapper.py index 668b7cd9..d3b65eaa 100644 --- a/packages/pynumaflow/tests/map/test_sync_mapper.py +++ b/packages/pynumaflow/tests/map/test_sync_mapper.py @@ -1,3 +1,6 @@ +import threading +from unittest.mock import MagicMock + import grpc import pytest from google.protobuf import empty_pb2 as _empty_pb2 @@ -5,6 +8,7 @@ from grpc_testing import server_from_dictionary, strict_real_time from pynumaflow.mapper import MapServer +from pynumaflow.mapper._servicer._sync_servicer import SyncMapServicer from pynumaflow.proto.mapper import map_pb2 from tests.map.utils import map_handler, err_map_handler, ExampleMap, get_test_datums @@ -121,3 +125,36 @@ def test_max_threads(self): # defaults to 4 server = MapServer(mapper_instance=map_handler) assert server.max_threads == 4 + + def test_rpc_error_before_handshake(self): + """RpcError on the very first read triggers the except grpc.RpcError in MapFn.""" + + def failing_iterator(): + raise grpc.RpcError() + yield # make it a generator + + servicer = SyncMapServicer(handler=map_handler) + context = MagicMock() + + responses = list(servicer.MapFn(failing_iterator(), context)) + + assert responses == [] + assert servicer.shutdown_event.is_set() + + def test_rpc_error_mid_stream(self): + """RpcError while reading requests triggers the except grpc.RpcError + in _process_requests.""" + + def interrupted_iterator(): + yield map_pb2.MapRequest(handshake=map_pb2.Handshake(sot=True)) + raise grpc.RpcError() + + servicer = SyncMapServicer(handler=map_handler) + context = MagicMock() + + responses = list(servicer.MapFn(interrupted_iterator(), context)) + + # Should get the handshake response and then stop cleanly + assert len(responses) == 1 + assert responses[0].handshake.sot + assert servicer.shutdown_event.is_set() From 89d7df90dc7cc9998288bcbddd6bfe18e1ec99ac Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Fri, 6 Mar 2026 18:23:38 +0530 Subject: [PATCH 5/6] fix lints Signed-off-by: Sreekanth --- packages/pynumaflow/tests/map/test_sync_mapper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/pynumaflow/tests/map/test_sync_mapper.py b/packages/pynumaflow/tests/map/test_sync_mapper.py index d3b65eaa..adabfeed 100644 --- a/packages/pynumaflow/tests/map/test_sync_mapper.py +++ b/packages/pynumaflow/tests/map/test_sync_mapper.py @@ -1,4 +1,3 @@ -import threading from unittest.mock import MagicMock import grpc From 072be7edb6bde607b45e781d1c1bc15d876af64d Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Fri, 6 Mar 2026 18:29:22 +0530 Subject: [PATCH 6/6] better pytest settings Signed-off-by: Sreekanth --- packages/pynumaflow/pytest.ini | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/packages/pynumaflow/pytest.ini b/packages/pynumaflow/pytest.ini index 9d6adb92..10ef6cf4 100644 --- a/packages/pynumaflow/pytest.ini +++ b/packages/pynumaflow/pytest.ini @@ -1,6 +1,2 @@ # pytest.ini -[pytest] -log_cli = 1 -log_cli_level = DEBUG -log_cli_format = %(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s) -log_cli_date_format=%Y-%m-%d %H:%M:%S \ No newline at end of file +[pytest] \ No newline at end of file