From c45e3fac421bdaaafec98a3a87273daf9e4798ce Mon Sep 17 00:00:00 2001 From: Martin <15851033+malvex@users.noreply.github.com> Date: Tue, 10 Feb 2026 20:27:03 +0100 Subject: [PATCH 1/4] feat: implement workflow support --- docker-compose.test.yml | 3 +- examples/workflows/fan_in_fan_out.py | 58 +++++++ examples/workflows/simple.py | 33 ++++ src/sheppy/__init__.py | 2 + src/sheppy/_localkv/client.py | 4 + src/sheppy/_localkv/server.py | 19 +- src/sheppy/_utils/functions.py | 16 ++ src/sheppy/_workflow.py | 249 +++++++++++++++++++++++++++ src/sheppy/backend/base.py | 24 +++ src/sheppy/backend/local.py | 82 +++++++++ src/sheppy/backend/memory.py | 60 +++++++ src/sheppy/backend/redis.py | 229 ++++++++++++++++++------ src/sheppy/models.py | 5 + src/sheppy/queue.py | 92 ++++++++++ src/sheppy/task_factory.py | 30 ++-- src/sheppy/testqueue.py | 40 +++++ src/sheppy/worker.py | 35 ++++ tests/contract/test_backend.py | 107 ++++++++++++ 18 files changed, 1018 insertions(+), 70 deletions(-) create mode 100644 examples/workflows/fan_in_fan_out.py create mode 100644 examples/workflows/simple.py create mode 100644 src/sheppy/_workflow.py diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 050e633..999b7d9 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -1,6 +1,7 @@ services: redis: - image: redis:6.2-bookworm + #image: redis:6.2-bookworm + image: redis:7.4-bookworm ports: - "6379:6379" volumes: diff --git a/examples/workflows/fan_in_fan_out.py b/examples/workflows/fan_in_fan_out.py new file mode 100644 index 0000000..7e1a111 --- /dev/null +++ b/examples/workflows/fan_in_fan_out.py @@ -0,0 +1,58 @@ +import asyncio + +from sheppy import Queue, RedisBackend, task +from sheppy._workflow import workflow + +ADMIN_EMAILS = [ + "admin1@example.com", + "admin2@example.com", + "admin3@example.com", +] + +@task +async def cleanup_old_data(days: int = 7): + if days > 7: # deterministic "random" failure + raise Exception("some random failure happened") + + return "everything ok!" + +@task +async def some_cleanup_at_the_end(): + return True + +@task +async def rollback_changes(): + return True + +@task +async def send_notification(to: str, subject: str): + print(f"Sending email to {to}, subject {subject}") + + +@workflow +def daily_cleanup(days_to_clean: int): + result_task = yield cleanup_old_data(days=days_to_clean) + + if result_task.error: + yield rollback_changes() + yield [send_notification(email, "Oh no, daily cleanup failed!") for email in ADMIN_EMAILS] + + raise Exception("Cleanup failed, notifications were sent") # fail the workflow + + if result_task.status == 'completed': + yield some_cleanup_at_the_end() + yield send_notification("devteam@example.com", "Cleanup finished") + + return "Daily cleanup finished successfully" + + raise Exception("not sure what happened!") + + +async def main(): + queue = Queue(RedisBackend()) + # await queue.add_workflow(daily_cleanup(7)) + await queue.add_workflow(daily_cleanup(30)) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/workflows/simple.py b/examples/workflows/simple.py new file mode 100644 index 0000000..9e4f2e8 --- /dev/null +++ b/examples/workflows/simple.py @@ -0,0 +1,33 @@ +import asyncio + +from sheppy import Queue, task +from sheppy._workflow import workflow + +ADMIN_EMAILS = [ + "admin1@example.com", + "admin2@example.com", + "admin3@example.com", +] + +@task +async def say_hello(name: str) -> str: + return f"Hello, {name}!" + + +@workflow +def example_workflow(names: list[str]): + t1 = yield say_hello("Alice") + t2 = yield say_hello("Bob") + tx = yield [say_hello(name) for name in names] # fan-out style + + return "\n".join([t1.result, t2.result] + [t.result for t in tx]) + + +async def main(): + queue = Queue("redis://") + wf = example_workflow(["Alex", "John"]) + await queue.add_workflow(wf) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/sheppy/__init__.py b/src/sheppy/__init__.py index f796f02..3adc91b 100644 --- a/src/sheppy/__init__.py +++ b/src/sheppy/__init__.py @@ -1,4 +1,6 @@ from ._utils.fastapi import Depends as Depends +from ._workflow import Workflow as Workflow +from ._workflow import workflow as workflow from .backend import Backend as Backend from .backend import BackendError as BackendError from .backend import LocalBackend as LocalBackend diff --git a/src/sheppy/_localkv/client.py b/src/sheppy/_localkv/client.py index a067245..92ecc14 100644 --- a/src/sheppy/_localkv/client.py +++ b/src/sheppy/_localkv/client.py @@ -102,6 +102,10 @@ async def list_len(self, key: str) -> int: r = await self._call("list_len", key=key) return r["count"] # type:ignore[no-any-return] + async def list_remove(self, key: str, value: str) -> bool: + r = await self._call("list_remove", key=key, value=value) + return r["removed"] # type:ignore[no-any-return] + # sorted list async def sorted_push(self, key: str, position: float, value: str) -> None: await self._call("sorted_push", key=key, position=position, value=value) diff --git a/src/sheppy/_localkv/server.py b/src/sheppy/_localkv/server.py index 0d00668..f5c6b47 100644 --- a/src/sheppy/_localkv/server.py +++ b/src/sheppy/_localkv/server.py @@ -66,7 +66,16 @@ def handle_command(cmd: str, args: dict[str, Any]) -> dict[str, Any]: return {"ok": True, "created": True} case "delete": - count = sum(1 for k in args["keys"] if store.kv.pop(k, None) is not None) + count = 0 + for k in args["keys"]: + if store.kv.pop(k, None) is not None: + count += 1 + if k in store.lists: + del store.lists[k] + count += 1 + if k in store.sorted_list: + del store.sorted_list[k] + count += 1 return {"ok": True, "count": count} case "keys": @@ -98,6 +107,14 @@ def handle_command(cmd: str, args: dict[str, Any]) -> dict[str, Any]: case "list_len": return {"ok": True, "count": len(store.lists[args["key"]])} + case "list_remove": + lst = store.lists[args["key"]] + value = args["value"] + if value in lst: + lst.remove(value) + return {"ok": True, "removed": True} + return {"ok": True, "removed": False} + # sorted list case "sorted_push": bisect.insort(store.sorted_list[args["key"]], SortedItem(args["position"], args["value"])) diff --git a/src/sheppy/_utils/functions.py b/src/sheppy/_utils/functions.py index e199431..61671a4 100644 --- a/src/sheppy/_utils/functions.py +++ b/src/sheppy/_utils/functions.py @@ -13,6 +13,22 @@ cache_main_module: str | None = None +def stringify_function(func: Callable[..., Any]) -> str: + _module = func.__module__ + # special case if the task is in the main python file that is executed + if _module == "__main__": + global cache_main_module + if not cache_main_module: + # this handles "python -m app.main" because with "-m" sys.argv[0] is absolute path + _main_path = os.path.relpath(sys.argv[0])[:-3] + # replace handles situations when user runs "python app/main.py" + cache_main_module = _main_path.replace(os.sep, ".") + + _module = cache_main_module + + return f"{_module}:{func.__name__}" + + def resolve_function(func: str, wrapped: bool = True) -> Callable[..., Any]: module_name = None function_name = None diff --git a/src/sheppy/_workflow.py b/src/sheppy/_workflow.py new file mode 100644 index 0000000..ec3c2d1 --- /dev/null +++ b/src/sheppy/_workflow.py @@ -0,0 +1,249 @@ +from collections.abc import Callable, Generator +from contextvars import ContextVar +from dataclasses import dataclass, field +from datetime import datetime, timezone +from functools import wraps +from typing import Any, NamedTuple, ParamSpec, TypeVar, overload +from uuid import UUID, uuid4 + +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field + +from ._utils.functions import resolve_function, stringify_function +from .models import Task + +P = ParamSpec('P') +R = TypeVar('R') + + +@dataclass +class WorkflowContext: + workflow_id: UUID + step_counter: int = 0 + stored_tasks: dict[int, UUID] = field(default_factory=dict) + task_order: list[str] = field(default_factory=list) + + def next_task_id(self) -> UUID: + step_idx = self.step_counter + self.step_counter += 1 + + if step_idx in self.stored_tasks: + return self.stored_tasks[step_idx] + + task_id = uuid4() + self.task_order.append(str(task_id)) + + return task_id + + +_workflow_context: ContextVar[WorkflowContext | None] = ContextVar('workflow_context', default=None) + + +def get_workflow_context() -> WorkflowContext | None: + return _workflow_context.get() + + +def set_workflow_context(ctx: WorkflowContext | None) -> None: + _workflow_context.set(ctx) + + +class Workflow(BaseModel): + model_config = ConfigDict(frozen=True, extra="forbid") + + id: UUID = Field(default_factory=uuid4) + func: str + args: tuple[Any, ...] = Field(default_factory=tuple) + kwargs: dict[str, Any] = Field(default_factory=dict) + task_order: list[str] = Field(default_factory=list) + processed_tasks: dict[str, Task] = Field(default_factory=dict) + pending_task_ids: list[str] = Field(default_factory=list) + completed: bool = False + final_result: Any = None + error: str | None = None + created_at: AwareDatetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + finished_at: AwareDatetime | None = None + + def get_stored_task_ids(self) -> dict[int, UUID]: + return {idx: UUID(tid) for idx, tid in enumerate(self.task_order)} + + def get_incomplete_task_ids(self) -> list[UUID]: + return [UUID(tid) for tid in self.pending_task_ids] + + def __repr__(self) -> str: + return ( + f"Workflow(id={self.id!r}, func={self.func!r}, " + f"completed={self.completed}, tasks={len(self.processed_tasks)})" + ) + + +class WorkflowResult(NamedTuple): + workflow: Workflow + pending_tasks: list[Task] + + +@overload +def workflow( + func: Callable[P, Generator[Task | list[Task], Task | list[Task], R]], / +) -> Callable[P, Workflow]: + ... + + +@overload +def workflow() -> Callable[ + [Callable[P, Generator[Task | list[Task], Task | list[Task], R]]], + Callable[P, Workflow] +]: + ... + + +def workflow( + func: Callable[P, Generator[Task | list[Task], Task | list[Task], R]] | None = None, +) -> ( + Callable[P, Workflow] | + Callable[ + [Callable[P, Generator[Task | list[Task], Task | list[Task], R]]], + Callable[P, Workflow] + ] +): + def decorator( + fn: Callable[P, Generator[Task | list[Task], Task | list[Task], R]] + ) -> Callable[P, Workflow]: + func_string = stringify_function(fn) + + @wraps(fn) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> Workflow: + return Workflow( + func=func_string, + args=args, + kwargs=kwargs, + ) + + return wrapper + + if func is not None: + return decorator(func) + + return decorator + + +class WorkflowRunner: + def __init__(self, workflow: Workflow, task_results: dict[UUID, Task] | None = None): + self.workflow = workflow + self.task_results = task_results or {} + + def run(self) -> WorkflowResult: + ctx = WorkflowContext( + workflow_id=self.workflow.id, + stored_tasks=self.workflow.get_stored_task_ids(), + task_order=list(self.workflow.task_order), + ) + _workflow_context.set(ctx) + + try: + gen = resolve_function(self.workflow.func)(*self.workflow.args, **self.workflow.kwargs) + + try: + result, pending = self._run_generator(gen) + except StopIteration as e: + result, pending = e.value, [] + + processed_tasks = dict(self.workflow.processed_tasks) + for task_id, task in self.task_results.items(): + if task.status in ['completed', 'failed']: + processed_tasks[str(task_id)] = task + + if pending: + return WorkflowResult( + workflow=self._make_workflow(ctx, processed_tasks, pending_task_ids=[str(t.id) for t in pending]), + pending_tasks=pending, + ) + + return WorkflowResult( + workflow=self._make_workflow(ctx, processed_tasks, completed=True, final_result=result), + pending_tasks=[], + ) + + except Exception as e: + return WorkflowResult( + workflow=self._make_workflow(ctx, error=f"{type(e).__name__}: {e}"), + pending_tasks=[], + ) + finally: + _workflow_context.set(None) + + def _run_generator(self, gen: Generator[Any, Any, Any]) -> tuple[Any, list[Task]]: + result_to_send: Any = None + + while True: + yielded = gen.send(result_to_send) + result_to_send, pending = self._handle_yield(yielded) + + if pending: + return None, pending + + def _handle_yield(self, yielded: Any) -> tuple[Any, list[Task]]: + items, is_list = (yielded, True) if isinstance(yielded, list) else ([yielded], False) + + if items and isinstance(items[0], Workflow): + results = [] + for wf in items: + result, pending = self._run_nested_workflow(wf) + if pending: + return None, pending + results.append(result) + return (results if is_list else results[0]), [] + + results = [] + pending = [] + for task in items: + result, is_pending = self._process_task(task) + results.append(result) + if is_pending: + pending.append(task) + + if pending: + return None, pending + return (results if is_list else results[0]), [] + + def _run_nested_workflow(self, wf: Workflow) -> tuple[Any, list[Task]]: + gen = resolve_function(wf.func)(*wf.args, **wf.kwargs) + try: + return self._run_generator(gen) + except StopIteration as e: + return e.value, [] + + def _process_task(self, task: Task) -> tuple[Task, bool]: + task_id = str(task.id) + cached = self.workflow.processed_tasks.get(task_id) + + if cached and (cached.status == 'completed' or cached.error): + return cached, False + + if task.id in self.task_results: + return self.task_results[task.id], False + + return task, True + + def _make_workflow( + self, + ctx: WorkflowContext, + processed_tasks: dict[str, Task] | None = None, + pending_task_ids: list[str] | None = None, + completed: bool = False, + final_result: Any = None, + error: str | None = None, + ) -> Workflow: + data = self.workflow.model_dump() + data['task_order'] = ctx.task_order + data['pending_task_ids'] = pending_task_ids or [] + data['completed'] = completed + + if processed_tasks is not None: + data['processed_tasks'] = {k: v.model_dump() for k, v in processed_tasks.items()} + if final_result is not None: + data['final_result'] = final_result + if error is not None: + data['error'] = error + if completed or error: + data['finished_at'] = datetime.now(timezone.utc) + + return Workflow.model_validate(data) diff --git a/src/sheppy/backend/base.py b/src/sheppy/backend/base.py index fae8371..9330234 100644 --- a/src/sheppy/backend/base.py +++ b/src/sheppy/backend/base.py @@ -89,3 +89,27 @@ async def delete_cron(self, queue_name: str, deterministic_id: str) -> bool: @abstractmethod async def get_crons(self, queue_name: str) -> list[dict[str, Any]]: pass + + @abstractmethod + async def store_workflow(self, queue_name: str, workflow_data: dict[str, Any]) -> bool: + pass + + @abstractmethod + async def get_workflows(self, queue_name: str, workflow_ids: list[str]) -> dict[str, dict[str, Any]]: + pass + + @abstractmethod + async def get_all_workflows(self, queue_name: str) -> list[dict[str, Any]]: + pass + + @abstractmethod + async def get_pending_workflows(self, queue_name: str) -> list[dict[str, Any]]: + pass + + @abstractmethod + async def delete_workflow(self, queue_name: str, workflow_id: str) -> bool: + pass + + @abstractmethod + async def mark_workflow_task_complete(self, queue_name: str, workflow_id: str, task_id: str) -> int: + pass diff --git a/src/sheppy/backend/local.py b/src/sheppy/backend/local.py index 3db5ab0..d94b85f 100644 --- a/src/sheppy/backend/local.py +++ b/src/sheppy/backend/local.py @@ -82,6 +82,15 @@ def _cron_prefix(self, queue_name: str) -> str: def _queue_prefix(self, queue_name: str) -> str: return f"sheppy:{queue_name}:" + def _workflow_key(self, queue_name: str, workflow_id: str) -> str: + return f"sheppy:{queue_name}:workflow:{workflow_id}" + + def _workflow_prefix(self, queue_name: str) -> str: + return f"sheppy:{queue_name}:workflow:" + + def _workflow_pending_key(self, queue_name: str, workflow_id: str) -> str: + return f"sheppy:{queue_name}:workflow:{workflow_id}:pending" + async def append(self, queue_name: str, tasks: list[dict[str, Any]], unique: bool = True) -> list[bool]: success = [] @@ -292,3 +301,76 @@ async def get_crons(self, queue_name: str) -> list[dict[str, Any]]: values = await self.client.get(cron_keys) return [json.loads(v) for v in values.values() if v] + + async def store_workflow(self, queue_name: str, workflow_data: dict[str, Any]) -> bool: + workflow_id = workflow_data["id"] + key = self._workflow_key(queue_name, workflow_id) + pending_key = self._workflow_pending_key(queue_name, workflow_id) + + await self.client.set({key: json.dumps(workflow_data)}) + + await self.client.delete([pending_key]) + pending_ids = workflow_data.get("pending_task_ids", []) + for task_id in pending_ids: + await self.client.list_push(pending_key, task_id) + + return True + + async def get_workflows(self, queue_name: str, workflow_ids: list[str]) -> dict[str, dict[str, Any]]: + if not workflow_ids: + return {} + + keys = [self._workflow_key(queue_name, wf_id) for wf_id in workflow_ids] + values = await self.client.get(keys) + + result = {} + for wf_json in values.values(): + if wf_json: + wf = json.loads(wf_json) + result[wf["id"]] = wf + return result + + async def get_all_workflows(self, queue_name: str) -> list[dict[str, Any]]: + workflow_keys = await self.client.keys(self._workflow_prefix(queue_name)) + if not workflow_keys: + return [] + + workflow_keys = [k for k in workflow_keys if ":pending" not in k] + if not workflow_keys: + return [] + + values = await self.client.get(workflow_keys) + return [json.loads(v) for v in values.values() if v] + + async def get_pending_workflows(self, queue_name: str) -> list[dict[str, Any]]: + all_workflows = await self.get_all_workflows(queue_name) + return [wf for wf in all_workflows if not wf.get("completed") and not wf.get("error")] + + async def delete_workflow(self, queue_name: str, workflow_id: str) -> bool: + key = self._workflow_key(queue_name, workflow_id) + pending_key = self._workflow_pending_key(queue_name, workflow_id) + count = await self.client.delete([key, pending_key]) + return count > 0 + + async def mark_workflow_task_complete(self, queue_name: str, workflow_id: str, task_id: str) -> int: + pending_key = self._workflow_pending_key(queue_name, workflow_id) + + pending_ids = await self.client.list_get(pending_key) + + if task_id not in pending_ids: + return -1 + + await self.client.list_remove(pending_key, task_id) + + remaining = await self.client.list_len(pending_key) + + workflow_key = self._workflow_key(queue_name, workflow_id) + workflow_values = await self.client.get([workflow_key]) + if workflow_key in workflow_values and workflow_values[workflow_key] is not None: + mypy_pls = workflow_values[workflow_key] + assert mypy_pls is not None + workflow = json.loads(mypy_pls) + workflow["pending_task_ids"] = [tid for tid in workflow.get("pending_task_ids", []) if tid != task_id] + await self.client.set({workflow_key: json.dumps(workflow)}) + + return remaining diff --git a/src/sheppy/backend/memory.py b/src/sheppy/backend/memory.py index afebb8e..252c88b 100644 --- a/src/sheppy/backend/memory.py +++ b/src/sheppy/backend/memory.py @@ -30,6 +30,7 @@ def __init__(self, self._pending: dict[str, list[str]] = defaultdict(list) self._scheduled: dict[str, list[ScheduledTask]] = defaultdict(list) self._crons: dict[str, dict[str, dict[str, Any]]] = defaultdict(dict) + self._workflows: dict[str, dict[str, dict[str, Any]]] = defaultdict(dict) # {QUEUE_NAME: {WORKFLOW_ID: workflow_data}} self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) # for thread-safety self._connected = False @@ -307,6 +308,65 @@ async def get_crons(self, queue_name: str) -> list[dict[str, Any]]: async with self._locks[queue_name]: return list(self._crons[queue_name].values()) + async def store_workflow(self, queue_name: str, workflow_data: dict[str, Any]) -> bool: + self._check_connected() + + async with self._locks[queue_name]: + self._workflows[queue_name][workflow_data["id"]] = workflow_data + return True + + async def get_workflows(self, queue_name: str, workflow_ids: list[str]) -> dict[str, dict[str, Any]]: + self._check_connected() + + async with self._locks[queue_name]: + results = {} + for wf_id in workflow_ids: + result = self._workflows[queue_name].get(wf_id) + if result: + results[wf_id] = result + return results + + async def get_all_workflows(self, queue_name: str) -> list[dict[str, Any]]: + self._check_connected() + + async with self._locks[queue_name]: + return list(self._workflows[queue_name].values()) + + async def get_pending_workflows(self, queue_name: str) -> list[dict[str, Any]]: + self._check_connected() + + async with self._locks[queue_name]: + return [ + wf for wf in self._workflows[queue_name].values() + if not wf.get("completed") and not wf.get("error") + ] + + async def delete_workflow(self, queue_name: str, workflow_id: str) -> bool: + self._check_connected() + + async with self._locks[queue_name]: + if workflow_id in self._workflows[queue_name]: + del self._workflows[queue_name][workflow_id] + return True + return False + + async def mark_workflow_task_complete(self, queue_name: str, workflow_id: str, task_id: str) -> int: + self._check_connected() + + async with self._locks[queue_name]: + workflow = self._workflows[queue_name].get(workflow_id) + if not workflow: + return -1 + + pending_ids = workflow.get("pending_task_ids", []) + if task_id not in pending_ids: + return -1 + + pending_ids = [tid for tid in pending_ids if tid != task_id] + workflow["pending_task_ids"] = pending_ids + + return len(pending_ids) + def _check_connected(self) -> None: if not self.is_connected: raise BackendError("Not connected") diff --git a/src/sheppy/backend/redis.py b/src/sheppy/backend/redis.py index 93ab3e8..464d720 100644 --- a/src/sheppy/backend/redis.py +++ b/src/sheppy/backend/redis.py @@ -23,7 +23,7 @@ def __init__( self, url: str = "redis://127.0.0.1:6379", consumer_group: str = "workers", - ttl: int | None = 24 * 60 * 60, # 24 hours + ttl: int | None = 30 * 24 * 60 * 60, # 30 days **kwargs: Any ): self.url = url @@ -75,9 +75,25 @@ def _scheduled_tasks_key(self, queue_name: str) -> str: return f"sheppy:scheduled:{queue_name}" def _cron_tasks_key(self, queue_name: str) -> str: - """Cron tasks (key prefix)""" + """Cron tasks (hset)""" return f"sheppy:cron:{queue_name}" + def _workflows_key(self, queue_name: str) -> str: + """Workflow metadata (hset)""" + return f"sheppy:workflows:{queue_name}" + + def _workflow_pending_key(self, queue_name: str, workflow_id: str) -> str: + """Pending task IDs for a workflow (set)""" + return f"sheppy:workflows:{queue_name}:{workflow_id}:pending" + + def _workflow_pending_index_key(self, queue_name: str) -> str: + """Index of pending workflow IDs (set)""" + return f"sheppy:workflows:{queue_name}:_pending" + + def _queues_registry_key(self) -> str: + """Registry of all queue names and their metadata (hset)""" + return "sheppy:_queues" + def _pending_tasks_key(self, queue_name: str) -> str: """Queued tasks to be processed (stream)""" return f"sheppy:pending:{queue_name}" @@ -86,6 +102,10 @@ def _finished_tasks_key(self, queue_name: str) -> str: """Notifications about finished tasks (stream)""" return f"sheppy:finished:{queue_name}" + def _completed_counter_key(self, queue_name: str) -> str: + """Cumulative count of completed tasks (string/integer)""" + return f"sheppy:completed:{queue_name}" + def _worker_metadata_key(self, queue_name: str) -> str: """Worker Metadata (key prefix)""" return f"sheppy:workers:{queue_name}" @@ -125,6 +145,8 @@ async def append(self, queue_name: str, tasks: list[dict[str, Any]], unique: boo try: async with self.client.pipeline(transaction=False) as pipe: + pipe.hsetnx(self._queues_registry_key(), queue_name, "{}") + for t in to_queue: _task_data = json.dumps(t) @@ -199,17 +221,16 @@ async def clear(self, queue_name: str) -> int: await self._ensure_consumer_group(pending_tasks_key) - keys = await self.client.keys(f"{tasks_metadata_key}:*") - if not keys: - return 0 - - count = await self.client.delete(*keys) + count = 0 + async for key in self.client.scan_iter(match=f"{tasks_metadata_key}:*", count=10000): + await self.client.delete(key) + count += 1 await self.client.xtrim(pending_tasks_key, maxlen=0) await self.client.delete(scheduled_key) - # await self.client.delete(tasks_metadata_key) + await self.client.hdel(self._queues_registry_key(), queue_name) # type: ignore[misc] - return int(count) + return count async def get_tasks(self, queue_name: str, task_ids: list[str]) -> dict[str,dict[str, Any]]: tasks_metadata_key = self._tasks_metadata_key(queue_name) @@ -232,14 +253,14 @@ async def schedule(self, queue_name: str, task_data: dict[str, Any], at: datetim return False try: - _task_data = json.dumps(task_data) - if not unique: - await self.client.set(f"{tasks_metadata_key}:{task_data['id']}", _task_data) + await self.client.set(f"{tasks_metadata_key}:{task_data['id']}", json.dumps(task_data)) - # add to sorted set with timestamp as score score = at.timestamp() - await self.client.zadd(scheduled_key, {_task_data: score}) + async with self.client.pipeline(transaction=False) as pipe: + pipe.zadd(scheduled_key, {task_data['id']: score}) + pipe.hsetnx(self._queues_registry_key(), queue_name, "{}") + await pipe.execute() return True except Exception as e: @@ -247,22 +268,28 @@ async def schedule(self, queue_name: str, task_data: dict[str, Any], at: datetim async def pop_scheduled(self, queue_name: str, now: datetime | None = None) -> list[dict[str, Any]]: scheduled_key = self._scheduled_tasks_key(queue_name) + tasks_metadata_key = self._tasks_metadata_key(queue_name) score = now.timestamp() if now else time() - task_jsons = await self.client.zrangebyscore(scheduled_key, 0, score) + task_id_entries = await self.client.zrangebyscore(scheduled_key, 0, score) - tasks = [] - for task_json in task_jsons: - removed = await self.client.zrem(scheduled_key, task_json) + claimed_ids = [] + for entry in task_id_entries: + removed = await self.client.zrem(scheduled_key, entry) if removed <= 0: # some other worker already got this task at the same time, skip continue - tasks.append(json.loads(task_json)) + task_id = entry.decode() if isinstance(entry, bytes) else entry + claimed_ids.append(task_id) - return tasks + if not claimed_ids: + return [] + + task_jsons = await self.client.mget([f"{tasks_metadata_key}:{tid}" for tid in claimed_ids]) + return [json.loads(tj) for tj in task_jsons if tj] async def store_result(self, queue_name: str, task_data: dict[str, Any]) -> bool: tasks_metadata_key = self._tasks_metadata_key(queue_name) @@ -288,13 +315,15 @@ async def store_result(self, queue_name: str, task_data: dict[str, Any]) -> bool # add to finished stream for get_result notifications if task_data["finished_at"] is not None: # only send notification on finished task (for retriable tasks we continue to wait) pipe.xadd(finished_tasks_key, {"task_id": task_data["id"]}, minid=min_id) + pipe.incr(self._completed_counter_key(queue_name)) # ack and delete the task from the stream (cleanup) if message_id: pipe.xack(pending_tasks_key, self.consumer_group, message_id) pipe.xdel(pending_tasks_key, message_id) - await (pipe.execute()) + await pipe.execute() + self._pending_messages.pop(task_data["id"], None) return True except Exception as e: raise BackendError(f"Failed to store task result: {e}") from e @@ -302,27 +331,28 @@ async def store_result(self, queue_name: str, task_data: dict[str, Any]) -> bool async def get_stats(self, queue_name: str) -> dict[str, int]: scheduled_tasks_key = self._scheduled_tasks_key(queue_name) pending_tasks_key = self._pending_tasks_key(queue_name) - finished_tasks_key = self._finished_tasks_key(queue_name) pending = await self.client.xlen(pending_tasks_key) - completed = await self.client.xlen(finished_tasks_key) + completed = await self.client.get(self._completed_counter_key(queue_name)) return { "pending": pending, - "completed": completed, + "completed": int(completed) if completed else 0, "scheduled": await self.client.zcard(scheduled_tasks_key), } async def get_all_tasks(self, queue_name: str) -> list[dict[str, Any]]: tasks_metadata_key = self._tasks_metadata_key(queue_name) - keys = await self.client.keys(f"{tasks_metadata_key}:*") + keys = [] + async for key in self.client.scan_iter(match=f"{tasks_metadata_key}:*", count=10000): + keys.append(key) + if not keys: return [] all_tasks_data = await self.client.mget(keys) - - return [json.loads(task_json) for task_json in all_tasks_data] + return [json.loads(task_json) for task_json in all_tasks_data if task_json] async def get_results(self, queue_name: str, task_ids: list[str], timeout: float | None = None) -> dict[str,dict[str, Any]]: tasks_metadata_key = self._tasks_metadata_key(queue_name) @@ -356,11 +386,11 @@ async def get_results(self, queue_name: str, task_ids: list[str], timeout: float return results # endless wait if timeout == 0 - deadline = None if timeout == 0 else asyncio.get_event_loop().time() + timeout + deadline = None if timeout == 0 else asyncio.get_running_loop().time() + timeout while True: if deadline: - remaining = deadline - asyncio.get_event_loop().time() + remaining = deadline - asyncio.get_running_loop().time() if remaining <= 0: raise TimeoutError(f"Did not complete within {timeout} seconds") else: @@ -369,7 +399,7 @@ async def get_results(self, queue_name: str, task_ids: list[str], timeout: float messages = await self.client.xread( {finished_tasks_key: last_id}, block=int(remaining * 1000), - count=100 + count=1000 ) if not messages: @@ -407,14 +437,11 @@ async def _ensure_consumer_group(self, stream_key: str) -> None: pass async def list_queues(self) -> dict[str, int]: - - queue_names = set() - - for key in await self.client.keys("sheppy:*:*"): - queue_names.add(key.decode().split(":")[2]) + queue_names = await self.client.hkeys(self._queues_registry_key()) # type: ignore[misc] queues = {} - for queue_name in sorted(queue_names): + for raw_name in sorted(queue_names): + queue_name = raw_name.decode() if isinstance(raw_name, bytes) else raw_name try: pending_count = await self.client.xlen(self._pending_tasks_key(queue_name)) queues[queue_name] = int(pending_count) @@ -425,28 +452,130 @@ async def list_queues(self) -> dict[str, int]: async def get_scheduled(self, queue_name: str) -> list[dict[str, Any]]: scheduled_key = self._scheduled_tasks_key(queue_name) + tasks_metadata_key = self._tasks_metadata_key(queue_name) - task_jsons = await self.client.zrange(scheduled_key, 0, -1, withscores=True) + task_ids = await self.client.zrange(scheduled_key, 0, -1) - tasks = [] - for task_json, _score in task_jsons: - tasks.append(json.loads(task_json)) + if not task_ids: + return [] - return tasks + keys = [ + f"{tasks_metadata_key}:{(tid.decode() if isinstance(tid, bytes) else tid)}" + for tid in task_ids + ] + task_jsons = await self.client.mget(keys) + return [json.loads(tj) for tj in task_jsons if tj] async def add_cron(self, queue_name: str, deterministic_id: str, task_cron: dict[str, Any]) -> bool: - cron_tasks_key = self._cron_tasks_key(queue_name) - return bool(await self.client.set(f"{cron_tasks_key}:{deterministic_id}", json.dumps(task_cron), nx=True)) + cron_key = self._cron_tasks_key(queue_name) + return bool(await self.client.hsetnx(cron_key, deterministic_id, json.dumps(task_cron))) # type: ignore[misc] async def delete_cron(self, queue_name: str, deterministic_id: str) -> bool: - cron_tasks_key = self._cron_tasks_key(queue_name) - return bool(await self.client.delete(f"{cron_tasks_key}:{deterministic_id}")) + cron_key = self._cron_tasks_key(queue_name) + return bool(await self.client.hdel(cron_key, deterministic_id)) # type: ignore[misc] async def get_crons(self, queue_name: str) -> list[dict[str, Any]]: - cron_tasks_key = self._cron_tasks_key(queue_name) - cron_tasks = await self.client.keys(f"{cron_tasks_key}:*") + cron_key = self._cron_tasks_key(queue_name) + cron_data = await self.client.hvals(cron_key) # type: ignore[misc] + return [json.loads(d) for d in cron_data] - if not cron_tasks: - return [] + async def store_workflow(self, queue_name: str, workflow_data: dict[str, Any]) -> bool: + workflows_key = self._workflows_key(queue_name) + workflow_id = workflow_data['id'] + pending_key = self._workflow_pending_key(queue_name, workflow_id) + pending_index_key = self._workflow_pending_index_key(queue_name) + pending_ids = workflow_data.get('pending_task_ids', []) + + try: + async with self.client.pipeline(transaction=True) as pipe: + pipe.hset(workflows_key, workflow_id, json.dumps(workflow_data)) + if self.ttl: + pipe.hexpire(workflows_key, self.ttl, workflow_id) + + if workflow_data.get('completed') or workflow_data.get('error'): + pipe.delete(pending_key) + pipe.srem(pending_index_key, workflow_id) + elif pending_ids: + pipe.sadd(pending_key, *pending_ids) + pipe.sadd(pending_index_key, workflow_id) + if self.ttl: + pipe.expire(pending_key, self.ttl) + + await pipe.execute() + return True + except Exception as e: + raise BackendError(f"Failed to store workflow: {e}") from e + + async def get_workflows(self, queue_name: str, workflow_ids: list[str]) -> dict[str, dict[str, Any]]: + workflows_key = self._workflows_key(queue_name) + + if not workflow_ids: + return {} + + try: + data = await self.client.hmget(workflows_key, workflow_ids) # type: ignore[misc] + result = {} + for wf_json in data: + if wf_json: + wf = json.loads(wf_json) + result[wf["id"]] = wf + return result + except Exception as e: + raise BackendError(f"Failed to get workflows: {e}") from e - return [json.loads(d) for d in await self.client.mget(cron_tasks) if d is not None] + async def get_all_workflows(self, queue_name: str) -> list[dict[str, Any]]: + workflows_key = self._workflows_key(queue_name) + + try: + all_data = await self.client.hvals(workflows_key) # type: ignore[misc] + return [json.loads(wf_json) for wf_json in all_data if wf_json] + except Exception as e: + raise BackendError(f"Failed to get all workflows: {e}") from e + + async def get_pending_workflows(self, queue_name: str) -> list[dict[str, Any]]: + workflows_key = self._workflows_key(queue_name) + pending_index_key = self._workflow_pending_index_key(queue_name) + + try: + workflow_ids = await self.client.smembers(pending_index_key) # type: ignore[misc] + if not workflow_ids: + return [] + + ids = [wid.decode() if isinstance(wid, bytes) else wid for wid in workflow_ids] + data = await self.client.hmget(workflows_key, ids) # type: ignore[misc] + return [json.loads(wf_json) for wf_json in data if wf_json] + except Exception as e: + raise BackendError(f"Failed to get pending workflows: {e}") from e + + async def delete_workflow(self, queue_name: str, workflow_id: str) -> bool: + workflows_key = self._workflows_key(queue_name) + pending_key = self._workflow_pending_key(queue_name, workflow_id) + pending_index_key = self._workflow_pending_index_key(queue_name) + try: + async with self.client.pipeline(transaction=True) as pipe: + pipe.hdel(workflows_key, workflow_id) + pipe.delete(pending_key) + pipe.srem(pending_index_key, workflow_id) + results = await pipe.execute() + return int(results[0]) > 0 + except Exception as e: + raise BackendError(f"Failed to delete workflow: {e}") from e + + async def mark_workflow_task_complete(self, queue_name: str, workflow_id: str, task_id: str) -> int: + pending_key = self._workflow_pending_key(queue_name, workflow_id) + + try: + async with self.client.pipeline() as pipe: + pipe.srem(pending_key, task_id) + pipe.scard(pending_key) + results = await pipe.execute() + + removed_count = results[0] # 1 if removed, 0 if not found + remaining_count = results[1] + + if removed_count == 0: + return -1 # task not in pending set + + return int(remaining_count) + except Exception as e: + raise BackendError(f"Failed to mark workflow task complete: {e}") from e diff --git a/src/sheppy/models.py b/src/sheppy/models.py index d2fd088..cae3d69 100644 --- a/src/sheppy/models.py +++ b/src/sheppy/models.py @@ -144,6 +144,7 @@ class Task(BaseModel): next_retry_at: Timestamp when the task is scheduled to be retried next. None if the task is not scheduled for retry. is_retriable: Returns True if the task is configured to be retriable. should_retry: Returns True if the task should be retried based on its retry configuration and current retry count. + workflow_id: ID of the workflow this task belongs to (if created within a workflow). Note: - You should not create Task instances directly. Instead, use the `@task` decorator to define a task function, and then call that function to create a Task instance. @@ -193,6 +194,10 @@ def add(x: int, y: int) -> int: """datetime|None: Timestamp when the task was last retried. None if the task has never been retried.""" next_retry_at: AwareDatetime | None = None """datetime|None: Timestamp when the task is scheduled to be retried next. None if the task is not scheduled for retry.""" + + workflow_id: UUID | None = None + """UUID|None: ID of the workflow this task belongs to (if created within a workflow).""" + # caller: str | None = None # worker: str | None = None diff --git a/src/sheppy/queue.py b/src/sheppy/queue.py index 44f70d5..751f00e 100644 --- a/src/sheppy/queue.py +++ b/src/sheppy/queue.py @@ -4,6 +4,7 @@ from uuid import UUID from ._config import config +from ._workflow import Workflow, WorkflowResult, WorkflowRunner from .backend.base import Backend from .models import Task, TaskCron from .task_factory import TaskFactory @@ -491,3 +492,94 @@ def _get_task_ids(self, task: list[Task | UUID | str] | Task | UUID | str) -> tu task_ids = list({str(t.id if isinstance(t, Task) else t) for t in task}) return task_ids, batch_mode + + async def add_workflow(self, workflow: Workflow) -> WorkflowResult: + await self.__ensure_backend_is_connected() + + runner = WorkflowRunner(workflow) + result = runner.run() + + await self.backend.store_workflow(self.name, result.workflow.model_dump(mode='json')) + + if result.pending_tasks: + await self.add(result.pending_tasks) + + return result + + async def resume_workflow(self, workflow: Workflow | UUID | str, task_results: dict[UUID, Task] | None = None) -> WorkflowResult: + await self.__ensure_backend_is_connected() + + if isinstance(workflow, (UUID, str)): + w_id = str(workflow) + wf_data = await self.backend.get_workflows(self.name, [w_id]) + if not wf_data or not wf_data[w_id]: + raise ValueError(f"Workflow not found: {workflow}") + workflow = Workflow.model_validate(wf_data[w_id]) + + if task_results is None: + incomplete_ids = workflow.get_incomplete_task_ids() + if incomplete_ids: + task_results_dict = await self.get_task(incomplete_ids) # type: ignore + task_results = { + tid: task for tid, task in task_results_dict.items() + if task.status == 'completed' or task.error + } + else: + task_results = {} + + runner = WorkflowRunner(workflow, task_results=task_results) + result = runner.run() + + await self.backend.store_workflow(self.name, result.workflow.model_dump(mode='json')) + + if result.pending_tasks: + await self.add(result.pending_tasks) + + return result + + @overload + async def get_workflow(self, workflow: Workflow | UUID | str) -> Workflow | None: ... + + @overload + async def get_workflow(self, workflow: list[Workflow | UUID | str]) -> dict[UUID, Workflow]: ... + + async def get_workflow(self, workflow: Workflow | UUID | str | list[Workflow | UUID | str]) -> dict[UUID, Workflow] | Workflow | None: + await self.__ensure_backend_is_connected() + + batch_mode = isinstance(workflow, list) + if not batch_mode: + workflow = [workflow] # type: ignore + + workflow_ids = [str(w.id if isinstance(w, Workflow) else w) for w in workflow] # type: ignore + + results = await self.backend.get_workflows(self.name, workflow_ids) + + if batch_mode: + return {UUID(wf_id): Workflow.model_validate(wf) for wf_id, wf in results.items()} + + if workflow_ids[0] in results: + return Workflow.model_validate(results[workflow_ids[0]]) + return None + + async def get_all_workflows(self) -> list[Workflow]: + await self.__ensure_backend_is_connected() + + workflows_data = await self.backend.get_all_workflows(self.name) + return [Workflow.model_validate(wf) for wf in workflows_data] + + async def get_pending_workflows(self) -> list[Workflow]: + await self.__ensure_backend_is_connected() + + workflows_data = await self.backend.get_pending_workflows(self.name) + return [Workflow.model_validate(wf) for wf in workflows_data] + + async def delete_workflow(self, workflow: Workflow | UUID | str) -> bool: + await self.__ensure_backend_is_connected() + + workflow_id = str(workflow.id if isinstance(workflow, Workflow) else workflow) + return await self.backend.delete_workflow(self.name, workflow_id) + + async def _mark_workflow_task_complete(self, workflow_id: UUID | str, task_id: UUID | str) -> int: + await self.__ensure_backend_is_connected() + + return await self.backend.mark_workflow_task_complete(self.name, str(workflow_id), str(task_id)) diff --git a/src/sheppy/task_factory.py b/src/sheppy/task_factory.py index da1a8f0..2705654 100644 --- a/src/sheppy/task_factory.py +++ b/src/sheppy/task_factory.py @@ -9,7 +9,9 @@ overload, ) +from ._utils.functions import stringify_function from ._utils.validation import validate_input +from ._workflow import get_workflow_context from .models import Task, TaskConfig, TaskCron, TaskSpec P = ParamSpec('P') @@ -24,22 +26,6 @@ class TaskFactory: def __init__(self) -> None: pass - @staticmethod - def _stringify_function(func: Callable[..., Any]) -> str: - _module = func.__module__ - # special case if the task is in the main python file that is executed - if _module == "__main__": - global cache_main_module - if not cache_main_module: - # this handles "python -m app.main" because with "-m" sys.argv[0] is absolute path - _main_path = os.path.relpath(sys.argv[0])[:-3] - # replace handles situations when user runs "python app/main.py" - cache_main_module = _main_path.replace(os.sep, ".") - - _module = cache_main_module - - return f"{_module}:{func.__name__}" - @staticmethod def create_task(func: Callable[..., Any], args: tuple[Any, ...], @@ -62,7 +48,7 @@ def create_task(func: Callable[..., Any], if retry_on_timeout is not None: task_config["retry_on_timeout"] = retry_on_timeout - func_string = TaskFactory._stringify_function(func) + func_string = stringify_function(func) args, kwargs = validate_input(func, tuple(args or ()), dict(kwargs or {})) @@ -70,9 +56,17 @@ def create_task(func: Callable[..., Any], if middleware: for m in middleware: # todo: should probably also validate them here - stringified_middlewares.append(TaskFactory._stringify_function(m)) + stringified_middlewares.append(stringify_function(m)) + + task_kwargs: dict[str, Any] = {} + + ctx = get_workflow_context() + if ctx is not None: + task_kwargs["id"] = ctx.next_task_id() + task_kwargs["workflow_id"] = ctx.workflow_id _task = Task( + **task_kwargs, spec=TaskSpec( func=func_string, args=args, diff --git a/src/sheppy/testqueue.py b/src/sheppy/testqueue.py index fd3b172..540ed44 100644 --- a/src/sheppy/testqueue.py +++ b/src/sheppy/testqueue.py @@ -5,6 +5,7 @@ from uuid import UUID from ._utils.task_execution import TaskProcessor +from ._workflow import Workflow, WorkflowResult from .backend.memory import MemoryBackend from .models import Task, TaskCron from .queue import Queue @@ -425,3 +426,42 @@ async def _process_task(self, task: Task) -> Task: return Task.model_validate(stored_task_data) return task + + def add_workflow(self, workflow: Workflow) -> WorkflowResult: + return asyncio.run(self._queue.add_workflow(workflow)) + + def resume_workflow(self, workflow: Workflow | UUID | str, task_results: dict[UUID, Task] | None = None) -> WorkflowResult: + return asyncio.run(self._queue.resume_workflow(workflow, task_results)) + + @overload + def get_workflow(self, workflow: Workflow | UUID | str) -> Workflow | None: ... + + @overload + def get_workflow(self, workflow: list[Workflow | UUID | str]) -> dict[UUID, Workflow]: ... + + def get_workflow(self, workflow: Workflow | UUID | str | list[Workflow | UUID | str]) -> Workflow | None | dict[UUID, Workflow]: + return asyncio.run(self._queue.get_workflow(workflow)) + + def get_all_workflows(self) -> list[Workflow]: + return asyncio.run(self._queue.get_all_workflows()) + + def get_pending_workflows(self) -> list[Workflow]: + return asyncio.run(self._queue.get_pending_workflows()) + + def delete_workflow(self, workflow: Workflow | UUID | str) -> bool: + return asyncio.run(self._queue.delete_workflow(workflow)) + + def process_workflow(self, workflow: Workflow) -> WorkflowResult: + result = self.add_workflow(workflow) + + while not result.workflow.completed and result.workflow.error is None: + if not result.pending_tasks: + break + + # Process all pending tasks + self.process_all() + + # Resume workflow + result = self.resume_workflow(result.workflow) + + return result diff --git a/src/sheppy/worker.py b/src/sheppy/worker.py index 44f47d6..6748af3 100644 --- a/src/sheppy/worker.py +++ b/src/sheppy/worker.py @@ -375,8 +375,43 @@ async def process_task(self, task: Task, queue: Queue) -> Task: if task.error and task.should_retry and task.next_retry_at is not None: await queue.retry(task, task.next_retry_at) + if (task.status == 'completed' or task.error) and task.workflow_id: + # workflow processing - resume workflow if this task belongs to one + await self._resume_workflow_for_task(queue, task) + return task + async def _resume_workflow_for_task(self, queue: Queue, task: Task) -> None: + if task.workflow_id is None: # for mypy + return + + try: + remaining = await queue._mark_workflow_task_complete(task.workflow_id, task.id) + + if remaining < 0: + # logger.debug(WORKER_PREFIX + f"Workflow {workflow_id} not found or already processed") + return + + if remaining > 0: + logger.debug(WORKER_PREFIX + f"Workflow {task.workflow_id} has {remaining} tasks remaining") + return + + logger.info(WORKER_PREFIX + f"All tasks complete for workflow {task.workflow_id}, resuming") + + result = await queue.resume_workflow(task.workflow_id) + + if result.workflow.completed: + logger.info(WORKER_PREFIX + f"Workflow {task.workflow_id} completed with result: {result.workflow.final_result}") + + elif result.workflow.error: + logger.error(WORKER_PREFIX + f"Workflow {task.workflow_id} failed: {result.workflow.error}") + + elif result.pending_tasks: + logger.info(WORKER_PREFIX + f"Workflow {task.workflow_id} waiting for {len(result.pending_tasks)} more tasks") + + except Exception as e: + logger.exception(WORKER_PREFIX + f"Failed to resume workflow for task {task.id}: {e}") + def __register_signal_handlers(self, loop: asyncio.AbstractEventLoop) -> None: CTRL_C_THRESHOLD = 3 def signal_handler(sig: signal.Signals) -> None: diff --git a/tests/contract/test_backend.py b/tests/contract/test_backend.py index 9091697..c21577e 100644 --- a/tests/contract/test_backend.py +++ b/tests/contract/test_backend.py @@ -950,3 +950,110 @@ async def test_cron(backend: Backend): assert await backend.get_crons(Q) == [] assert await backend.get_crons("different-queue") == [] + + +async def test_store_and_get_workflow(backend: Backend): + wf = {"id": "wf-1", "pending_task_ids": ["t1", "t2"], "completed": False, "error": None} + + await backend.connect() + assert await backend.store_workflow(Q, wf) is True + + result = await backend.get_workflows(Q, ["wf-1"]) + assert result == {"wf-1": wf} + + assert await backend.get_workflows(Q, ["nonexistent"]) == {} + assert await backend.get_workflows("different-queue", ["wf-1"]) == {} + assert await backend.get_workflows(Q, []) == {} + + +async def test_get_all_workflows(backend: Backend): + wf1 = {"id": "wf-1", "pending_task_ids": ["t1"], "completed": False, "error": None} + wf2 = {"id": "wf-2", "pending_task_ids": ["t2"], "completed": False, "error": None} + wf3 = {"id": "wf-3", "pending_task_ids": [], "completed": True, "error": None} + + await backend.connect() + assert await backend.store_workflow(Q, wf1) is True + assert await backend.store_workflow(Q, wf2) is True + assert await backend.store_workflow(Q, wf3) is True + + workflows = await backend.get_all_workflows(Q) + assert len(workflows) == 3 + assert wf1 in workflows + assert wf2 in workflows + assert wf3 in workflows + + assert await backend.get_all_workflows("different-queue") == [] + + +async def test_get_pending_workflows(backend: Backend): + wf_pending = {"id": "wf-1", "pending_task_ids": ["t1"], "completed": False, "error": None} + wf_completed = {"id": "wf-2", "pending_task_ids": [], "completed": True, "error": None} + wf_error = {"id": "wf-3", "pending_task_ids": [], "completed": False, "error": "something broke"} + + await backend.connect() + assert await backend.store_workflow(Q, wf_pending) is True + assert await backend.store_workflow(Q, wf_completed) is True + assert await backend.store_workflow(Q, wf_error) is True + + pending = await backend.get_pending_workflows(Q) + assert len(pending) == 1 + assert pending[0] == wf_pending + + assert await backend.get_pending_workflows("different-queue") == [] + + +async def test_delete_workflow(backend: Backend): + wf = {"id": "wf-1", "pending_task_ids": ["t1"], "completed": False, "error": None} + + await backend.connect() + assert await backend.store_workflow(Q, wf) is True + + assert await backend.delete_workflow(Q, "wf-1") is True + assert await backend.delete_workflow(Q, "wf-1") is False + assert await backend.delete_workflow("different-queue", "wf-1") is False + + assert await backend.get_workflows(Q, ["wf-1"]) == {} + assert await backend.get_all_workflows(Q) == [] + + +async def test_mark_workflow_task_complete(backend: Backend): + wf = {"id": "wf-1", "pending_task_ids": ["t1", "t2", "t3"], "completed": False, "error": None} + + await backend.connect() + assert await backend.store_workflow(Q, wf) is True + + assert await backend.mark_workflow_task_complete(Q, "wf-1", "t1") == 2 + assert await backend.mark_workflow_task_complete(Q, "wf-1", "t2") == 1 + assert await backend.mark_workflow_task_complete(Q, "wf-1", "t3") == 0 + + # already removed + assert await backend.mark_workflow_task_complete(Q, "wf-1", "t1") == -1 + # nonexistent workflow + assert await backend.mark_workflow_task_complete(Q, "nonexistent", "t1") == -1 + + +async def test_store_workflow_overwrites(backend: Backend): + wf = {"id": "wf-1", "pending_task_ids": ["t1", "t2"], "completed": False, "error": None} + wf_updated = {"id": "wf-1", "pending_task_ids": ["t1"], "completed": False, "error": None} + + await backend.connect() + assert await backend.store_workflow(Q, wf) is True + assert await backend.store_workflow(Q, wf_updated) is True + + result = await backend.get_workflows(Q, ["wf-1"]) + assert result == {"wf-1": wf_updated} + + +async def test_get_workflows_batch(backend: Backend): + wf1 = {"id": "wf-1", "pending_task_ids": ["t1"], "completed": False, "error": None} + wf2 = {"id": "wf-2", "pending_task_ids": ["t2"], "completed": False, "error": None} + + await backend.connect() + assert await backend.store_workflow(Q, wf1) is True + assert await backend.store_workflow(Q, wf2) is True + + result = await backend.get_workflows(Q, ["wf-1", "wf-2"]) + assert result == {"wf-1": wf1, "wf-2": wf2} + + result = await backend.get_workflows(Q, ["wf-1", "nonexistent"]) + assert result == {"wf-1": wf1} From f8a2c703d2998812aec99055aa4b3b0d3da78600 Mon Sep 17 00:00:00 2001 From: Martin <15851033+malvex@users.noreply.github.com> Date: Sat, 14 Feb 2026 02:51:17 +0100 Subject: [PATCH 2/4] linting --- src/sheppy/task_factory.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/sheppy/task_factory.py b/src/sheppy/task_factory.py index 2705654..e06d399 100644 --- a/src/sheppy/task_factory.py +++ b/src/sheppy/task_factory.py @@ -1,5 +1,3 @@ -import os -import sys from collections.abc import Callable from functools import wraps from typing import ( From 733db1d54ef0db49e99e06a2cdece06cf30002b0 Mon Sep 17 00:00:00 2001 From: Martin <15851033+malvex@users.noreply.github.com> Date: Wed, 25 Feb 2026 17:49:32 +0100 Subject: [PATCH 3/4] fix --- docker-compose.test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 999b7d9..d07688f 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -1,7 +1,7 @@ services: redis: - #image: redis:6.2-bookworm - image: redis:7.4-bookworm + image: redis:6.2-bookworm + #image: redis:7.4-bookworm ports: - "6379:6379" volumes: From db37d6721f83086b565c5f495c400769d377edd5 Mon Sep 17 00:00:00 2001 From: Martin <15851033+malvex@users.noreply.github.com> Date: Wed, 25 Feb 2026 17:53:16 +0100 Subject: [PATCH 4/4] fix --- .github/workflows/test.yml | 4 ++-- docker-compose.test.yml | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f159dbb..3009d64 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,7 @@ jobs: services: redis: - image: redis:6 + image: redis:7.4-bookworm ports: - 6379:6379 options: >- @@ -47,7 +47,7 @@ jobs: services: redis: - image: redis:8 + image: redis:7.4-bookworm ports: - 6379:6379 options: >- diff --git a/docker-compose.test.yml b/docker-compose.test.yml index d07688f..497e3b3 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -1,7 +1,6 @@ services: redis: - image: redis:6.2-bookworm - #image: redis:7.4-bookworm + image: redis:7.4-bookworm ports: - "6379:6379" volumes: