diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..4cc32cd --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,26 @@ +name: CI + +on: [push] + +jobs: + test: + runs-on: ubuntu-latest + env: + VIRTUAL_ENV: ignore + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.9 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r development.txt + - name: Lint + run: | + pip install flake8 + python -m flake8 pyqs tests + - name: Test + run: | + make test diff --git a/.gitignore b/.gitignore index 33d7c93..51be5a5 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ dist/* *.egg-info/* build/* htmlcov/* +.idea/* *~ *# diff --git a/.travis.yml b/.travis.yml index 67e8055..18963b2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,6 @@ language: python python: - - "2.7" - "3.6" - "3.7" - "3.8" diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0400906..5850e26 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,25 @@ Changelog --------- +1.0.1 +~~~~~ +- Add MessageId for tracking task executions + +1.0.0 +~~~~~ +- Drop Py2 support +- Add new SimpleProcessWorker (https://github.com/spulec/PyQS/pull/76) + +0.1.6 +~~~~~ + +- Fix broken pickle of botocore clients. + +0.1.5 +~~~~~ + +- Add events hooks for pre and post processors. + 0.1.4 ~~~~~ diff --git a/Makefile b/Makefile index 0da598a..68c12d5 100644 --- a/Makefile +++ b/Makefile @@ -41,26 +41,7 @@ clean: @find . -name __pycache__ -delete @rm -rf .coverage *.egg-info *.log build dist MANIFEST yc -publish: clean tag - @if [ -e "$$HOME/.pypirc" ]; then \ - echo "Uploading to '$(CUSTOM_PIP_INDEX)'"; \ - python setup.py register -r "$(CUSTOM_PIP_INDEX)"; \ - python setup.py sdist upload -r "$(CUSTOM_PIP_INDEX)"; \ - else \ - echo "You should create a file called '.pypirc' under your home dir.\n"; \ - echo "That's the right place to configure 'pypi' repos.\n"; \ - exit 1; \ - fi - -tag: - @if [ $$(git rev-list $$(git describe --abbrev=0 --tags)..HEAD --count) -gt 0 ]; then \ - if [ $$(git log -n 1 --oneline $$(git describe --abbrev=0 --tags)..HEAD CHANGELOG.rst | wc -l) -gt 0 ]; then \ - git tag $$(python setup.py --version) && git push --tags || echo 'Version already released, update your version!'; \ - else \ - echo "CHANGELOG not updated since last release!"; \ - exit 1; \ - fi; \ - else \ - echo "No commits since last release!"; \ - exit 1;\ - fi +publish: clean + rm -rf dist + python -m pep517.build --source --binary . + twine upload dist/* diff --git a/README.rst b/README.rst index f251a8d..94e954f 100644 --- a/README.rst +++ b/README.rst @@ -78,7 +78,7 @@ To read tasks we need to run PyQS. If the task is already in your $ pyqs email.tasks.send_email -If we want want to run all tasks with a certain prefix. This is based on +If we want to run all tasks with a certain prefix. This is based on Python's `fnmatch `__. .. code:: bash @@ -100,6 +100,45 @@ messages. $ pyqs send_email --concurrency 10 +Simple Process Worker +~~~~~~~~~~~~~~~~~~~~~ + +To use a simpler version of PyQS that deals with some of the edge cases in the original implementation, pass the ``simple-worker`` flag. + +.. code:: bash + + $ pyqs send_email --simple-worker + +The Simple Process Worker differs in the following way from the original implementation. + +* Does not use an internal queue and removes support for the ``prefetch-multiplier`` flag. This helps simply the mental model required, as messages are not on both the SQS queue and an internal queue. +* When the ``simple-worker`` flag is passed, the default ``batchsize`` is 1 instead of 10. This is configurable. +* Does not check the visibility timeout when reading or processing a message from SQS. + * Allowing the worker to process the message even past its visibility timeout means we solve the problem of never processing a message if ``max_receives=1`` and we incorrectly set a shorter visibility timeout and exceed the visibility timeout. Previously, this message would have ended up in the DLQ, if one was configured, and never actually processed. + * It increases the probability that we process a message more than once, especially if ``batchsize > 1``, but this can be solved by the developer checking if the message has already been processed. + +Hooks +~~~~~ + +PyQS has an event registry which can be used to run a function before or after every tasks runs. + +.. code:: python + + from pyqs import task, events + + def print_pre_process(context): + print({"pre_process": context}) + + def print_post_process(context): + print({"post_process": context}) + + events.register_event("pre_process", print_pre_process) + events.register_event("post_process", print_post_process) + + @task(queue="my_queue") + def send_email(subject): + pass + Operational Notes ~~~~~~~~~~~~~~~~~ diff --git a/development.txt b/development.txt index 71b8fc8..1f38b8b 100644 --- a/development.txt +++ b/development.txt @@ -2,7 +2,9 @@ coverage==4.4.1 mock==1.0.1 moto==1.3.13 nose==1.3.0 +pep517==0.9.1 pre-commit==0.7.6 sure==1.2.2 +twine==3.3.0 functools32;python_version=='2.7' pycodestyle==2.4.0 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..9787c3b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/pyqs/__init__.py b/pyqs/__init__.py index bd44c8a..757f9cd 100644 --- a/pyqs/__init__.py +++ b/pyqs/__init__.py @@ -1,4 +1,4 @@ from .decorator import task # noqa __title__ = 'pyqs' -__version__ = '0.1.4' +__version__ = '1.0.1' diff --git a/pyqs/events.py b/pyqs/events.py new file mode 100644 index 0000000..e0d988c --- /dev/null +++ b/pyqs/events.py @@ -0,0 +1,43 @@ +""" +pyqs events registry: register callback functions on pyqs events + +Usage: +from pyqs.events import register_event + +register_event("pre_process", lambda context: print(context)) +""" + + +class Events: + def __init__(self): + self.pre_process = [] + self.post_process = [] + + def clear(self): + self.pre_process = [] + self.post_process = [] + + +# Global singleton +_EVENTS = Events() + + +class NoEventException(Exception): + pass + + +def register_event(name, callback): + if hasattr(_EVENTS, name): + getattr(_EVENTS, name).append(callback) + else: + raise NoEventException( + "{name} is not a valid pyqs event.".format(name=name) + ) + + +def get_events(): + return _EVENTS + + +def clear_events(): + _EVENTS.clear() diff --git a/pyqs/main.py b/pyqs/main.py index 33b55b5..f2fd0c6 100644 --- a/pyqs/main.py +++ b/pyqs/main.py @@ -7,11 +7,28 @@ import sys from argparse import ArgumentParser -from .worker import ManagerWorker +from .worker import ManagerWorker, SimpleManagerWorker from . import __version__ logger = logging.getLogger("pyqs") +SIMPLE_WORKER_DEFAULT_BATCH_SIZE = 1 +DEFAULT_BATCH_SIZE = 10 + + +def _set_batchsize(args): + batchsize = args.batchsize + if batchsize: + return batchsize + + simple_worker = args.simple_worker + if simple_worker: + # Default batchsize for SimpleProcessWorker + return SIMPLE_WORKER_DEFAULT_BATCH_SIZE + + # Default batchsize for ProcessWorker + return DEFAULT_BATCH_SIZE + def main(): parser = ArgumentParser(description=""" @@ -77,6 +94,15 @@ def main(): action="store", ) + parser.add_argument( + "--endpoint-url", + dest="endpoint_url", + type=str, + default=None, + help="AWS SQS endpoint url", + action="store", + ) + parser.add_argument( "--interval", dest="interval", @@ -90,7 +116,7 @@ def main(): "--batchsize", dest="batchsize", type=int, - default=10, + default=None, help='How many messages to download at a time from SQS.', action="store", ) @@ -107,6 +133,13 @@ def main(): action="store", ) + parser.add_argument( + '--simple-worker', + dest='simple_worker', + default=False, + action='store_true' + ) + args = parser.parse_args() _main( @@ -117,8 +150,10 @@ def main(): access_key_id=args.access_key_id, secret_access_key=args.secret_access_key, interval=args.interval, - batchsize=args.batchsize, - prefetch_multiplier=args.prefetch_multiplier + batchsize=_set_batchsize(args), + prefetch_multiplier=args.prefetch_multiplier, + simple_worker=args.simple_worker, + endpoint_url=args.endpoint_url, ) @@ -130,17 +165,29 @@ def _add_cwd_to_path(): def _main(queue_prefixes, concurrency=5, logging_level="WARN", region=None, access_key_id=None, secret_access_key=None, - interval=1, batchsize=10, prefetch_multiplier=2): + interval=1, batchsize=DEFAULT_BATCH_SIZE, prefetch_multiplier=2, + simple_worker=False, endpoint_url=None): logging.basicConfig( format="[%(levelname)s]: %(message)s", level=getattr(logging, logging_level), ) logger.info("Starting PyQS version {}".format(__version__)) - manager = ManagerWorker( - queue_prefixes, concurrency, interval, batchsize, - prefetch_multiplier=prefetch_multiplier, region=region, - access_key_id=access_key_id, secret_access_key=secret_access_key, - ) + + if simple_worker: + manager = SimpleManagerWorker( + queue_prefixes, concurrency, interval, batchsize, + region=region, access_key_id=access_key_id, + secret_access_key=secret_access_key, + endpoint_url=endpoint_url, + ) + else: + manager = ManagerWorker( + queue_prefixes, concurrency, interval, batchsize, + prefetch_multiplier=prefetch_multiplier, region=region, + access_key_id=access_key_id, secret_access_key=secret_access_key, + endpoint_url=endpoint_url, + ) + _add_cwd_to_path() manager.start() manager.sleep() diff --git a/pyqs/worker.py b/pyqs/worker.py index d1ba0b4..4e655ed 100644 --- a/pyqs/worker.py +++ b/pyqs/worker.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals +import copy import fnmatch import importlib import logging @@ -19,29 +20,50 @@ import boto3 from pyqs.utils import get_aws_region_name, decode_message +from pyqs.events import get_events MESSAGE_DOWNLOAD_BATCH_SIZE = 10 LONG_POLLING_INTERVAL = 20 logger = logging.getLogger("pyqs") -def get_conn(region=None, access_key_id=None, secret_access_key=None): - if not region: - region = get_aws_region_name() +def get_conn( + region=None, access_key_id=None, secret_access_key=None, endpoint_url=None +): + kwargs = { + "aws_access_key_id": access_key_id, + "aws_secret_access_key": secret_access_key, + "region_name": region, + } + + if endpoint_url: + kwargs["endpoint_url"] = endpoint_url + if not kwargs["region_name"]: + kwargs["region_name"] = get_aws_region_name() return boto3.client( "sqs", - aws_access_key_id=access_key_id, - aws_secret_access_key=secret_access_key, region_name=region, + **kwargs, ) class BaseWorker(Process): def __init__(self, *args, **kwargs): + self._connection = None self.parent_id = kwargs.pop('parent_id') super(BaseWorker, self).__init__(*args, **kwargs) self.should_exit = Event() + def _get_connection(self): + if self._connection: + return self._connection + + if self.connection_args is None: + self._connection = get_conn() + else: + self._connection = get_conn(**self.connection_args) + return self._connection + def shutdown(self): logger.info( "Received shutdown signal, shutting down PID {}!".format( @@ -65,10 +87,9 @@ def __init__(self, queue_url, internal_queue, batchsize, if connection_args is None: connection_args = {} self.connection_args = connection_args - self.conn = get_conn(**self.connection_args) self.queue_url = queue_url - sqs_queue = self.conn.get_queue_attributes( + sqs_queue = get_conn(**self.connection_args).get_queue_attributes( QueueUrl=queue_url, AttributeNames=['All'])['Attributes'] self.visibility_timeout = int(sqs_queue['VisibilityTimeout']) @@ -88,7 +109,7 @@ def run(self): self.internal_queue.cancel_join_thread() def read_message(self): - messages = self.conn.receive_message( + messages = self._get_connection().receive_message( QueueUrl=self.queue_url, MaxNumberOfMessages=self.batchsize, WaitTimeSeconds=LONG_POLLING_INTERVAL, @@ -136,34 +157,117 @@ def read_message(self): self.queue_url, message_body)) # noqa -class ProcessWorker(BaseWorker): +class BaseProcessWorker(BaseWorker): + def __init__(self, *args, **kwargs): + super(BaseProcessWorker, self).__init__(*args, **kwargs) + + def _run_hooks(self, hook_name, context): + hooks = getattr(get_events(), hook_name) + for hook in hooks: + hook(context) + + def _create_pre_process_context(self, packed_message): + message = packed_message['message'] + message_body = decode_message(message) + full_task_path = message_body['task'] + + pre_process_context = { + "message_id": message['MessageId'], + "task_name": full_task_path.split(".")[-1], + "args": message_body['args'], + "kwargs": message_body['kwargs'], + "full_task_path": full_task_path, + "fetch_time": packed_message['start_time'], + "queue_url": packed_message['queue'], + "timeout": packed_message['timeout'], + "receipt_handle": message['ReceiptHandle'] + } + + return pre_process_context + + def _get_task(self, full_task_path): + + task_name = full_task_path.split(".")[-1] + task_path = ".".join(full_task_path.split(".")[:-1]) + task_module = importlib.import_module(task_path) + task = getattr(task_module, task_name) + + return task + + def _process_task(self, pre_process_context): + task = self._get_task(pre_process_context["full_task_path"]) + + # Modify the contexts separately so the original + # context isn't modified by later processing + post_process_context = copy.copy(pre_process_context) + + start_time = time.time() + try: + self._run_hooks("pre_process", pre_process_context) + task(*pre_process_context["args"], **pre_process_context["kwargs"]) + except Exception: + end_time = time.time() + logger.exception( + "Task {} raised error in {:.4f} seconds: with args: {} " + "and kwargs: {}: {}".format( + pre_process_context["full_task_path"], + end_time - start_time, + pre_process_context["args"], + pre_process_context["kwargs"], + traceback.format_exc(), + ) + ) + post_process_context["status"] = "exception" + post_process_context["exception"] = traceback.format_exc() + self._run_hooks("post_process", post_process_context) + return True + else: + end_time = time.time() + self._get_connection().delete_message( + QueueUrl=pre_process_context["queue_url"], + ReceiptHandle=pre_process_context["receipt_handle"] + ) + logger.info( + "Processed task {} in {:.4f} seconds with args: {} " + "and kwargs: {}".format( + pre_process_context["full_task_path"], + end_time - start_time, + pre_process_context["args"], + pre_process_context["kwargs"], + ) + ) + post_process_context["status"] = "success" + self._run_hooks("post_process", post_process_context) + return True + + +class ProcessWorker(BaseProcessWorker): def __init__(self, internal_queue, interval, connection_args=None, *args, **kwargs): super(ProcessWorker, self).__init__(*args, **kwargs) - if connection_args is None: - self.conn = get_conn() - else: - self.conn = get_conn(**connection_args) + self.connection_args = connection_args self.internal_queue = internal_queue self.interval = interval self._messages_to_process_before_shutdown = 100 + self.messages_processed = 0 def run(self): # Set the child process to not receive any keyboard interrupts signal.signal(signal.SIGINT, signal.SIG_IGN) logger.info("Running ProcessWorker, pid: {}".format(os.getpid())) - messages_processed = 0 + while not self.should_exit.is_set() and self.parent_is_alive(): processed = self.process_message() if processed: - messages_processed += 1 + self.messages_processed += 1 time.sleep(self.interval) else: # If we have no messages wait a moment before rechecking. time.sleep(0.001) - if messages_processed >= self._messages_to_process_before_shutdown: + if self.messages_processed \ + >= self._messages_to_process_before_shutdown: self.shutdown() def process_message(self): @@ -172,93 +276,115 @@ def process_message(self): except Empty: # Return False if we did not attempt to process any messages return False - message = packed_message['message'] - queue_url = packed_message['queue'] - fetch_time = packed_message['start_time'] - timeout = packed_message['timeout'] - message_body = decode_message(message) - full_task_path = message_body['task'] - args = message_body['args'] - kwargs = message_body['kwargs'] - task_name = full_task_path.split(".")[-1] - task_path = ".".join(full_task_path.split(".")[:-1]) - - task_module = importlib.import_module(task_path) - - task = getattr(task_module, task_name) + pre_process_context = self._create_pre_process_context(packed_message) current_time = time.time() - if int(current_time - fetch_time) >= timeout: + if int(current_time - pre_process_context["fetch_time"]) \ + >= pre_process_context["timeout"]: logger.warning( "Discarding task {} with args: {} and kwargs: {} due to " "exceeding visibility timeout".format( # noqa - full_task_path, - repr(args), - repr(kwargs), - ) - ) - return True - try: - start_time = time.time() - task(*args, **kwargs) - except Exception: - end_time = time.time() - logger.exception( - "Task {} raised error in {:.4f} seconds: with args: {} " - "and kwargs: {}: {}".format( - full_task_path, - end_time - start_time, - args, - kwargs, - traceback.format_exc(), + pre_process_context["full_task_path"], + repr(pre_process_context["args"]), + repr(pre_process_context["kwargs"]), ) ) return True - else: - end_time = time.time() - self.conn.delete_message( - QueueUrl=queue_url, - ReceiptHandle=message['ReceiptHandle'] - ) - logger.info( - "Processed task {} in {:.4f} seconds with args: {} " - "and kwargs: {}".format( - full_task_path, - end_time - start_time, - repr(args), - repr(kwargs), - ) - ) - return True + return self._process_task(pre_process_context) -class ManagerWorker(object): - def __init__(self, queue_prefixes, worker_concurrency, interval, batchsize, - prefetch_multiplier=2, region=None, access_key_id=None, - secret_access_key=None): +class SimpleProcessWorker(BaseProcessWorker): + + def __init__(self, queue_url, interval, batchsize, + connection_args=None, *args, **kwargs): + super(SimpleProcessWorker, self).__init__(*args, **kwargs) + if connection_args is None: + connection_args = {} + self.connection_args = connection_args + self.queue_url = queue_url + + sqs_queue = get_conn(**self.connection_args).get_queue_attributes( + QueueUrl=queue_url, AttributeNames=['All'])['Attributes'] + self.visibility_timeout = int(sqs_queue['VisibilityTimeout']) + + self.interval = interval + self.batchsize = batchsize + self._messages_to_process_before_shutdown = 100 + self.messages_processed = 0 + + def run(self): + # Set the child process to not receive any keyboard interrupts + signal.signal(signal.SIGINT, signal.SIG_IGN) + + logger.info( + "Running SimpleProcessWorker: {}, pid: {}".format( + self.queue_url, os.getpid())) + + while not self.should_exit.is_set() and self.parent_is_alive(): + messages = self.read_message() + start = time.time() + + for message in messages: + packed_message = { + "queue": self.queue_url, + "message": message, + "start_time": start, + "timeout": self.visibility_timeout, + } + + processed = self.process_message(packed_message) + + if processed: + self.messages_processed += 1 + time.sleep(self.interval) + + if self.messages_processed \ + >= self._messages_to_process_before_shutdown: + self.shutdown() + + def read_message(self): + messages = self._get_connection().receive_message( + QueueUrl=self.queue_url, + MaxNumberOfMessages=self.batchsize, + WaitTimeSeconds=LONG_POLLING_INTERVAL, + ).get('Messages', []) + + logger.debug( + "Successfully got {} messages from SQS queue {}".format( + len(messages), self.queue_url)) # noqa + + return messages + + def process_message(self, packed_message): + + pre_process_context = self._create_pre_process_context(packed_message) + + return self._process_task(pre_process_context) + + +class BaseManager(object): + + def __init__(self, queue_prefixes, interval, batchsize, + region=None, access_key_id=None, + secret_access_key=None, endpoint_url=None): self.connection_args = { "region": region, "access_key_id": access_key_id, "secret_access_key": secret_access_key, + "endpoint_url": endpoint_url, } + self.interval = interval self.batchsize = batchsize if batchsize > MESSAGE_DOWNLOAD_BATCH_SIZE: self.batchsize = MESSAGE_DOWNLOAD_BATCH_SIZE if batchsize <= 0: self.batchsize = 1 - self.interval = interval - self.prefetch_multiplier = prefetch_multiplier self.queue_prefixes = queue_prefixes self.queue_urls = self.get_queue_urls_from_queue_prefixes( self.queue_prefixes) - self.setup_internal_queue(worker_concurrency) - self.reader_children = [] - self.worker_children = [] self._pid = os.getpid() - self._initialize_reader_children() - self._initialize_worker_children(worker_concurrency) self._running = True self._register_signals() @@ -267,6 +393,154 @@ def _register_signals(self): signal.SIGHUP]: self.register_shutdown_signal(SIG) + def get_queue_urls_from_queue_prefixes(self, queue_prefixes): + conn = get_conn(**self.connection_args) + queue_urls = conn.list_queues().get('QueueUrls', []) + matching_urls = [] + + logger.info("Loading Queues:") + for prefix in queue_prefixes: + logger.info("[Queue]\t{}".format(prefix)) + matching_urls.extend([ + queue_url for queue_url in queue_urls if + fnmatch.fnmatch(queue_url.rsplit("/", 1)[1], prefix) + ]) + logger.info("Found matching SQS Queues: {}".format(matching_urls)) + return matching_urls + + def check_for_new_queues(self): + raise NotImplementedError + + def start(self): + raise NotImplementedError + + def stop(self): + raise NotImplementedError + + def sleep(self): + counter = 0 + while self._running: + counter = counter + 1 + if counter % 1000 == 0: + self.process_counts() + self.replace_workers() + if counter % 30000 == 0: + counter = 0 + self.check_for_new_queues() + time.sleep(0.001) + self._exit() + + def register_shutdown_signal(self, SIG): + signal.signal(SIG, self._graceful_shutdown) + + def _graceful_shutdown(self, signum, frame): + logger.info('Received shutdown signal %s', signum) + self._running = False + + def _exit(self): + logger.info('Graceful shutdown. Sending shutdown signal to children.') + self.stop() + sys.exit(0) + + def process_counts(self): + raise NotImplementedError + + def replace_workers(self): + raise NotImplementedError + + +class SimpleManagerWorker(BaseManager): + WORKER_CHILDREN_CLASS = SimpleProcessWorker + + def __init__(self, queue_prefixes, worker_concurrency, interval, batchsize, + region=None, access_key_id=None, secret_access_key=None, + endpoint_url=None): + + super(SimpleManagerWorker, self).__init__(queue_prefixes, interval, + batchsize, region, + access_key_id, + secret_access_key, + endpoint_url) + + self.worker_children = [] + self._initialize_worker_children(worker_concurrency) + + def _initialize_worker_children(self, number): + for queue_url in self.queue_urls: + for index in range(number): + self.worker_children.append( + self.WORKER_CHILDREN_CLASS( + queue_url, self.interval, self.batchsize, + connection_args=self.connection_args, + parent_id=self._pid, + ) + ) + + def check_for_new_queues(self): + queue_urls = self.get_queue_urls_from_queue_prefixes( + self.queue_prefixes) + new_queue_urls = set(queue_urls) - set(self.queue_urls) + for new_queue_url in new_queue_urls: + logger.info("Found new queue\t{}".format(new_queue_url)) + worker = self.WORKER_CHILDREN_CLASS( + new_queue_url, self.interval, self.batchsize, + connection_args=self.connection_args, + parent_id=self._pid, + ) + worker.start() + self.worker_children.append(worker) + + def start(self): + for child in self.worker_children: + child.start() + + def stop(self): + for child in self.worker_children: + child.shutdown() + for child in self.worker_children: + child.join() + + def process_counts(self): + worker_count = sum(map(lambda x: x.is_alive(), self.worker_children)) + logger.debug("Worker Processes: {}".format(worker_count)) + + def replace_workers(self): + for index, worker in enumerate(self.worker_children): + if not worker.is_alive(): + logger.info( + "Worker Process {} is no longer responding, " + "spawning a new worker.".format(worker.pid)) + self.worker_children.pop(index) + worker = self.WORKER_CHILDREN_CLASS( + worker.queue_url, self.interval, self.batchsize, + connection_args=self.connection_args, + parent_id=self._pid, + ) + worker.start() + self.worker_children.append(worker) + + +class ManagerWorker(BaseManager): + WORKER_CHILDREN_CLASS = ProcessWorker + + def __init__(self, queue_prefixes, worker_concurrency, interval, batchsize, + prefetch_multiplier=2, region=None, access_key_id=None, + secret_access_key=None, endpoint_url=None): + + super(ManagerWorker, self).__init__(queue_prefixes, interval, + batchsize, region, + access_key_id, + secret_access_key, + endpoint_url) + + self.prefetch_multiplier = prefetch_multiplier + self.worker_children = [] + self.reader_children = [] + + self.setup_internal_queue(worker_concurrency) + self._initialize_reader_children() + self._initialize_worker_children(worker_concurrency) + def _initialize_reader_children(self): for queue_url in self.queue_urls: self.reader_children.append( @@ -280,28 +554,13 @@ def _initialize_reader_children(self): def _initialize_worker_children(self, number): for index in range(number): self.worker_children.append( - ProcessWorker( + self.WORKER_CHILDREN_CLASS( self.internal_queue, self.interval, connection_args=self.connection_args, parent_id=self._pid, ) ) - def get_queue_urls_from_queue_prefixes(self, queue_prefixes): - conn = get_conn(**self.connection_args) - queue_urls = conn.list_queues().get('QueueUrls', []) - matching_urls = [] - - logger.info("Loading Queues:") - for prefix in queue_prefixes: - logger.info("[Queue]\t{}".format(prefix)) - matching_urls.extend([ - queue_url for queue_url in queue_urls if - fnmatch.fnmatch(queue_url.rsplit("/", 1)[1], prefix) - ]) - logger.info("Found matching SQS Queues: {}".format(matching_urls)) - return matching_urls - def check_for_new_queues(self): queue_urls = self.get_queue_urls_from_queue_prefixes( self.queue_prefixes) @@ -337,31 +596,6 @@ def stop(self): for child in self.worker_children: child.join() - def sleep(self): - counter = 0 - while self._running: - counter = counter + 1 - if counter % 1000 == 0: - self.process_counts() - self.replace_workers() - if counter % 30000 == 0: - counter = 0 - self.check_for_new_queues() - time.sleep(0.001) - self._exit() - - def register_shutdown_signal(self, SIG): - signal.signal(SIG, self._graceful_shutdown) - - def _graceful_shutdown(self, signum, frame): - logger.info('Received shutdown signal %s', signum) - self._running = False - - def _exit(self): - logger.info('Graceful shutdown. Sending shutdown signal to children.') - self.stop() - sys.exit(0) - def process_counts(self): reader_count = sum(map(lambda x: x.is_alive(), self.reader_children)) worker_count = sum(map(lambda x: x.is_alive(), self.worker_children)) @@ -395,7 +629,7 @@ def _replace_worker_children(self): "Worker Process {} is no longer responding, " "spawning a new worker.".format(worker.pid)) self.worker_children.pop(index) - worker = ProcessWorker( + worker = self.WORKER_CHILDREN_CLASS( self.internal_queue, self.interval, connection_args=self.connection_args, parent_id=self._pid, diff --git a/tests/__init__.py b/tests/__init__.py index 387d5c9..41342ef 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1 @@ -import sure # flake8: noqa +import sure # noqa: F401 diff --git a/tests/tasks.py b/tests/tasks.py index a92ddd7..3ebdd73 100644 --- a/tests/tasks.py +++ b/tests/tasks.py @@ -36,3 +36,8 @@ def delayed_task(): @task(custom_function_path="custom_function.path", queue="foobar") def custom_path_task(): pass + + +@task() +def exception_task(message, extra=None): + raise Exception('this task raises an exception!') diff --git a/tests/test_events.py b/tests/test_events.py new file mode 100644 index 0000000..7823b4e --- /dev/null +++ b/tests/test_events.py @@ -0,0 +1,75 @@ +from nose.tools import assert_raises +from pyqs import events +from tests.utils import clear_events_registry + + +@clear_events_registry +def test_register_event(): + def print_pre_process(context): + print(context) + + events.register_event("pre_process", print_pre_process) + events.get_events().pre_process.should.equal([print_pre_process]) + + +@clear_events_registry +def test_register_multiple_same_events(): + def print_pre_process(context): + print(context) + + def print_numbers(context): + print(1 + 2) + + events.register_event("pre_process", print_pre_process) + events.register_event("pre_process", print_numbers) + events.get_events().pre_process.should.equal([ + print_pre_process, print_numbers + ]) + + +@clear_events_registry +def test_register_different_events(): + def print_pre_process(context): + print(context) + + def print_post_process(context): + print(context) + + events.register_event("pre_process", print_pre_process) + events.register_event("post_process", print_post_process) + events.get_events().pre_process.should.equal([print_pre_process]) + events.get_events().post_process.should.equal([print_post_process]) + + +@clear_events_registry +def test_register_multiple_different_events(): + def print_pre_process(context): + print(context) + + def print_post_process(context): + print(context) + + def print_numbers(context): + print(1 + 2) + + events.register_event("pre_process", print_pre_process) + events.register_event("pre_process", print_numbers) + events.register_event("post_process", print_post_process) + events.register_event("post_process", print_numbers) + events.get_events().pre_process.should.equal([ + print_pre_process, print_numbers + ]) + events.get_events().post_process.should.equal([ + print_post_process, print_numbers + ]) + + +@clear_events_registry +def test_register_non_existent_event(): + non_existent_event = "non_existent_event" + assert_raises( + events.NoEventException, + events.register_event, + non_existent_event, + lambda x: x + ) diff --git a/tests/test_manager_worker.py b/tests/test_manager_worker.py index 3089195..d5a3289 100644 --- a/tests/test_manager_worker.py +++ b/tests/test_manager_worker.py @@ -97,6 +97,7 @@ def test_main_method(ManagerWorker): ManagerWorker.assert_called_once_with( ['email1', 'email2'], 2, 1, 10, prefetch_multiplier=2, region=None, secret_access_key=None, access_key_id=None, + endpoint_url=None, ) ManagerWorker.return_value.start.assert_called_once_with() @@ -109,16 +110,41 @@ def test_real_main_method(ArgumentParser, _main): Test parsing of arguments from main method """ ArgumentParser.return_value.parse_args.return_value = Mock( - concurrency=3, queues=["email1"], interval=1, batchsize=10, + concurrency=3, queues=["email1"], interval=1, batchsize=5, logging_level="WARN", region='us-east-1', prefetch_multiplier=2, - access_key_id=None, secret_access_key=None, + access_key_id=None, secret_access_key=None, simple_worker=False, + endpoint_url=None, + ) + main() + + _main.assert_called_once_with( + queue_prefixes=['email1'], concurrency=3, interval=1, batchsize=5, + logging_level="WARN", region='us-east-1', prefetch_multiplier=2, + access_key_id=None, secret_access_key=None, simple_worker=False, + endpoint_url=None, + ) + + +@patch("pyqs.main._main") +@patch("pyqs.main.ArgumentParser") +@mock_sqs +def test_real_main_method_default_batchsize(ArgumentParser, _main): + """ + Test parsing of arguments from main method batch default + """ + ArgumentParser.return_value.parse_args.return_value = Mock( + concurrency=3, queues=["email1"], interval=1, batchsize=None, + logging_level="WARN", region='us-east-1', prefetch_multiplier=2, + access_key_id=None, secret_access_key=None, simple_worker=False, + endpoint_url=None, ) main() _main.assert_called_once_with( queue_prefixes=['email1'], concurrency=3, interval=1, batchsize=10, logging_level="WARN", region='us-east-1', prefetch_multiplier=2, - access_key_id=None, secret_access_key=None, + access_key_id=None, secret_access_key=None, simple_worker=False, + endpoint_url=None, ) diff --git a/tests/test_simple_manager_worker.py b/tests/test_simple_manager_worker.py new file mode 100644 index 0000000..012faec --- /dev/null +++ b/tests/test_simple_manager_worker.py @@ -0,0 +1,377 @@ +import json +import logging +import os +import signal +import time + +import boto3 +from mock import patch, Mock, MagicMock +from moto import mock_sqs + +from pyqs.main import main, _main +from pyqs.worker import SimpleManagerWorker +from tests.utils import ( + MockLoggingHandler, ThreadWithReturnValue2, ThreadWithReturnValue3, +) + + +@mock_sqs +def test_simple_manager_worker_create_proper_children_workers(): + """ + Test simple managing process creates multiple child workers + """ + conn = boto3.client('sqs', region_name='us-east-1') + conn.create_queue(QueueName="email") + + manager = SimpleManagerWorker( + queue_prefixes=['email'], worker_concurrency=3, interval=2, + batchsize=10, + ) + + len(manager.worker_children).should.equal(3) + + +@mock_sqs +def test_simple_manager_worker_with_queue_prefix(): + """ + Test simple managing process can find queues by prefix + """ + conn = boto3.client('sqs', region_name='us-east-1') + conn.create_queue(QueueName="email.foobar") + conn.create_queue(QueueName="email.baz") + + manager = SimpleManagerWorker( + queue_prefixes=['email.*'], worker_concurrency=1, interval=1, + batchsize=10, + ) + + len(manager.worker_children).should.equal(2) + children = manager.worker_children + # Pull all the read children and sort by name to make testing easier + sorted_children = sorted(children, key=lambda child: child.queue_url) + + sorted_children[0].queue_url.should.equal( + "https://queue.amazonaws.com/123456789012/email.baz") + sorted_children[1].queue_url.should.equal( + "https://queue.amazonaws.com/123456789012/email.foobar") + + +@mock_sqs +def test_simple_manager_start_and_stop(): + """ + Test simple managing process can start and stop child processes + """ + conn = boto3.client('sqs', region_name='us-east-1') + conn.create_queue(QueueName="email") + + manager = SimpleManagerWorker( + queue_prefixes=['email'], worker_concurrency=2, interval=1, + batchsize=10, + ) + + len(manager.worker_children).should.equal(2) + + manager.worker_children[0].is_alive().should.equal(False) + manager.worker_children[1].is_alive().should.equal(False) + + manager.start() + + manager.worker_children[0].is_alive().should.equal(True) + manager.worker_children[1].is_alive().should.equal(True) + + manager.stop() + + manager.worker_children[0].is_alive().should.equal(False) + manager.worker_children[1].is_alive().should.equal(False) + + +@patch("pyqs.main.SimpleManagerWorker") +@mock_sqs +def test_main_method(SimpleManagerWorker): + """ + Test creation of simple manager process from _main method + """ + _main(["email1", "email2"], concurrency=2, simple_worker=True) + + SimpleManagerWorker.assert_called_once_with( + ['email1', 'email2'], 2, 1, 10, + region=None, secret_access_key=None, access_key_id=None, + endpoint_url=None, + ) + SimpleManagerWorker.return_value.start.assert_called_once_with() + + +@patch("pyqs.main._main") +@patch("pyqs.main.ArgumentParser") +@mock_sqs +def test_real_main_method(ArgumentParser, _main): + """ + Test parsing of arguments from main method + """ + ArgumentParser.return_value.parse_args.return_value = Mock( + concurrency=3, queues=["email1"], interval=1, batchsize=5, + logging_level="WARN", region='us-east-1', prefetch_multiplier=2, + access_key_id=None, secret_access_key=None, simple_worker=True, + endpoint_url=None, + ) + main() + + _main.assert_called_once_with( + queue_prefixes=['email1'], concurrency=3, interval=1, batchsize=5, + logging_level="WARN", region='us-east-1', prefetch_multiplier=2, + access_key_id=None, secret_access_key=None, simple_worker=True, + endpoint_url=None, + ) + + +@patch("pyqs.main._main") +@patch("pyqs.main.ArgumentParser") +@mock_sqs +def test_real_main_method_default_batchsize(ArgumentParser, _main): + """ + Test parsing of arguments from main method batch default + """ + ArgumentParser.return_value.parse_args.return_value = Mock( + concurrency=3, queues=["email1"], interval=1, batchsize=None, + logging_level="WARN", region='us-east-1', prefetch_multiplier=2, + access_key_id=None, secret_access_key=None, simple_worker=True, + endpoint_url=None, + ) + main() + + _main.assert_called_once_with( + queue_prefixes=['email1'], concurrency=3, interval=1, batchsize=1, + logging_level="WARN", region='us-east-1', prefetch_multiplier=2, + access_key_id=None, secret_access_key=None, simple_worker=True, + endpoint_url=None, + ) + + +@mock_sqs +def test_master_spawns_worker_processes(): + """ + Test simple managing process creates child workers + """ + + # Setup SQS Queue + conn = boto3.client('sqs', region_name='us-east-1') + conn.create_queue(QueueName="tester") + + # Setup Manager + manager = SimpleManagerWorker(["tester"], 1, 1, 10) + manager.start() + + # Check Workers + len(manager.worker_children).should.equal(1) + + manager.worker_children[0].is_alive().should.be.true + + # Cleanup + manager.stop() + + +@mock_sqs +def test_master_counts_processes(): + """ + Test simple managing process counts child processes + """ + + # Setup Logging + logger = logging.getLogger("pyqs") + del logger.handlers[:] + logger.handlers.append(MockLoggingHandler()) + + # Setup SQS Queue + conn = boto3.client('sqs', region_name='us-east-1') + conn.create_queue(QueueName="tester") + + # Setup Manager + manager = SimpleManagerWorker(["tester"], 2, 1, 10) + manager.start() + + # Check Workers + manager.process_counts() + + # Cleanup + manager.stop() + + # Check messages + msg2 = "Worker Processes: 2" + logger.handlers[0].messages['debug'][-1].lower().should.contain( + msg2.lower()) + + +@mock_sqs +def test_master_replaces_worker_processes(): + """ + Test simple managing process replaces worker processes + """ + # Setup SQS Queue + conn = boto3.client('sqs', region_name='us-east-1') + conn.create_queue(QueueName="tester") + + # Setup Manager + manager = SimpleManagerWorker( + queue_prefixes=["tester"], worker_concurrency=1, interval=1, + batchsize=10, + ) + manager.start() + + # Get Worker PID + pid = manager.worker_children[0].pid + + # Kill Worker and wait to replace + manager.worker_children[0].shutdown() + time.sleep(0.1) + manager.replace_workers() + + # Check Replacement + manager.worker_children[0].pid.shouldnt.equal(pid) + + # Cleanup + manager.stop() + + +@mock_sqs +@patch("pyqs.worker.sys") +def test_master_handles_signals(sys): + """ + Test simple managing process handles OS signals + """ + + # Setup SQS Queue + conn = boto3.client('sqs', region_name='us-east-1') + conn.create_queue(QueueName="tester") + + # Mock out sys.exit + sys.exit = Mock() + + # Have our inner method send our signal + def process_counts(): + os.kill(os.getpid(), signal.SIGTERM) + + # Setup Manager + manager = SimpleManagerWorker( + queue_prefixes=["tester"], worker_concurrency=1, interval=1, + batchsize=10, + ) + manager.process_counts = process_counts + manager._graceful_shutdown = MagicMock() + + # When we start and trigger a signal + manager.start() + manager.sleep() + + # Then we exit + sys.exit.assert_called_once_with(0) + + +@mock_sqs +def test_master_shuts_down_busy_process_workers(): + """ + Test simple managing process properly cleans up busy Process Workers + """ + # For debugging test + import sys + logger = logging.getLogger("pyqs") + logger.setLevel(logging.DEBUG) + stdout_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stdout_handler) + + # Setup SQS Queue + conn = boto3.client('sqs', region_name='us-east-1') + queue_url = conn.create_queue(QueueName="tester")['QueueUrl'] + + # Add Slow tasks + message = json.dumps({ + 'task': 'tests.tasks.sleeper', + 'args': [], + 'kwargs': { + 'message': 5, + }, + }) + + # Fill the queue (we need a lot of messages to trigger the bug) + for _ in range(20): + conn.send_message(QueueUrl=queue_url, MessageBody=message) + + # Create function to watch and kill stuck processes + def sleep_and_kill(pid): + import os + import signal + import time + # This sleep time is long enoug for 100 messages in queue + time.sleep(5) + try: + os.kill(pid, signal.SIGKILL) + except OSError: + # Return that we didn't need to kill the process + return True + else: + # Return that we needed to kill the process + return False + + # Setup Manager + manager = SimpleManagerWorker( + queue_prefixes=["tester"], worker_concurrency=1, interval=0.0, + batchsize=1, + ) + manager.start() + + # Give our processes a moment to start + time.sleep(1) + + # Setup Threading watcher + try: + # Try Python 2 Style + thread = ThreadWithReturnValue2( + target=sleep_and_kill, args=(manager.worker_children[0].pid,)) + thread.daemon = True + except TypeError: + # Use Python 3 Style + thread = ThreadWithReturnValue3( + target=sleep_and_kill, args=(manager.worker_children[0].pid,), + daemon=True, + ) + + thread.start() + + # Stop the Master Process + manager.stop() + + # Check if we had to kill the Process Worker or it exited gracefully + return_value = thread.join() + if not return_value: + raise Exception("Process Worker failed to quit!") + + +@mock_sqs +def test_manager_picks_up_new_queues(): + """ + Test that the simple manager will recognize new SQS queues have been added + """ + + # Setup SQS Queue + conn = boto3.client('sqs', region_name='us-east-1') + + # Setup Manager + manager = SimpleManagerWorker( + queue_prefixes=["tester"], worker_concurrency=1, interval=1, + batchsize=10, + ) + manager.start() + + # No queues found + len(manager.worker_children).should.equal(0) + + # Create the queue + conn.create_queue(QueueName="tester") + manager.check_for_new_queues() + + # The manager should have seen the new queue was created and add a reader + len(manager.worker_children).should.equal(1) + manager.worker_children[0].queue_url.should.equal( + "https://queue.amazonaws.com/123456789012/tester") + + # Cleanup + manager.stop() diff --git a/tests/test_simple_worker.py b/tests/test_simple_worker.py new file mode 100644 index 0000000..1090b76 --- /dev/null +++ b/tests/test_simple_worker.py @@ -0,0 +1,651 @@ +import json +import logging +import time + +import boto3 +from botocore.exceptions import ClientError +from moto import mock_sqs +from mock import patch, Mock +from pyqs.worker import ( + SimpleManagerWorker, BaseProcessWorker, SimpleProcessWorker, + MESSAGE_DOWNLOAD_BATCH_SIZE +) +from pyqs.utils import decode_message +from pyqs.events import register_event +from tests.utils import MockLoggingHandler, clear_events_registry + +BATCHSIZE = 10 +INTERVAL = 0.1 + + +def _create_packed_message(task_name): + # Setup SQS Queue + conn = boto3.client('sqs', region_name='us-east-1') + queue_url = conn.create_queue(QueueName="tester")['QueueUrl'] + + # Build the SQS message + message = { + 'Body': json.dumps({ + 'task': task_name, + 'args': [], + 'kwargs': { + 'message': 'Test message', + }, + }), + "ReceiptHandle": "receipt-1234", + "MessageId": "message-id-1", + } + + packed_message = { + "queue": queue_url, + "message": message, + "start_time": time.time(), + "timeout": 30, + } + + return queue_url, packed_message + + +def _add_messages_to_sqs(task_name, num): + # Setup SQS Queue + conn = boto3.client('sqs', region_name='us-east-1') + queue_url = conn.create_queue(QueueName="tester")['QueueUrl'] + + # Build the SQS message + message = json.dumps({ + 'task': task_name, + 'args': [], + 'kwargs': { + 'message': 'Test message', + }, + }) + + for i in range(num): + conn.send_message(QueueUrl=queue_url, MessageBody=message) + + return queue_url + + +@mock_sqs +def test_worker_reads_messages_from_sqs(): + """ + Test simple worker reads from sqs queue + """ + queue_url = _add_messages_to_sqs('tests.tasks.index_incrementer', 1) + + worker = SimpleProcessWorker(queue_url, INTERVAL, BATCHSIZE, parent_id=1) + messages = worker.read_message() + + found_message_body = decode_message(messages[0]) + found_message_body.should.equal({ + 'task': 'tests.tasks.index_incrementer', + 'args': [], + 'kwargs': { + 'message': 'Test message', + }, + }) + + +@mock_sqs +def test_worker_throws_error_when_exceeding_max_number_of_messages_for_read(): + """ + Test simple worker reads from sqs queue and throws error when batchsize + greater than 10 + """ + queue_url = _add_messages_to_sqs('tests.tasks.index_incrementer', 1) + + worker = SimpleProcessWorker(queue_url, INTERVAL, 20, parent_id=1) + + error_msg = "Value 20 for parameter MaxNumberOfMessages is invalid" + + try: + worker.read_message() + except ClientError as exc: + str(exc).should.contain(error_msg) + + +@mock_sqs +def test_worker_reads_max_messages_from_sqs(): + """ + Test simple worker reads at maximum 10 message from sqs queue + """ + _add_messages_to_sqs('tests.tasks.index_incrementer', 12) + + manager = SimpleManagerWorker( + queue_prefixes=['tester'], worker_concurrency=1, interval=INTERVAL, + batchsize=20, + ) + worker = manager.worker_children[0] + messages = worker.read_message() + + messages.should.have.length_of(BATCHSIZE) + + +@mock_sqs +def test_worker_fills_internal_queue_from_celery_task(): + """ + Test simple worker reads from sqs queue with celery tasks + """ + conn = boto3.client('sqs', region_name='us-east-1') + queue_url = conn.create_queue(QueueName="tester")['QueueUrl'] + + message = ( + '{"body": "KGRwMApTJ3Rhc2snCnAxClMndGVzdHMudGFza3MuaW5kZXhfa' + 'W5jcmVtZW50ZXInCnAyCnNTJ2Fy\\nZ3MnCnAzCihscDQKc1Mna3dhcmdzJw' + 'pwNQooZHA2ClMnbWVzc2FnZScKcDcKUydUZXN0IG1lc3Nh\\nZ2UyJwpwOAp' + 'zcy4=\\n", "some stuff": "asdfasf"}' + ) + conn.send_message(QueueUrl=queue_url, MessageBody=message) + + worker = SimpleProcessWorker(queue_url, INTERVAL, BATCHSIZE, parent_id=1) + messages = worker.read_message() + + found_message_body = decode_message(messages[0]) + found_message_body.should.equal({ + 'task': 'tests.tasks.index_incrementer', + 'args': [], + 'kwargs': { + 'message': 'Test message2', + }, + }) + + +@mock_sqs +def test_worker_processes_tasks_and_logs_correctly(): + """ + Test simple worker processes logs INFO correctly + """ + # Setup logging + logger = logging.getLogger("pyqs") + del logger.handlers[:] + logger.handlers.append(MockLoggingHandler()) + + queue_url, packed_message = _create_packed_message( + 'tests.tasks.index_incrementer' + ) + + # Process message + worker = SimpleProcessWorker(queue_url, INTERVAL, BATCHSIZE, parent_id=1) + worker.process_message(packed_message) + + # Check output + kwargs = json.loads(packed_message["message"]['Body'])['kwargs'] + expected_result = ( + u"Processed task tests.tasks.index_incrementer in 0.0000 seconds " + "with args: [] and kwargs: {}".format(kwargs) + ) + logger.handlers[0].messages['info'].should.equal([expected_result]) + + +@mock_sqs +def test_worker_processes_tasks_and_logs_warning_correctly(): + """ + Test simple worker processes logs WARNING correctly + """ + # Setup logging + logger = logging.getLogger("pyqs") + del logger.handlers[:] + logger.handlers.append(MockLoggingHandler()) + + # Setup SQS Queue + conn = boto3.client('sqs', region_name='us-east-1') + queue_url = conn.create_queue(QueueName="tester")['QueueUrl'] + + # Build the SQS Message + message = { + 'Body': json.dumps({ + 'task': 'tests.tasks.index_incrementer', + 'args': [], + 'kwargs': { + 'message': 23, + }, + }), + "ReceiptHandle": "receipt-1234", + "MessageId": "message-id-1", + } + + packed_message = { + "queue": queue_url, + "message": message, + "start_time": time.time(), + "timeout": 30, + } + + # Process message + worker = SimpleProcessWorker(queue_url, INTERVAL, BATCHSIZE, parent_id=1) + worker.process_message(packed_message) + + # Check output + kwargs = json.loads(message['Body'])['kwargs'] + msg1 = ( + "Task tests.tasks.index_incrementer raised error in 0.0000 seconds: " + "with args: [] and kwargs: {}: " + "Traceback (most recent call last)".format(kwargs) + ) # noqa + logger.handlers[0].messages['error'][0].lower().should.contain( + msg1.lower()) + msg2 = ( + 'ValueError: Need to be given basestring, ' + 'was given 23' + ) # noqa + logger.handlers[0].messages['error'][0].lower().should.contain( + msg2.lower()) + + +@patch("pyqs.worker.os") +def test_parent_process_death(os): + """ + Test simple worker processes recognize parent process death + """ + os.getppid.return_value = 123 + + worker = BaseProcessWorker(parent_id=1) + worker.parent_is_alive().should.be.false + + +@patch("pyqs.worker.os") +def test_parent_process_alive(os): + """ + Test simple worker processes recognize when parent process is alive + """ + os.getppid.return_value = 1234 + + worker = BaseProcessWorker(parent_id=1234) + worker.parent_is_alive().should.be.true + + +@mock_sqs +@patch("pyqs.worker.os") +def test_read_message_with_parent_process_alive_and_should_not_exit(os): + """ + Test simple worker processes do not exit when parent is alive and shutdown + is not set when reading message + """ + # Setup SQS Queue + conn = boto3.client('sqs', region_name='us-east-1') + queue_url = conn.create_queue(QueueName="tester")['QueueUrl'] + + # Setup PPID + os.getppid.return_value = 1 + + # Setup dummy read_message + def read_message(): + raise Exception("Called") + + # When I have a parent process, and shutdown is not set + worker = SimpleProcessWorker(queue_url, INTERVAL, BATCHSIZE, parent_id=1) + worker.read_message = read_message + + # Then read_message() is reached + worker.run.when.called_with().should.throw(Exception, "Called") + + +@mock_sqs +@patch("pyqs.worker.os") +def test_read_message_with_parent_process_alive_and_should_exit(os): + """ + Test simple worker processes exit when parent is alive and shutdown is set + when reading message + """ + # Setup SQS Queue + conn = boto3.client('sqs', region_name='us-east-1') + queue_url = conn.create_queue(QueueName="tester")['QueueUrl'] + + # Setup PPID + os.getppid.return_value = 1234 + + # When I have a parent process, and shutdown is set + worker = SimpleProcessWorker(queue_url, INTERVAL, BATCHSIZE, parent_id=1) + worker.read_message = Mock() + worker.shutdown() + + # Then I return from run() + worker.run().should.be.none + + +@mock_sqs +@patch("pyqs.worker.os") +def test_read_message_with_parent_process_dead_and_should_not_exit(os): + """ + Test simple worker processes exit when parent is dead and shutdown is not + set when reading messages + """ + # Setup SQS Queue + conn = boto3.client('sqs', region_name='us-east-1') + queue_url = conn.create_queue(QueueName="tester")['QueueUrl'] + + # Setup PPID + os.getppid.return_value = 123 + + # When I have no parent process, and shutdown is not set + worker = SimpleProcessWorker(queue_url, INTERVAL, BATCHSIZE, parent_id=1) + worker.read_message = Mock() + + # Then I return from run() + worker.run().should.be.none + + +@mock_sqs +@patch("pyqs.worker.os") +def test_process_message_with_parent_process_alive_and_should_not_exit(os): + """ + Test simple worker processes do not exit when parent is alive and shutdown + is not set when processing message + """ + queue_url = _add_messages_to_sqs('tests.tasks.index_incrementer', 1) + + # Setup PPID + os.getppid.return_value = 1 + + # Setup dummy read_message + def process_message(packed_message): + raise Exception("Called") + + # When I have a parent process, and shutdown is not set + worker = SimpleProcessWorker(queue_url, INTERVAL, BATCHSIZE, parent_id=1) + worker.process_message = process_message + + # Then process_message() is reached + worker.run.when.called_with().should.throw(Exception, "Called") + + +@mock_sqs +@patch("pyqs.worker.os") +def test_process_message_with_parent_process_dead_and_should_not_exit(os): + """ + Test simple worker processes exit when parent is dead and shutdown is not + set when processing message + """ + + queue_url = _add_messages_to_sqs('tests.tasks.index_incrementer', 1) + + # Setup PPID + os.getppid.return_value = 123 + + # When I have a parent process, and shutdown is not set + worker = SimpleProcessWorker(queue_url, INTERVAL, BATCHSIZE, parent_id=1) + worker.process_message = Mock() + + # Then process_message() is reached + worker.run().should.be.none + + +@mock_sqs +@patch("pyqs.worker.os") +def test_process_message_with_parent_process_alive_and_should_exit(os): + """ + Test simple worker processes exit when parent is alive and shutdown is set + set when processing message + """ + + queue_url = _add_messages_to_sqs('tests.tasks.index_incrementer', 1) + + # Setup PPID + os.getppid.return_value = 1 + + # When I have a parent process, and shutdown is not set + worker = SimpleProcessWorker(queue_url, INTERVAL, BATCHSIZE, parent_id=1) + worker.process_message = Mock() + worker.shutdown() + + # Then process_message() is reached + worker.run().should.be.none + + +@mock_sqs +@patch("pyqs.worker.os") +def test_worker_processes_shuts_down_after_processing_its_max_number_of_msgs( + os): + """ + Test simple worker processes shutdown after processing maximum number + of messages + """ + os.getppid.return_value = 1 + + queue_url = _add_messages_to_sqs('tests.tasks.index_incrementer', 2) + + # When I Process messages + worker = SimpleProcessWorker(queue_url, INTERVAL, BATCHSIZE, parent_id=1) + worker._messages_to_process_before_shutdown = 2 + + # Then I return from run() + worker.run().should.be.none + + +@mock_sqs +def test_worker_negative_batch_size(): + """ + Test simple workers with negative batch sizes + """ + BATCHSIZE = -1 + CONCURRENCY = 1 + QUEUE_PREFIX = "tester" + INTERVAL = 0.0 + conn = boto3.client('sqs', region_name='us-east-1') + conn.create_queue(QueueName="tester")['QueueUrl'] + + worker = SimpleManagerWorker( + QUEUE_PREFIX, + CONCURRENCY, + INTERVAL, + BATCHSIZE + ) + worker.batchsize.should.equal(1) + + +@mock_sqs +def test_worker_to_large_batch_size(): + """ + Test simple workers with too large of a batch size + """ + BATCHSIZE = 10000 + CONCURRENCY = 1 + QUEUE_PREFIX = "tester" + INTERVAL = 0.0 + conn = boto3.client('sqs', region_name='us-east-1') + conn.create_queue(QueueName="tester")['QueueUrl'] + + worker = SimpleManagerWorker( + QUEUE_PREFIX, + CONCURRENCY, + INTERVAL, + BATCHSIZE + ) + worker.batchsize.should.equal(MESSAGE_DOWNLOAD_BATCH_SIZE) + + +@clear_events_registry +@mock_sqs +def test_worker_processes_tasks_with_pre_process_callback(): + """ + Test simple worker runs registered callbacks when processing a message + """ + + queue_url, packed_message = _create_packed_message( + 'tests.tasks.index_incrementer' + ) + + # Declare this so it can be checked as a side effect + # to pre_process_with_side_effect + contexts = [] + + def pre_process_with_side_effect(context): + contexts.append(context) + + # When we have a registered pre_process callback + register_event("pre_process", pre_process_with_side_effect) + + worker = SimpleProcessWorker(queue_url, INTERVAL, BATCHSIZE, parent_id=1) + worker.process_message(packed_message) + + pre_process_context = contexts[0] + + # We should run the callback with the task context + pre_process_context['task_name'].should.equal('index_incrementer') + pre_process_context['args'].should.equal([]) + pre_process_context['kwargs'].should.equal({'message': 'Test message'}) + pre_process_context['full_task_path'].should.equal( + 'tests.tasks.index_incrementer' + ) + pre_process_context['queue_url'].should.equal( + 'https://queue.amazonaws.com/123456789012/tester' + ) + pre_process_context['timeout'].should.equal(30) + + assert 'fetch_time' in pre_process_context + assert 'receipt_handle' in pre_process_context + assert 'status' not in pre_process_context + + +@clear_events_registry +@mock_sqs +def test_worker_processes_tasks_with_post_process_callback_success(): + """ + Test simple worker runs registered callbacks when + processing a message and it succeeds + """ + + queue_url, packed_message = _create_packed_message( + 'tests.tasks.index_incrementer' + ) + + # Declare this so it can be checked as a side effect + # to post_process_with_side_effect + contexts = [] + + def post_process_with_side_effect(context): + contexts.append(context) + + # When we have a registered post_process callback + register_event("post_process", post_process_with_side_effect) + + worker = SimpleProcessWorker(queue_url, INTERVAL, BATCHSIZE, parent_id=1) + worker.process_message(packed_message) + + post_process_context = contexts[0] + + # We should run the callback with the task context + post_process_context['task_name'].should.equal('index_incrementer') + post_process_context['args'].should.equal([]) + post_process_context['kwargs'].should.equal({'message': 'Test message'}) + post_process_context['full_task_path'].should.equal( + 'tests.tasks.index_incrementer' + ) + post_process_context['queue_url'].should.equal( + 'https://queue.amazonaws.com/123456789012/tester' + ) + post_process_context['timeout'].should.equal(30) + post_process_context['status'].should.equal('success') + + assert 'fetch_time' in post_process_context + assert 'receipt_handle' in post_process_context + assert 'exception' not in post_process_context + + +@clear_events_registry +@mock_sqs +def test_worker_processes_tasks_with_post_process_callback_exception(): + """ + Test simple worker runs registered callbacks when processing + a message and it fails + """ + + queue_url, packed_message = _create_packed_message( + 'tests.tasks.exception_task' + ) + + # Declare this so it can be checked as a side effect + # to post_process_with_side_effect + contexts = [] + + def post_process_with_side_effect(context): + contexts.append(context) + + # When we have a registered post_process callback + register_event("post_process", post_process_with_side_effect) + + worker = SimpleProcessWorker(queue_url, INTERVAL, BATCHSIZE, parent_id=1) + worker.process_message(packed_message) + + post_process_context = contexts[0] + + # We should run the callback with the task context + post_process_context['task_name'].should.equal('exception_task') + post_process_context['args'].should.equal([]) + post_process_context['kwargs'].should.equal({'message': 'Test message'}) + post_process_context['full_task_path'].should.equal( + 'tests.tasks.exception_task' + ) + post_process_context['queue_url'].should.equal( + 'https://queue.amazonaws.com/123456789012/tester' + ) + post_process_context['timeout'].should.equal(30) + post_process_context['status'].should.equal('exception') + + assert 'fetch_time' in post_process_context + assert 'receipt_handle' in post_process_context + assert 'exception' in post_process_context + + +@clear_events_registry +@mock_sqs +def test_worker_processes_tasks_with_pre_and_post_process(): + """ + Test worker runs registered callbacks when processing a message + """ + + queue_url, packed_message = _create_packed_message( + 'tests.tasks.index_incrementer' + ) + + # Declare these so they can be checked as a side effect to the callbacks + contexts = [] + + def pre_process_with_side_effect(context): + contexts.append(context) + + def post_process_with_side_effect(context): + contexts.append(context) + + # When we have a registered pre_process and post_process callback + register_event("pre_process", pre_process_with_side_effect) + register_event("post_process", post_process_with_side_effect) + + worker = SimpleProcessWorker(queue_url, INTERVAL, BATCHSIZE, parent_id=1) + worker.process_message(packed_message) + + pre_process_context = contexts[0] + + # We should run the callbacks with the right task contexts + pre_process_context['task_name'].should.equal('index_incrementer') + pre_process_context['args'].should.equal([]) + pre_process_context['kwargs'].should.equal({'message': 'Test message'}) + pre_process_context['full_task_path'].should.equal( + 'tests.tasks.index_incrementer' + ) + pre_process_context['queue_url'].should.equal( + 'https://queue.amazonaws.com/123456789012/tester' + ) + pre_process_context['timeout'].should.equal(30) + + assert 'fetch_time' in pre_process_context + assert 'receipt_handle' in pre_process_context + assert 'status' not in pre_process_context + + post_process_context = contexts[1] + + post_process_context['task_name'].should.equal('index_incrementer') + post_process_context['args'].should.equal([]) + post_process_context['kwargs'].should.equal({'message': 'Test message'}) + post_process_context['full_task_path'].should.equal( + 'tests.tasks.index_incrementer' + ) + post_process_context['queue_url'].should.equal( + 'https://queue.amazonaws.com/123456789012/tester' + ) + post_process_context['timeout'].should.equal(30) + post_process_context['status'].should.equal('success') + + assert 'fetch_time' in post_process_context + assert 'receipt_handle' in post_process_context + assert 'exception' not in post_process_context diff --git a/tests/test_worker.py b/tests/test_worker.py index 996eb39..459a835 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -17,13 +17,53 @@ MESSAGE_DOWNLOAD_BATCH_SIZE, ) from pyqs.utils import decode_message +from pyqs.events import register_event from tests.tasks import task_results -from tests.utils import MockLoggingHandler +from tests.utils import MockLoggingHandler, clear_events_registry BATCHSIZE = 10 INTERVAL = 0.1 +def _add_message_to_internal_queue(task_name): + # Setup SQS Queue + conn = boto3.client('sqs', region_name='us-east-1') + queue_url = conn.create_queue(QueueName="tester")['QueueUrl'] + + # Build the SQS message + message = { + 'Body': json.dumps({ + 'task': task_name, + 'args': [], + 'kwargs': { + 'message': 'Test message', + }, + }), + "ReceiptHandle": "receipt-1234", + "MessageId": "message-id-1", + } + # Add message to queue + internal_queue = Queue() + internal_queue.put( + { + "message": message, + "queue": queue_url, + "start_time": time.time(), + "timeout": 30, + } + ) + return internal_queue + + +def _check_internal_queue_is_empty(internal_queue): + try: + internal_queue.get(timeout=1) + except Empty: + pass + else: + raise AssertionError("The internal queue should be empty") + + @mock_sqs def test_worker_fills_internal_queue(): """ @@ -145,6 +185,7 @@ def test_worker_processes_tasks_from_internal_queue(): }, }), "ReceiptHandle": "receipt-1234", + "MessageId": "message-id-1", } # Add message to queue @@ -237,6 +278,7 @@ def test_worker_processes_tasks_and_logs_correctly(): }, }), "ReceiptHandle": "receipt-1234", + "MessageId": "message-id-1", } # Add message to internal queue @@ -287,6 +329,7 @@ def test_worker_processes_tasks_and_logs_warning_correctly(): }, }), "ReceiptHandle": "receipt-1234", + "MessageId": "message-id-1", } # Add message to internal queue @@ -508,6 +551,7 @@ def test_worker_processes_shuts_down_after_processing_its_max_number_of_msgs( }, }), "ReceiptHandle": "receipt-1234", + "MessageId": "message-id-1", } # Add message to internal queue @@ -573,6 +617,7 @@ def test_worker_processes_discard_tasks_that_exceed_their_visibility_timeout(): }, }), "ReceiptHandle": "receipt-1234", + "MessageId": "message-id-1", } # Add message to internal queue with timeout of 0 that started long ago @@ -621,6 +666,7 @@ def test_worker_processes_only_incr_processed_counter_if_a_msg_was_processed(): }, }), "ReceiptHandle": "receipt-1234", + "MessageId": "message-id-1", } # Add message to internal queue @@ -688,3 +734,207 @@ def test_worker_to_large_batch_size(): worker = ManagerWorker(QUEUE_PREFIX, CONCURRENCY, INTERVAL, BATCHSIZE) worker.batchsize.should.equal(MESSAGE_DOWNLOAD_BATCH_SIZE) + + +@clear_events_registry +@mock_sqs +def test_worker_processes_tasks_with_pre_process_callback(): + """ + Test worker runs registered callbacks when processing a message + """ + + # Declare this so it can be checked as a side effect + # to pre_process_with_side_effect + contexts = [] + + def pre_process_with_side_effect(context): + contexts.append(context) + + # When we have a registered pre_process callback + register_event("pre_process", pre_process_with_side_effect) + + # And we process a message + internal_queue = _add_message_to_internal_queue( + 'tests.tasks.index_incrementer' + ) + worker = ProcessWorker(internal_queue, INTERVAL, parent_id=1) + worker.process_message() + + pre_process_context = contexts[0] + + # We should run the callback with the task context + pre_process_context['task_name'].should.equal('index_incrementer') + pre_process_context['args'].should.equal([]) + pre_process_context['kwargs'].should.equal({'message': 'Test message'}) + pre_process_context['full_task_path'].should.equal( + 'tests.tasks.index_incrementer' + ) + pre_process_context['queue_url'].should.equal( + 'https://queue.amazonaws.com/123456789012/tester' + ) + pre_process_context['timeout'].should.equal(30) + + assert 'fetch_time' in pre_process_context + assert 'status' not in pre_process_context + + # And the internal queue should be empty + _check_internal_queue_is_empty(internal_queue) + + +@clear_events_registry +@mock_sqs +def test_worker_processes_tasks_with_post_process_callback_success(): + """ + Test worker runs registered callbacks when + processing a message and it succeeds + """ + + # Declare this so it can be checked as a side effect + # to post_process_with_side_effect + contexts = [] + + def post_process_with_side_effect(context): + contexts.append(context) + + # When we have a registered post_process callback + register_event("post_process", post_process_with_side_effect) + + # And we process a message + internal_queue = _add_message_to_internal_queue( + 'tests.tasks.index_incrementer' + ) + worker = ProcessWorker(internal_queue, INTERVAL, parent_id=1) + worker.process_message() + + post_process_context = contexts[0] + + # We should run the callback with the task context + post_process_context['task_name'].should.equal('index_incrementer') + post_process_context['args'].should.equal([]) + post_process_context['kwargs'].should.equal({'message': 'Test message'}) + post_process_context['full_task_path'].should.equal( + 'tests.tasks.index_incrementer' + ) + post_process_context['queue_url'].should.equal( + 'https://queue.amazonaws.com/123456789012/tester' + ) + post_process_context['timeout'].should.equal(30) + post_process_context['status'].should.equal('success') + + assert 'fetch_time' in post_process_context + assert 'exception' not in post_process_context + + # And the internal queue should be empty + _check_internal_queue_is_empty(internal_queue) + + +@clear_events_registry +@mock_sqs +def test_worker_processes_tasks_with_post_process_callback_exception(): + """ + Test worker runs registered callbacks when processing + a message and it fails + """ + + # Declare this so it can be checked as a side effect + # to post_process_with_side_effect + contexts = [] + + def post_process_with_side_effect(context): + contexts.append(context) + + # When we have a registered post_process callback + register_event("post_process", post_process_with_side_effect) + + # And we process a message + internal_queue = _add_message_to_internal_queue( + 'tests.tasks.exception_task' + ) + worker = ProcessWorker(internal_queue, INTERVAL, parent_id=1) + worker.process_message() + + post_process_context = contexts[0] + + # We should run the callback with the task context + post_process_context['task_name'].should.equal('exception_task') + post_process_context['args'].should.equal([]) + post_process_context['kwargs'].should.equal({'message': 'Test message'}) + post_process_context['full_task_path'].should.equal( + 'tests.tasks.exception_task' + ) + post_process_context['queue_url'].should.equal( + 'https://queue.amazonaws.com/123456789012/tester' + ) + post_process_context['timeout'].should.equal(30) + post_process_context['status'].should.equal('exception') + + assert 'fetch_time' in post_process_context + assert 'exception' in post_process_context + + # And the internal queue should be empty + _check_internal_queue_is_empty(internal_queue) + + +@clear_events_registry +@mock_sqs +def test_worker_processes_tasks_with_pre_and_post_process(): + """ + Test worker runs registered callbacks when processing a message + """ + + # Declare these so they can be checked as a side effect to the callbacks + contexts = [] + + def pre_process_with_side_effect(context): + contexts.append(context) + + def post_process_with_side_effect(context): + contexts.append(context) + + # When we have a registered pre_process and post_process callback + register_event("pre_process", pre_process_with_side_effect) + register_event("post_process", post_process_with_side_effect) + + # And we process a message + internal_queue = _add_message_to_internal_queue( + 'tests.tasks.index_incrementer' + ) + worker = ProcessWorker(internal_queue, INTERVAL, parent_id=1) + worker.process_message() + + pre_process_context = contexts[0] + + # We should run the callbacks with the right task contexts + pre_process_context['task_name'].should.equal('index_incrementer') + pre_process_context['args'].should.equal([]) + pre_process_context['kwargs'].should.equal({'message': 'Test message'}) + pre_process_context['full_task_path'].should.equal( + 'tests.tasks.index_incrementer' + ) + pre_process_context['queue_url'].should.equal( + 'https://queue.amazonaws.com/123456789012/tester' + ) + pre_process_context['timeout'].should.equal(30) + + assert 'fetch_time' in pre_process_context + assert 'status' not in pre_process_context + + post_process_context = contexts[1] + + post_process_context['task_name'].should.equal('index_incrementer') + post_process_context['args'].should.equal([]) + post_process_context['kwargs'].should.equal({'message': 'Test message'}) + post_process_context['full_task_path'].should.equal( + 'tests.tasks.index_incrementer' + ) + post_process_context['queue_url'].should.equal( + 'https://queue.amazonaws.com/123456789012/tester' + ) + post_process_context['timeout'].should.equal(30) + post_process_context['status'].should.equal('success') + + assert 'fetch_time' in post_process_context + assert 'exception' not in post_process_context + + # And the internal queue should be empty + _check_internal_queue_is_empty(internal_queue) diff --git a/tests/utils.py b/tests/utils.py index 3e9a9b4..0243beb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,9 +2,12 @@ from __future__ import unicode_literals import logging +from functools import wraps from threading import Thread +from pyqs import events + class MockLoggingHandler(logging.Handler): """Mock logging handler to check for expected logs.""" @@ -55,3 +58,13 @@ def run(self): def join(self): Thread.join(self) return self._return + + +def clear_events_registry(fn): + """Clear the global events registry before each test.""" + + @wraps(fn) + def wrapper(*args, **kwargs): + events.clear_events() + return fn(*args, **kwargs) + return wrapper