diff --git a/examples/rate_limits.py b/examples/rate_limits.py new file mode 100644 index 0000000..3805f83 --- /dev/null +++ b/examples/rate_limits.py @@ -0,0 +1,38 @@ +import asyncio +from datetime import datetime +from sheppy import Queue, task + +queue = Queue("redis://127.0.0.1:6379") + + +@task(rate_limit={"max_rate": 2, "rate_period": 5}) +async def do_work(queued_time: datetime) -> int: + execution_time = datetime.now() + delay = (execution_time - queued_time).total_seconds() + return int(delay) + + +async def main(): + t1 = do_work(datetime.now()) + t2 = do_work(datetime.now()) + t3 = do_work(datetime.now()) + await queue.add([t1, t2, t3]) + + # await the task completion + update_tasks = await queue.wait_for([t1, t2, t3]) + t1 = update_tasks[t1.id] + t2 = update_tasks[t2.id] + t3 = update_tasks[t3.id] + + assert all([t1.status == 'completed', t2.status == 'completed', t3.status == 'completed']) + + # two tasks will be executed immediately and result will be 0 + # one task will be executed after 5 seconds because of rate limit and returns 5 + # the order is not guaranteed, any task might run first + print("t1.result:", t1.result) + print("t2.result:", t2.result) + print("t3.result:", t3.result) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/sheppy/backend/base.py b/src/sheppy/backend/base.py index 9330234..40742a3 100644 --- a/src/sheppy/backend/base.py +++ b/src/sheppy/backend/base.py @@ -113,3 +113,7 @@ async def delete_workflow(self, queue_name: str, workflow_id: str) -> bool: @abstractmethod async def mark_workflow_task_complete(self, queue_name: str, workflow_id: str, task_id: str) -> int: pass + + @abstractmethod + async def acquire_rate_limit(self, queue_name: str, key: str, max_rate: int, rate_period: float, task_id: str, strategy: str = "sliding_window") -> float | None: + pass diff --git a/src/sheppy/backend/local.py b/src/sheppy/backend/local.py index d94b85f..272a9dc 100644 --- a/src/sheppy/backend/local.py +++ b/src/sheppy/backend/local.py @@ -1,7 +1,9 @@ import asyncio import contextlib import json +from collections import defaultdict from datetime import datetime, timezone +from time import time from typing import Any from .._localkv.client import KVClient @@ -17,6 +19,7 @@ def __init__(self, host: str = "127.0.0.1", port: int = 17420, *, embedded: bool self._embedded = embedded self._server: asyncio.Server | None = None self._server_task: asyncio.Task[None] | None = None + self._rl_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) self._client = KVClient(host, port) async def connect(self) -> None: @@ -91,6 +94,9 @@ def _workflow_prefix(self, queue_name: str) -> str: def _workflow_pending_key(self, queue_name: str, workflow_id: str) -> str: return f"sheppy:{queue_name}:workflow:{workflow_id}:pending" + def _rate_limit_key(self, queue_name: str, key: str) -> str: + return f"sheppy:{queue_name}:ratelimit:{key}" + async def append(self, queue_name: str, tasks: list[dict[str, Any]], unique: bool = True) -> list[bool]: success = [] @@ -352,6 +358,57 @@ async def delete_workflow(self, queue_name: str, workflow_id: str) -> bool: count = await self.client.delete([key, pending_key]) return count > 0 + async def acquire_rate_limit(self, queue_name: str, key: str, max_rate: int, rate_period: float, task_id: str, strategy: str = "sliding_window") -> float | None: + if strategy == "fixed_window": + return await self._acquire_fixed_window(queue_name, key, max_rate, rate_period, task_id) + + return await self._acquire_sliding_window(queue_name, key, max_rate, rate_period, task_id) + + async def _acquire_sliding_window(self, queue_name: str, key: str, max_rate: int, rate_period: float, task_id: str) -> float | None: + rl_key = self._rate_limit_key(queue_name, key) + + async with self._rl_locks[rl_key]: + now = time() + cutoff = now - rate_period + + # prune expired entries + await self.client.sorted_pop(rl_key, cutoff) + + count = await self.client.sorted_len(rl_key) + + if count < max_rate: + await self.client.sorted_push(rl_key, now, task_id) + return None + + # over limit - find oldest to calculate wait time + entries = await self.client.sorted_get(rl_key) + if entries: + oldest_score = entries[0][0] + return max(oldest_score + rate_period - now, 0.01) + + return rate_period + + async def _acquire_fixed_window(self, queue_name: str, key: str, max_rate: int, rate_period: float, task_id: str) -> float | None: + rl_key = self._rate_limit_key(queue_name, key) + fw_key = f"{rl_key}:fw" + + async with self._rl_locks[rl_key]: + now = time() + + entries = await self.client.sorted_get(fw_key) + + # window expired - reset + if entries and now - entries[0][0] >= rate_period: + await self.client.sorted_pop(fw_key, float('inf')) + entries = [] + + if len(entries) < max_rate: + await self.client.sorted_push(fw_key, now, task_id) + return None + + remaining = entries[0][0] + rate_period - now + return max(remaining, 0.01) + async def mark_workflow_task_complete(self, queue_name: str, workflow_id: str, task_id: str) -> int: pending_key = self._workflow_pending_key(queue_name, workflow_id) diff --git a/src/sheppy/backend/memory.py b/src/sheppy/backend/memory.py index 252c88b..da1e6f7 100644 --- a/src/sheppy/backend/memory.py +++ b/src/sheppy/backend/memory.py @@ -4,6 +4,7 @@ from collections.abc import Callable from dataclasses import dataclass, field from datetime import datetime, timezone +from time import time from typing import Any from .._utils.task_execution import TaskProcessor @@ -32,6 +33,7 @@ def __init__(self, self._crons: dict[str, dict[str, dict[str, Any]]] = defaultdict(dict) self._workflows: dict[str, dict[str, dict[str, Any]]] = defaultdict(dict) # {QUEUE_NAME: {WORKFLOW_ID: workflow_data}} + self._rate_limits: dict[str, list[float]] = defaultdict(list) self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) # for thread-safety self._connected = False @@ -147,6 +149,10 @@ async def clear(self, queue_name: str) -> int: self._scheduled[queue_name].clear() self._crons[queue_name].clear() + rl_keys = [k for k in self._rate_limits if k.startswith(f"{queue_name}:")] + for k in rl_keys: + del self._rate_limits[k] + return queue_size + queue_cron_size async def get_tasks(self, queue_name: str, task_ids: list[str]) -> dict[str,dict[str, Any]]: @@ -221,7 +227,7 @@ async def get_results(self, queue_name: str, task_ids: list[str], timeout: float while True: async with self._locks[queue_name]: - for task_id in task_ids: + for task_id in remaining_ids[:]: task_data = self._task_metadata[queue_name].get(task_id, {}) if task_data.get("finished_at"): @@ -367,6 +373,50 @@ async def mark_workflow_task_complete(self, queue_name: str, workflow_id: str, t return len(pending_ids) + async def acquire_rate_limit(self, queue_name: str, key: str, max_rate: int, rate_period: float, task_id: str, strategy: str = "sliding_window") -> float | None: + self._check_connected() + + if strategy == "fixed_window": + return await self._acquire_fixed_window(queue_name, key, max_rate, rate_period) + + return await self._acquire_sliding_window(queue_name, key, max_rate, rate_period) + + async def _acquire_sliding_window(self, queue_name: str, key: str, max_rate: int, rate_period: float) -> float | None: + bucket_key = f"{queue_name}:{key}" + + async with self._locks[bucket_key]: + now = time() + cutoff = now - rate_period + entries = self._rate_limits[bucket_key] + + self._rate_limits[bucket_key] = [t for t in entries if t > cutoff] + entries = self._rate_limits[bucket_key] + + if len(entries) < max_rate: + entries.append(now) + return None + + oldest = min(entries) + return max(oldest + rate_period - now, 0.01) + + async def _acquire_fixed_window(self, queue_name: str, key: str, max_rate: int, rate_period: float) -> float | None: + bucket_key = f"{queue_name}:{key}:fw" + + async with self._locks[bucket_key]: + now = time() + entries = self._rate_limits[bucket_key] + + # window expired - reset + if entries and now - entries[0] >= rate_period: + entries.clear() + + if len(entries) < max_rate: + entries.append(now) + return None + + remaining = entries[0] + rate_period - now + return max(remaining, 0.01) + def _check_connected(self) -> None: if not self.is_connected: raise BackendError("Not connected") diff --git a/src/sheppy/backend/redis.py b/src/sheppy/backend/redis.py index 464d720..81dc7fd 100644 --- a/src/sheppy/backend/redis.py +++ b/src/sheppy/backend/redis.py @@ -110,6 +110,12 @@ def _worker_metadata_key(self, queue_name: str) -> str: """Worker Metadata (key prefix)""" return f"sheppy:workers:{queue_name}" + def _rate_limit_key(self, queue_name: str) -> str: + return f"sheppy:ratelimit:{queue_name}" + + def _sliding_window_key(self, queue_name: str, key: str) -> str: + return f"sheppy:ratelimit:{queue_name}:sw:{key}" + @property def client(self) -> redis.Redis: if not self._client: @@ -229,6 +235,10 @@ async def clear(self, queue_name: str) -> int: await self.client.xtrim(pending_tasks_key, maxlen=0) await self.client.delete(scheduled_key) await self.client.hdel(self._queues_registry_key(), queue_name) # type: ignore[misc] + await self.client.delete(self._rate_limit_key(queue_name)) + sw_keys = [key async for key in self.client.scan_iter(match=self._sliding_window_key(queue_name, '*'), count=10000)] + if sw_keys: + await self.client.delete(*sw_keys) return count @@ -423,6 +433,60 @@ async def get_results(self, queue_name: str, task_ids: list[str], timeout: float if not remaining_ids: return results + async def acquire_rate_limit(self, queue_name: str, key: str, max_rate: int, rate_period: float, task_id: str, strategy: str = "sliding_window") -> float | None: + if strategy == "fixed_window": + return await self._acquire_fixed_window(queue_name, key, max_rate, rate_period) + + return await self._acquire_sliding_window(queue_name, key, max_rate, rate_period, task_id) + + async def _acquire_fixed_window(self, queue_name: str, key: str, max_rate: int, rate_period: float) -> float | None: + rl_key = self._rate_limit_key(queue_name) + ttl_ms = int(rate_period * 1000) + + async with self.client.pipeline(transaction=False) as pipe: + pipe.hincrby(rl_key, key, 1) + pipe.hpexpire(rl_key, ttl_ms, key, nx=True) + pipe.hpttl(rl_key, key) + results = await pipe.execute() + + count = results[0] + + if count <= max_rate: + return None + + # over limit - undo increment, return remaining TTL + await self.client.hincrby(rl_key, key, -1) # type: ignore[misc] + pttl_result = results[2] + remaining_ms = pttl_result[0] if pttl_result and pttl_result[0] > 0 else ttl_ms + return remaining_ms / 1000.0 + + async def _acquire_sliding_window(self, queue_name: str, key: str, max_rate: int, rate_period: float, task_id: str) -> float | None: + rl_key = self._sliding_window_key(queue_name, key) + now = time() + window_start = now - rate_period + + async with self.client.pipeline(transaction=False) as pipe: + pipe.zremrangebyscore(rl_key, 0, window_start) + pipe.zcard(rl_key) + pipe.zadd(rl_key, {task_id: now}) + pipe.expire(rl_key, int(rate_period) + 1) + results = await pipe.execute() + + current_count = results[1] + + if current_count < max_rate: + return None + + # over limit - remove the entry we just added + await self.client.zrem(rl_key, task_id) + + # calculate wait time from oldest entry in the window + oldest = await self.client.zrange(rl_key, 0, 0, withscores=True) + if oldest: + wait = float(oldest[0][1]) + rate_period - now + return max(wait, 0.01) + + return rate_period async def _ensure_consumer_group(self, stream_key: str) -> None: if stream_key in self._initialized_groups: diff --git a/src/sheppy/models.py b/src/sheppy/models.py index cae3d69..4bbf53d 100644 --- a/src/sheppy/models.py +++ b/src/sheppy/models.py @@ -18,6 +18,7 @@ field_validator, model_validator, ) +from typing_extensions import NotRequired, TypedDict from ._utils.functions import reconstruct_result @@ -42,6 +43,16 @@ def cron_expression_validator(value: str) -> str: 'completed', 'failed', 'cancelled', 'unknown'] +RateLimitStrategy = Literal["sliding_window", "fixed_window"] + + +class RateLimit(TypedDict): + max_rate: int + rate_period: float # time window for the rate limit in seconds + key: NotRequired[str] # defaults to task name, can be set to group rate limits across tasks + strategy: NotRequired[RateLimitStrategy] # defaults to "sliding_window" + + class TaskSpec(BaseModel): """Task specification. @@ -116,7 +127,7 @@ def my_task(): timeout: float | None = None # seconds retry_on_timeout: bool = False - # tags: dict[str, str] = Field(default_factory=dict) + rate_limit: RateLimit | None = None @field_validator('retry_delay') @classmethod diff --git a/src/sheppy/queue.py b/src/sheppy/queue.py index 751f00e..844e7dc 100644 --- a/src/sheppy/queue.py +++ b/src/sheppy/queue.py @@ -443,7 +443,34 @@ async def _pop_pending(self, limit: int = 1, timeout: float | None = None) -> li raise ValueError("Pop limit must be greater than zero.") tasks_data = await self.backend.pop(self.name, limit, timeout) - return [Task.model_validate(t) for t in tasks_data] + tasks = [] + + for task_data in tasks_data: + task = Task.model_validate(task_data) + + if task.config.rate_limit: + rl = task.config.rate_limit + key = rl.get("key", task.spec.func) + + wait = await self.backend.acquire_rate_limit( + self.name, + key, + rl["max_rate"], + rl["rate_period"], + strategy=rl.get("strategy", "sliding_window"), + task_id=str(task.id), + ) + + if wait is not None: + scheduled_at = datetime.now(timezone.utc) + timedelta(seconds=wait) + task.__dict__["status"] = "scheduled" + task.__dict__["scheduled_at"] = scheduled_at + await self.backend.schedule(self.name, task.model_dump(mode='json'), scheduled_at, unique=False) + continue + + tasks.append(task) + + return tasks async def _store_result(self, task: Task) -> bool: """Store task result. Internal method used by workers. diff --git a/src/sheppy/task_factory.py b/src/sheppy/task_factory.py index e06d399..ede68fd 100644 --- a/src/sheppy/task_factory.py +++ b/src/sheppy/task_factory.py @@ -10,7 +10,7 @@ from ._utils.functions import stringify_function from ._utils.validation import validate_input from ._workflow import get_workflow_context -from .models import Task, TaskConfig, TaskCron, TaskSpec +from .models import RateLimit, Task, TaskConfig, TaskCron, TaskSpec P = ParamSpec('P') R = TypeVar('R') @@ -33,6 +33,7 @@ def create_task(func: Callable[..., Any], middleware: list[Callable[..., Any]] | None, timeout: float | None, retry_on_timeout: bool | None, + rate_limit: RateLimit | None = None, ) -> Task: task_config: dict[str, Any] = { @@ -46,6 +47,9 @@ def create_task(func: Callable[..., Any], if retry_on_timeout is not None: task_config["retry_on_timeout"] = retry_on_timeout + if rate_limit is not None: + task_config["rate_limit"] = dict(rate_limit) + func_string = stringify_function(func) args, kwargs = validate_input(func, tuple(args or ()), dict(kwargs or {})) @@ -94,6 +98,7 @@ def task( middleware: list[Callable[..., Any]] | None = None, timeout: float | None = None, retry_on_timeout: bool | None = None, + rate_limit: RateLimit | None = None, ) -> Callable[[Callable[P, R]], Callable[P, Task]]: ... @@ -110,12 +115,13 @@ def task( middleware: list[Callable[..., Any]] | None = None, timeout: float | None = None, retry_on_timeout: bool | None = None, + rate_limit: RateLimit | None = None, ) -> Callable[[Callable[P, R]], Callable[P, Task]] | Callable[P, Task]: def decorator(func: Callable[P, R]) -> Callable[P, Task]: @wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs) -> Task: - return TaskFactory.create_task(func, tuple(args), kwargs, retry, retry_delay, middleware, timeout, retry_on_timeout) + return TaskFactory.create_task(func, tuple(args), kwargs, retry, retry_delay, middleware, timeout, retry_on_timeout, rate_limit) return wrapper diff --git a/tests/conftest.py b/tests/conftest.py index e8eb674..c320ee9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,7 +57,9 @@ async def worker_backend(backend: Backend) -> AsyncGenerator[Backend, None]: await worker_backend.disconnect() elif isinstance(backend, LocalBackend): - # embedded backend: same instance (server is embedded) + # ensure embedded server is running before worker tries to connect + if not backend.is_connected: + await backend.connect() worker_backend = LocalBackend(port=17421) yield worker_backend await worker_backend.disconnect() diff --git a/tests/contract/test_rate_limit.py b/tests/contract/test_rate_limit.py new file mode 100644 index 0000000..04ccdd5 --- /dev/null +++ b/tests/contract/test_rate_limit.py @@ -0,0 +1,123 @@ +import asyncio +from datetime import datetime + +from sheppy import Queue, Worker, task + + +@task(rate_limit={"max_rate": 2, "rate_period": 1, "strategy": "fixed_window"}) +async def rate_limited_fixed_window(created_time: datetime) -> int: + return int((datetime.now() - created_time).total_seconds()*10) + +@task(rate_limit={"max_rate": 2, "rate_period": 1, "strategy": "sliding_window"}) +async def rate_limited_sliding_window(created_time: datetime) -> int: + return int((datetime.now() - created_time).total_seconds()*10) + +class TestRateLimitFixedWindow: + + async def test_rate_limit(self, queue: Queue, worker: Worker): + t1 = rate_limited_fixed_window(datetime.now()) + t2 = rate_limited_fixed_window(datetime.now()) + t3 = rate_limited_fixed_window(datetime.now()) + + await queue.add([t1, t2, t3]) + + worker_task = asyncio.create_task(worker.work(max_tasks=3)) + + processed = await queue.wait_for([t1, t2, t3], timeout=3) + + timer = {} + for task_id, t in processed.items(): + if t.result not in timer: + timer[t.result] = [] + timer[t.result].append(task_id) + + assert 0 in timer and len(timer[0]) == 2 + assert 10 in timer and len(timer[10]) == 1 + + await worker_task + + async def test_rate_limit_with_delay(self, queue: Queue, worker: Worker): + worker_task = asyncio.create_task(worker.work(max_tasks=3)) + + t1 = rate_limited_fixed_window(datetime.now()) + await queue.add(t1) + + await asyncio.sleep(0.2) + t2 = rate_limited_fixed_window(datetime.now()) + await queue.add(t2) + + await asyncio.sleep(0.2) + t3 = rate_limited_fixed_window(datetime.now()) + await queue.add(t3) + + processed = await queue.wait_for([t1, t2, t3], timeout=3) + + timer = {} + for task_id, t in processed.items(): + if t.result not in timer: + timer[t.result] = [] + timer[t.result].append(task_id) + + assert 0 in timer and len(timer[0]) == 2 + assert 6 in timer and len(timer[6]) == 1 + + await worker_task + + +class TestRateLimitSlidingWindow: + + async def test_rate_limit(self, queue: Queue, worker: Worker): + worker_task = asyncio.create_task(worker.work(max_tasks=4)) + + t1 = rate_limited_sliding_window(datetime.now()) + t2 = rate_limited_sliding_window(datetime.now()) + t3 = rate_limited_sliding_window(datetime.now()) + t4 = rate_limited_sliding_window(datetime.now()) + + await queue.add([t1, t2, t3, t4]) + + + processed = await queue.wait_for([t1, t2, t3, t4], timeout=10) + + timer = {} + for task_id, t in processed.items(): + if t.result not in timer: + timer[t.result] = [] + timer[t.result].append(task_id) + + assert 0 in timer and len(timer[0]) == 2 + # assert 10 in timer and len(timer[10]) == 1 + assert 10 in timer and len(timer[10]) == 2 + + await worker_task + + async def test_rate_limit_with_delay(self, queue: Queue, worker: Worker): + worker_task = asyncio.create_task(worker.work(max_tasks=6)) + + t1 = rate_limited_sliding_window(datetime.now()) + await queue.add(t1) + + await asyncio.sleep(0.2) + t2 = rate_limited_sliding_window(datetime.now()) + t3 = rate_limited_sliding_window(datetime.now()) + t4 = rate_limited_sliding_window(datetime.now()) + t5 = rate_limited_sliding_window(datetime.now()) + t6 = rate_limited_sliding_window(datetime.now()) + await queue.add([t2, t3, t4, t5, t6]) + + + processed = await queue.wait_for([t1, t2, t3, t4, t5, t6], timeout=5) + + timer = {} + for task_id, t in processed.items(): + if t.result not in timer: + timer[t.result] = [] + timer[t.result].append(task_id) + + assert 0 in timer and len(timer[0]) == 2 + assert 8 in timer and len(timer[8]) == 1 + assert 10 in timer and len(timer[10]) == 1 + assert 18 in timer and len(timer[18]) == 1 + assert 20 in timer and len(timer[20]) == 1 + + await worker_task