Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions packages/pynumaflow/pynumaflow/shared/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import multiprocessing
import os
import socket
import threading
import traceback

from google.protobuf import any_pb2
Expand All @@ -18,6 +19,7 @@
from pynumaflow._constants import (
_LOGGER,
MULTIPROC_MAP_SOCK_ADDR,
NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS,
UDFType,
)
from pynumaflow.exceptions import SocketError
Expand Down Expand Up @@ -57,6 +59,7 @@ def sync_server_start(
server_options=None,
server_info: ServerInfo | None = None,
udf_type: str = UDFType.Map,
shutdown_event: threading.Event | None = None,
):
"""
Utility function to start a sync grpc server instance.
Expand All @@ -75,6 +78,7 @@ def sync_server_start(
udf_type=udf_type,
server_info_file=server_info_file,
server_info=server_info,
shutdown_event=shutdown_event,
)


Expand All @@ -86,10 +90,15 @@ def _run_server(
udf_type: str,
server_info_file: str | None = None,
server_info: ServerInfo | None = None,
shutdown_event: threading.Event | None = None,
) -> None:
"""
Starts the Synchronous server instance on the given UNIX socket
with given max threads. Wait for the server to terminate.

If *shutdown_event* is provided, a background daemon thread will wait
on it and then call ``server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)``
for a cooperative graceful shutdown (no process kill).
"""
server = grpc.server(
ThreadPoolExecutor(
Expand All @@ -115,10 +124,21 @@ def _run_server(
server.add_insecure_port(bind_address)
# start the gRPC server
server.start()

# Add the server information to the server info file if provided
if server_info and server_info_file:
info_server_write(server_info=server_info, info_file=server_info_file)

if shutdown_event is not None:

def _watch_for_shutdown():
shutdown_event.wait()
_LOGGER.info("Shutdown signal received, stopping server gracefully...")
server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)

watcher = threading.Thread(target=_watch_for_shutdown, daemon=True)
watcher.start()

_LOGGER.info("GRPC Server listening on: %s %d", bind_address, os.getpid())
server.wait_for_termination()

Expand Down Expand Up @@ -243,14 +263,14 @@ def check_instance(instance, callable_type) -> bool:
return False


def get_grpc_status(err: str):
def get_grpc_status(err: str, detail: str | None = None):
"""
Create a grpc status object with the error details.
"""
details = any_pb2.Any()
details.Pack(
error_details_pb2.DebugInfo(
detail="\n".join(traceback.format_stack()),
detail=detail if detail is not None else "\n".join(traceback.format_stack()),
)
)

Expand Down Expand Up @@ -295,9 +315,9 @@ def update_context_err(context: NumaflowServicerContext, e: BaseException, err_m
"""
trace = get_exception_traceback_str(e)
_LOGGER.critical(trace)
_LOGGER.critical(e.__str__())
_LOGGER.critical(err_msg)

grpc_status = get_grpc_status(err_msg)
grpc_status = get_grpc_status(err_msg, detail=trace)

context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(err_msg)
Expand Down
6 changes: 5 additions & 1 deletion packages/pynumaflow/pynumaflow/shared/synciter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class SyncIterator:
"""A Sync Interator backed by a queue"""
"""A Sync Iterator backed by a queue"""

__slots__ = "_queue"

Expand All @@ -21,3 +21,7 @@ def read_iterator(self):

def put(self, item):
self._queue.put(item)

def close(self):
"""Unblock any thread waiting on read_iterator() by injecting STREAM_EOF."""
self._queue.put(STREAM_EOF)
25 changes: 18 additions & 7 deletions packages/pynumaflow/pynumaflow/shared/thread_with_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ class ThreadWithReturnValue(Thread):
"""
A custom Thread class that allows the target function to return a value.
This class extends the built-in threading.Thread class.
Exceptions raised by the target are captured and re-raised on join().
"""

def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, verbose=None):
Expand All @@ -23,32 +24,42 @@ def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, verbo
Thread.__init__(self, group, target, name, args, kwargs)
# Variable to store the return value of the target function
self._return = None
self._exception: BaseException | None = None

def run(self):
"""
Run the thread.

This method is overridden from the Thread class.
It calls the target function and saves the return value.
If the target raises, the exception is captured for re-raising on join().
"""
if self._target is not None:
# Execute target and store the result
self._return = self._target(*self._args, **self._kwargs)
try:
# Execute target and store the result
self._return = self._target(*self._args, **self._kwargs)
except BaseException as exc:
self._exception = exc

def join(self, *args):
def join(self, timeout=None):
"""
Wait for the thread to complete and return the result.

This method is overridden from the Thread class.
It calls the parent class's join() method and then returns the stored return value.
It calls the parent class's join() method, re-raises any captured
exception, and then returns the stored return value.

Parameters:
*args: Variable length argument list to pass to the join() method.
timeout: Seconds to wait (None means wait indefinitely).

Returns:
The return value from the target function.

Raises:
BaseException: If the target function raised during run().
"""
# Call the parent class's join() method to wait for the thread to finish
Thread.join(self, *args)
Thread.join(self, timeout=timeout)
if self._exception is not None:
raise self._exception
# Return the result of the target function
return self._return
7 changes: 7 additions & 0 deletions packages/pynumaflow/pynumaflow/sinker/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys

from pynumaflow.info.types import ServerInfo, ContainerType, MINIMUM_NUMAFLOW_VERSION
from pynumaflow.sinker.servicer.sync_servicer import SyncSinkServicer
Expand Down Expand Up @@ -120,6 +121,7 @@ def start(self):
)
serv_info = ServerInfo.get_default_server_info()
serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Sinker]

# Start the server
sync_server_start(
servicer=self.servicer,
Expand All @@ -129,4 +131,9 @@ def start(self):
server_options=self._server_options,
udf_type=UDFType.Sink,
server_info=serv_info,
shutdown_event=self.servicer.shutdown_event,
)

if self.servicer.error:
_LOGGER.critical("Server exiting due to UDF error: %s", self.servicer.error)
sys.exit(1)
41 changes: 27 additions & 14 deletions packages/pynumaflow/pynumaflow/sinker/servicer/sync_servicer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import threading
from collections.abc import Iterator

import grpc

from pynumaflow._constants import _LOGGER, STREAM_EOF
from pynumaflow._constants import _LOGGER, STREAM_EOF, ERR_UDF_EXCEPTION_STRING
from pynumaflow.proto.sinker import sink_pb2_grpc, sink_pb2
from pynumaflow.shared.server import exit_on_error
from pynumaflow.shared.server import update_context_err
from pynumaflow.shared.synciter import SyncIterator
from pynumaflow.shared.thread_with_return import ThreadWithReturnValue
from pynumaflow.sinker._dtypes import SinkSyncCallable
Expand All @@ -24,6 +26,8 @@ class SyncSinkServicer(sink_pb2_grpc.SinkServicer):

def __init__(self, handler: SinkSyncCallable):
self.handler: SinkSyncCallable = handler
self.shutdown_event: threading.Event = threading.Event()
self.error: BaseException | None = None

def SinkFn(
self, request_iterator: Iterator[sink_pb2.SinkRequest], context: NumaflowServicerContext
Expand All @@ -32,6 +36,7 @@ def SinkFn(
Applies a sink function to datum elements.
"""

req_queue = None
try:
# The first message to be received should be a valid handshake
req = next(request_iterator)
Expand Down Expand Up @@ -78,23 +83,31 @@ def SinkFn(
if cur_task:
cur_task.join()

except grpc.RpcError:
_LOGGER.warning("gRPC stream closed, shutting down the server.")
if req_queue is not None:
req_queue.close()
self.shutdown_event.set()
return

except BaseException as err:
# Handle exceptions
err_msg = f"UDSinkError: {repr(err)}"
err_msg = f"UDSinkError, {ERR_UDF_EXCEPTION_STRING}: {repr(err)}"
_LOGGER.critical(err_msg, exc_info=True)
exit_on_error(context, err_msg)
update_context_err(context, err, err_msg)
# Unblock the handler thread if it is waiting on queue.get()
# (e.g. gRPC stream broke while the handler was waiting for the next message).
# This lets it exit gracefully and release any user-held resources
# before the process shuts down.
if req_queue is not None:
req_queue.close()
self.error = err
self.shutdown_event.set()
return

def _invoke_sink(self, request_queue: SyncIterator, context: NumaflowServicerContext):
try:
# Invoke the handler function with the request queue
rspns = self.handler(request_queue.read_iterator())
return build_sink_resp_results(rspns)
except BaseException as err:
err_msg = f"UDSinkError: {repr(err)}"
_LOGGER.critical(err_msg, exc_info=True)
exit_on_error(context, err_msg)
raise err
# Invoke the handler function with the request queue
rspns = self.handler(request_queue.read_iterator())
return build_sink_resp_results(rspns)

def IsReady(self, request, context: NumaflowServicerContext) -> sink_pb2.ReadyResponse:
"""
Expand Down
Loading