Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 41 additions & 13 deletions packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from concurrent.futures import ThreadPoolExecutor
from collections.abc import Iterator

import grpc
from google.protobuf import empty_pb2 as _empty_pb2
from pynumaflow.shared.server import exit_on_error
from pynumaflow.shared.server import update_context_err
from pynumaflow._metadata import _user_and_system_metadata_from_proto

from pynumaflow._constants import NUM_THREADS_DEFAULT, STREAM_EOF, _LOGGER, ERR_UDF_EXCEPTION_STRING
from pynumaflow._constants import NUM_THREADS_DEFAULT, _LOGGER, ERR_UDF_EXCEPTION_STRING
from pynumaflow.mapper._dtypes import MapSyncCallable, Datum, MapError
from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc
from pynumaflow.shared.synciter import SyncIterator
Expand All @@ -26,6 +27,8 @@ def __init__(self, handler: MapSyncCallable, multiproc: bool = False):
self.multiproc = multiproc
# create a thread pool for executing UDF code
self.executor = ThreadPoolExecutor(max_workers=NUM_THREADS_DEFAULT)
self.shutdown_event: threading.Event = threading.Event()
self.error: BaseException | None = None

def MapFn(
self,
Expand All @@ -36,6 +39,7 @@ def MapFn(
Applies a function to each datum element.
The pascal case function name comes from the proto map_pb2_grpc.py file.
"""
result_queue = None
try:
# The first message to be received should be a valid handshake
req = next(request_iterator)
Expand All @@ -57,10 +61,13 @@ def MapFn(
for res in result_queue.read_iterator():
# if error handler accordingly
if isinstance(res, BaseException):
# Terminate the current server process due to exception
exit_on_error(
context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(res)}", parent=self.multiproc
)
err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(res)}"
_LOGGER.critical(err_msg, exc_info=True)
update_context_err(context, res, err_msg)
# Unblock the reader thread if it is waiting on queue.put()
result_queue.close()
self.error = res
self.shutdown_event.set()
return
# return the result
yield res
Expand All @@ -69,12 +76,22 @@ def MapFn(
reader_thread.join()
self.executor.shutdown(cancel_futures=True)

except grpc.RpcError:
_LOGGER.warning("gRPC stream closed, shutting down the server.")
if result_queue is not None:
result_queue.close()
self.shutdown_event.set()
return

except BaseException as err:
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
# Terminate the current server process due to exception
exit_on_error(
context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}", parent=self.multiproc
)
err_msg = f"UDFError, {ERR_UDF_EXCEPTION_STRING}: {repr(err)}"
_LOGGER.critical(err_msg, exc_info=True)
update_context_err(context, err, err_msg)
# Unblock the reader thread if it is waiting on queue.put()
if result_queue is not None:
result_queue.close()
self.error = err
self.shutdown_event.set()
return

def _process_requests(
Expand All @@ -91,9 +108,20 @@ def _process_requests(
# wait for all tasks to finish after all requests exhausted
self.executor.shutdown(wait=True)
# Indicate to the result queue that no more messages left to process
result_queue.put(STREAM_EOF)
result_queue.close()
except grpc.RpcError:
# The only error that can occur here is the gRPC stream closing
# (e.g. client disconnected). UDF exceptions are caught inside _invoke_map
# and never propagate here.
_LOGGER.warning("gRPC stream closed in reader thread, shutting down the server.")
# Let already-submitted UDF tasks finish within the graceful shutdown period
self.executor.shutdown(wait=True)
# Signal MapFn's read_iterator() loop to exit cleanly
result_queue.close()
# Trigger server shutdown (not a UDF error, so self.error is not set)
self.shutdown_event.set()
except BaseException as e:
_LOGGER.critical("MapFn Error, re-raising the error", exc_info=True)
_LOGGER.critical("MapFn Error while reading requests from gRPC stream", exc_info=True)
# Surface the error to the consumer; MapFn will handle and exit
result_queue.put(e)

Expand Down
7 changes: 7 additions & 0 deletions packages/pynumaflow/pynumaflow/mapper/sync_server.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys

from pynumaflow.info.types import (
ServerInfo,
MAP_MODE_KEY,
Expand Down Expand Up @@ -112,4 +114,9 @@ def start(self) -> None:
server_options=self._server_options,
udf_type=UDFType.Map,
server_info=serv_info,
shutdown_event=self.servicer.shutdown_event,
)

if self.servicer.error:
_LOGGER.critical("Server exiting due to UDF error: %s", self.servicer.error)
sys.exit(1)