Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions examples/rate_limits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import asyncio
from datetime import datetime
from sheppy import Queue, task

queue = Queue("redis://127.0.0.1:6379")


@task(rate_limit={"max_rate": 2, "rate_period": 5})
async def do_work(queued_time: datetime) -> int:
execution_time = datetime.now()
delay = (execution_time - queued_time).total_seconds()
return int(delay)


async def main():
t1 = do_work(datetime.now())
t2 = do_work(datetime.now())
t3 = do_work(datetime.now())
await queue.add([t1, t2, t3])

# await the task completion
update_tasks = await queue.wait_for([t1, t2, t3])
t1 = update_tasks[t1.id]
t2 = update_tasks[t2.id]
t3 = update_tasks[t3.id]

assert all([t1.status == 'completed', t2.status == 'completed', t3.status == 'completed'])

# two tasks will be executed immediately and result will be 0
# one task will be executed after 5 seconds because of rate limit and returns 5
# the order is not guaranteed, any task might run first
print("t1.result:", t1.result)
print("t2.result:", t2.result)
print("t3.result:", t3.result)


if __name__ == "__main__":
asyncio.run(main())
4 changes: 4 additions & 0 deletions src/sheppy/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,7 @@ async def delete_workflow(self, queue_name: str, workflow_id: str) -> bool:
@abstractmethod
async def mark_workflow_task_complete(self, queue_name: str, workflow_id: str, task_id: str) -> int:
pass

@abstractmethod
async def acquire_rate_limit(self, queue_name: str, key: str, max_rate: int, rate_period: float, task_id: str, strategy: str = "sliding_window") -> float | None:
pass
57 changes: 57 additions & 0 deletions src/sheppy/backend/local.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import contextlib
import json
from collections import defaultdict
from datetime import datetime, timezone
from time import time
from typing import Any

from .._localkv.client import KVClient
Expand All @@ -17,6 +19,7 @@ def __init__(self, host: str = "127.0.0.1", port: int = 17420, *, embedded: bool
self._embedded = embedded
self._server: asyncio.Server | None = None
self._server_task: asyncio.Task[None] | None = None
self._rl_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self._client = KVClient(host, port)

async def connect(self) -> None:
Expand Down Expand Up @@ -91,6 +94,9 @@ def _workflow_prefix(self, queue_name: str) -> str:
def _workflow_pending_key(self, queue_name: str, workflow_id: str) -> str:
return f"sheppy:{queue_name}:workflow:{workflow_id}:pending"

def _rate_limit_key(self, queue_name: str, key: str) -> str:
return f"sheppy:{queue_name}:ratelimit:{key}"

async def append(self, queue_name: str, tasks: list[dict[str, Any]], unique: bool = True) -> list[bool]:
success = []

Expand Down Expand Up @@ -352,6 +358,57 @@ async def delete_workflow(self, queue_name: str, workflow_id: str) -> bool:
count = await self.client.delete([key, pending_key])
return count > 0

async def acquire_rate_limit(self, queue_name: str, key: str, max_rate: int, rate_period: float, task_id: str, strategy: str = "sliding_window") -> float | None:
if strategy == "fixed_window":
return await self._acquire_fixed_window(queue_name, key, max_rate, rate_period, task_id)

return await self._acquire_sliding_window(queue_name, key, max_rate, rate_period, task_id)

async def _acquire_sliding_window(self, queue_name: str, key: str, max_rate: int, rate_period: float, task_id: str) -> float | None:
rl_key = self._rate_limit_key(queue_name, key)

async with self._rl_locks[rl_key]:
now = time()
cutoff = now - rate_period

# prune expired entries
await self.client.sorted_pop(rl_key, cutoff)

count = await self.client.sorted_len(rl_key)

if count < max_rate:
await self.client.sorted_push(rl_key, now, task_id)
return None

# over limit - find oldest to calculate wait time
entries = await self.client.sorted_get(rl_key)
if entries:
oldest_score = entries[0][0]
return max(oldest_score + rate_period - now, 0.01)

return rate_period

async def _acquire_fixed_window(self, queue_name: str, key: str, max_rate: int, rate_period: float, task_id: str) -> float | None:
rl_key = self._rate_limit_key(queue_name, key)
fw_key = f"{rl_key}:fw"

async with self._rl_locks[rl_key]:
now = time()

entries = await self.client.sorted_get(fw_key)

# window expired - reset
if entries and now - entries[0][0] >= rate_period:
await self.client.sorted_pop(fw_key, float('inf'))
entries = []

if len(entries) < max_rate:
await self.client.sorted_push(fw_key, now, task_id)
return None

remaining = entries[0][0] + rate_period - now
return max(remaining, 0.01)

async def mark_workflow_task_complete(self, queue_name: str, workflow_id: str, task_id: str) -> int:
pending_key = self._workflow_pending_key(queue_name, workflow_id)

Expand Down
52 changes: 51 additions & 1 deletion src/sheppy/backend/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Callable
from dataclasses import dataclass, field
from datetime import datetime, timezone
from time import time
from typing import Any

from .._utils.task_execution import TaskProcessor
Expand Down Expand Up @@ -32,6 +33,7 @@ def __init__(self,
self._crons: dict[str, dict[str, dict[str, Any]]] = defaultdict(dict)
self._workflows: dict[str, dict[str, dict[str, Any]]] = defaultdict(dict) # {QUEUE_NAME: {WORKFLOW_ID: workflow_data}}

self._rate_limits: dict[str, list[float]] = defaultdict(list)
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) # for thread-safety
self._connected = False

Expand Down Expand Up @@ -147,6 +149,10 @@ async def clear(self, queue_name: str) -> int:
self._scheduled[queue_name].clear()
self._crons[queue_name].clear()

rl_keys = [k for k in self._rate_limits if k.startswith(f"{queue_name}:")]
for k in rl_keys:
del self._rate_limits[k]

return queue_size + queue_cron_size

async def get_tasks(self, queue_name: str, task_ids: list[str]) -> dict[str,dict[str, Any]]:
Expand Down Expand Up @@ -221,7 +227,7 @@ async def get_results(self, queue_name: str, task_ids: list[str], timeout: float

while True:
async with self._locks[queue_name]:
for task_id in task_ids:
for task_id in remaining_ids[:]:
task_data = self._task_metadata[queue_name].get(task_id, {})

if task_data.get("finished_at"):
Expand Down Expand Up @@ -367,6 +373,50 @@ async def mark_workflow_task_complete(self, queue_name: str, workflow_id: str, t

return len(pending_ids)

async def acquire_rate_limit(self, queue_name: str, key: str, max_rate: int, rate_period: float, task_id: str, strategy: str = "sliding_window") -> float | None:
self._check_connected()

if strategy == "fixed_window":
return await self._acquire_fixed_window(queue_name, key, max_rate, rate_period)

return await self._acquire_sliding_window(queue_name, key, max_rate, rate_period)

async def _acquire_sliding_window(self, queue_name: str, key: str, max_rate: int, rate_period: float) -> float | None:
bucket_key = f"{queue_name}:{key}"

async with self._locks[bucket_key]:
now = time()
cutoff = now - rate_period
entries = self._rate_limits[bucket_key]

self._rate_limits[bucket_key] = [t for t in entries if t > cutoff]
entries = self._rate_limits[bucket_key]

if len(entries) < max_rate:
entries.append(now)
return None

oldest = min(entries)
return max(oldest + rate_period - now, 0.01)

async def _acquire_fixed_window(self, queue_name: str, key: str, max_rate: int, rate_period: float) -> float | None:
bucket_key = f"{queue_name}:{key}:fw"

async with self._locks[bucket_key]:
now = time()
entries = self._rate_limits[bucket_key]

# window expired - reset
if entries and now - entries[0] >= rate_period:
entries.clear()

if len(entries) < max_rate:
entries.append(now)
return None

remaining = entries[0] + rate_period - now
return max(remaining, 0.01)

def _check_connected(self) -> None:
if not self.is_connected:
raise BackendError("Not connected")
Expand Down
64 changes: 64 additions & 0 deletions src/sheppy/backend/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ def _worker_metadata_key(self, queue_name: str) -> str:
"""Worker Metadata (key prefix)"""
return f"sheppy:workers:{queue_name}"

def _rate_limit_key(self, queue_name: str) -> str:
return f"sheppy:ratelimit:{queue_name}"

def _sliding_window_key(self, queue_name: str, key: str) -> str:
return f"sheppy:ratelimit:{queue_name}:sw:{key}"

@property
def client(self) -> redis.Redis:
if not self._client:
Expand Down Expand Up @@ -229,6 +235,10 @@ async def clear(self, queue_name: str) -> int:
await self.client.xtrim(pending_tasks_key, maxlen=0)
await self.client.delete(scheduled_key)
await self.client.hdel(self._queues_registry_key(), queue_name) # type: ignore[misc]
await self.client.delete(self._rate_limit_key(queue_name))
sw_keys = [key async for key in self.client.scan_iter(match=self._sliding_window_key(queue_name, '*'), count=10000)]
if sw_keys:
await self.client.delete(*sw_keys)

return count

Expand Down Expand Up @@ -423,6 +433,60 @@ async def get_results(self, queue_name: str, task_ids: list[str], timeout: float
if not remaining_ids:
return results

async def acquire_rate_limit(self, queue_name: str, key: str, max_rate: int, rate_period: float, task_id: str, strategy: str = "sliding_window") -> float | None:
if strategy == "fixed_window":
return await self._acquire_fixed_window(queue_name, key, max_rate, rate_period)

return await self._acquire_sliding_window(queue_name, key, max_rate, rate_period, task_id)

async def _acquire_fixed_window(self, queue_name: str, key: str, max_rate: int, rate_period: float) -> float | None:
rl_key = self._rate_limit_key(queue_name)
ttl_ms = int(rate_period * 1000)

async with self.client.pipeline(transaction=False) as pipe:
pipe.hincrby(rl_key, key, 1)
pipe.hpexpire(rl_key, ttl_ms, key, nx=True)
pipe.hpttl(rl_key, key)
results = await pipe.execute()

count = results[0]

if count <= max_rate:
return None

# over limit - undo increment, return remaining TTL
await self.client.hincrby(rl_key, key, -1) # type: ignore[misc]
pttl_result = results[2]
remaining_ms = pttl_result[0] if pttl_result and pttl_result[0] > 0 else ttl_ms
return remaining_ms / 1000.0

async def _acquire_sliding_window(self, queue_name: str, key: str, max_rate: int, rate_period: float, task_id: str) -> float | None:
rl_key = self._sliding_window_key(queue_name, key)
now = time()
window_start = now - rate_period

async with self.client.pipeline(transaction=False) as pipe:
pipe.zremrangebyscore(rl_key, 0, window_start)
pipe.zcard(rl_key)
pipe.zadd(rl_key, {task_id: now})
pipe.expire(rl_key, int(rate_period) + 1)
results = await pipe.execute()

current_count = results[1]

if current_count < max_rate:
return None

# over limit - remove the entry we just added
await self.client.zrem(rl_key, task_id)

# calculate wait time from oldest entry in the window
oldest = await self.client.zrange(rl_key, 0, 0, withscores=True)
if oldest:
wait = float(oldest[0][1]) + rate_period - now
return max(wait, 0.01)

return rate_period

async def _ensure_consumer_group(self, stream_key: str) -> None:
if stream_key in self._initialized_groups:
Expand Down
13 changes: 12 additions & 1 deletion src/sheppy/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
field_validator,
model_validator,
)
from typing_extensions import NotRequired, TypedDict

from ._utils.functions import reconstruct_result

Expand All @@ -42,6 +43,16 @@ def cron_expression_validator(value: str) -> str:
'completed', 'failed', 'cancelled', 'unknown']


RateLimitStrategy = Literal["sliding_window", "fixed_window"]


class RateLimit(TypedDict):
max_rate: int
rate_period: float # time window for the rate limit in seconds
key: NotRequired[str] # defaults to task name, can be set to group rate limits across tasks
strategy: NotRequired[RateLimitStrategy] # defaults to "sliding_window"


class TaskSpec(BaseModel):
"""Task specification.

Expand Down Expand Up @@ -116,7 +127,7 @@ def my_task():

timeout: float | None = None # seconds
retry_on_timeout: bool = False
# tags: dict[str, str] = Field(default_factory=dict)
rate_limit: RateLimit | None = None

@field_validator('retry_delay')
@classmethod
Expand Down
29 changes: 28 additions & 1 deletion src/sheppy/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,34 @@ async def _pop_pending(self, limit: int = 1, timeout: float | None = None) -> li
raise ValueError("Pop limit must be greater than zero.")

tasks_data = await self.backend.pop(self.name, limit, timeout)
return [Task.model_validate(t) for t in tasks_data]
tasks = []

for task_data in tasks_data:
task = Task.model_validate(task_data)

if task.config.rate_limit:
rl = task.config.rate_limit
key = rl.get("key", task.spec.func)

wait = await self.backend.acquire_rate_limit(
self.name,
key,
rl["max_rate"],
rl["rate_period"],
strategy=rl.get("strategy", "sliding_window"),
task_id=str(task.id),
)

if wait is not None:
scheduled_at = datetime.now(timezone.utc) + timedelta(seconds=wait)
task.__dict__["status"] = "scheduled"
task.__dict__["scheduled_at"] = scheduled_at
await self.backend.schedule(self.name, task.model_dump(mode='json'), scheduled_at, unique=False)
continue

tasks.append(task)

return tasks

async def _store_result(self, task: Task) -> bool:
"""Store task result. Internal method used by workers.
Expand Down
Loading