diff --git a/README.md b/README.md index e8212a5..0355fd4 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# FastAPI Throttle + # FastAPI Throttle [![PyPI](https://img.shields.io/pypi/v/fastapi-throttle?logo=pypi&label=PyPI)](https://pypi.org/project/fastapi-throttle/) [![CI](https://github.com/AliYmn/fastapi-throttle/actions/workflows/ci.yml/badge.svg)](https://github.com/AliYmn/fastapi-throttle/actions?query=workflow%3ACI) @@ -64,6 +64,9 @@ async def root(): - **Simple Configuration**: Just specify request count and time window - **Route-Level Control**: Apply different limits to different endpoints - **FastAPI Integration**: Works with FastAPI's dependency injection system +- **Custom Keying (optional)**: Provide a `key_func(Request) -> str` to limit by user, API key, path, etc. +- **Proxy-Aware (optional)**: `trust_proxy=True` to use `X-Forwarded-For` when behind proxies/CDNs +- **Rate-Limit Headers (optional)**: Add `X-RateLimit-*` and `Retry-After` for clients ## Usage Examples @@ -109,14 +112,67 @@ async def get_resource(): app.include_router(router) ``` +### Custom Key Function (limit by user or path) + +```python +from fastapi import FastAPI, Depends, Request +from fastapi_throttle import RateLimiter + +app = FastAPI() + +def user_key(req: Request) -> str: + # Example: extract user-id from header or auth (for demo only) + return req.headers.get("x-user-id", req.client.host or "unknown") + +limiter = RateLimiter(times=10, seconds=60, key_func=user_key) + +@app.get("/data", dependencies=[Depends(limiter)]) +async def data(): + return {"ok": True} +``` + +### Behind proxy/CDN (trust X-Forwarded-For) + +```python +from fastapi import FastAPI, Depends +from fastapi_throttle import RateLimiter + +app = FastAPI() + +proxy_limit = RateLimiter(times=5, seconds=30, trust_proxy=True) + +@app.get("/proxy", dependencies=[Depends(proxy_limit)]) +async def proxy_route(): + return {"ok": True} +``` + +### Standard rate-limit headers + +```python +from fastapi import FastAPI, Depends +from fastapi_throttle import RateLimiter + +app = FastAPI() + +headers_limit = RateLimiter(times=5, seconds=60, add_headers=True) + +@app.get("/limited", dependencies=[Depends(headers_limit)]) +async def limited(): + return {"message": "Check X-RateLimit-* headers"} +``` + ## Configuration -The `RateLimiter` class takes two parameters: +The `RateLimiter` class parameters: | Parameter | Type | Description | |-----------|------|-------------| | `times` | int | Maximum number of requests allowed in the time window | | `seconds` | int | Time window in seconds | +| `detail` | str | Optional custom detail message for 429 responses | +| `key_func` | `Callable[[Request], str]` | Optional custom function to compute the rate-limit key | +| `trust_proxy` | bool | If True, tries `X-Forwarded-For` for client identification (default False) | +| `add_headers` | bool | If True, adds `X-RateLimit-Limit`, `X-RateLimit-Remaining`, and `Retry-After` headers | ## How It Works @@ -127,13 +183,17 @@ The rate limiter: 4. Counts requests within the window 5. Returns HTTP 429 when limit is exceeded +## Notes + +- When `add_headers=True`, successful responses (2xx) include `X-RateLimit-Limit` and `X-RateLimit-Remaining`. On 429 responses, only `Retry-After` is included. +- Use `trust_proxy=True` only when running behind a trusted proxy/load balancer that correctly sets `X-Forwarded-For`. The first IP is treated as the client. + ## Limitations - **Memory Storage**: Data is lost when the application restarts -- **Single-Server Only**: Not designed for distributed environments -- **IP-Based Identification**: May not work well with shared IPs or proxies +- **Single-Server Only**: Intended for monoliths/single-worker setups (in-memory, no cross-process sync) +- **IP-Based Identification (default)**: With shared IPs/proxies, prefer `key_func` or `trust_proxy` - **Memory Usage**: Can grow with number of unique clients -- **No Rate Limit Headers**: Doesn't include standard rate limit headers in responses ## When to Use @@ -143,7 +203,7 @@ FastAPI Throttle is ideal for: - Projects where simplicity is valued over advanced features - Development and testing environments -For high-traffic production applications or distributed systems, consider a Redis-based solution. +For high-traffic production applications or distributed systems, prefer a distributed rate limiter (e.g., Redis-backed). This package intentionally avoids Redis and focuses on simplicity. ## Testing @@ -160,9 +220,8 @@ pytest --cov=fastapi_throttle -q ## Roadmap -- Add optional Redis backend for distributed environments -- Optional standard rate-limit headers in responses - Middleware variant in addition to dependency-based limiter +- (Maybe) Pluggable storage interface if a second backend is introduced later ## License diff --git a/fastapi_throttle/limiter.py b/fastapi_throttle/limiter.py index a382158..5d607ee 100644 --- a/fastapi_throttle/limiter.py +++ b/fastapi_throttle/limiter.py @@ -1,6 +1,6 @@ import time -from fastapi import Request, HTTPException -from typing import Dict, List, Optional +from fastapi import Request, HTTPException, Response +from typing import Callable, Dict, List, Optional class RateLimiter: @@ -13,11 +13,14 @@ class RateLimiter: Attributes: times (int): The maximum number of requests allowed per client within the specified period. seconds (int): The time window in seconds during which requests are counted. - requests (Dict[str, List[float]]): A dictionary storing request timestamps for each client IP. + requests (Dict[str, List[float]]): A dictionary storing request timestamps for each computed key. detail (str): The detail message to be returned to the client if the requests exceed the limit within the specified period. + key_func (Optional[Callable[[Request], str]]): Optional function to compute a custom rate-limit key. + trust_proxy (bool): When True, uses X-Forwarded-For to determine client IP (first hop). + add_headers (bool): When True, adds rate-limit headers to the response. """ - def __init__(self, times: int, seconds: int, detail : Optional[str] = None) -> None: + def __init__(self, times: int, seconds: int, detail: Optional[str] = None, *, key_func: Optional[Callable[[Request], str]] = None, trust_proxy: bool = False, add_headers: bool = False) -> None: """ Initializes the RateLimiter instance with the specified request limit and time period. @@ -25,15 +28,20 @@ def __init__(self, times: int, seconds: int, detail : Optional[str] = None) -> N times (int): The maximum number of requests allowed per client. seconds (int): The time period in seconds for rate limiting. detail (str): The detail message to be returned to the client if rate limit is exceeded. + key_func (Callable[[Request], str], optional): Custom key function. Defaults to None (client IP based). + trust_proxy (bool): If True, attempts to use X-Forwarded-For header for client identification. + add_headers (bool): If True, attaches standard rate limit headers to the response. """ self.times: int = times self.seconds: int = seconds self.requests: Dict[str, List[float]] = {} - self.detail: str = detail - if self.detail is None: - self.detail: str = "Too Many Requests" + # Ensure non-None detail without violating typing + self.detail: str = detail or "Too Many Requests" + self.key_func: Optional[Callable[[Request], str]] = key_func + self.trust_proxy: bool = trust_proxy + self.add_headers: bool = add_headers - async def __call__(self, request: Request) -> None: + async def __call__(self, request: Request, response: Response) -> None: """ Checks if the incoming request exceeds the allowed rate limit. @@ -43,27 +51,60 @@ async def __call__(self, request: Request) -> None: Args: request (Request): The incoming HTTP request object. + response (Response): The outgoing HTTP response object. Raises: HTTPException: If the request rate limit is exceeded, a 429 status code is returned. """ - client_ip: str = request.client.host - current_time: float = time.time() + client = request.client + # Compute key: custom key_func takes precedence + if self.key_func is not None: + try: + key: str = self.key_func(request) + except Exception: + # Fail-safe: do not break request handling if custom key function errors out + key = "unknown" + else: + # Default behavior: determine client IP, optionally trusting proxy headers + key = "unknown" + if self.trust_proxy: + xff = request.headers.get("x-forwarded-for") + if xff: + # First IP in X-Forwarded-For is the original client + key = xff.split(",")[0].strip() or key + if key == "unknown": + key = client.host if (client and getattr(client, "host", None)) else "unknown" + current_time: float = time.monotonic() + window_start: float = current_time - self.seconds - # Initialize the client's request history if not already present - if client_ip not in self.requests: - self.requests[client_ip] = [] + # Get and prune timestamps inside the window + existing = self.requests.get(key, []) + filtered: List[float] = [ts for ts in existing if ts > window_start] + if not filtered: + # Small hygiene: if list becomes empty, drop the key to avoid empty buckets + if key in self.requests: + del self.requests[key] + current_count = 0 + else: + self.requests[key] = filtered + current_count = len(filtered) - # Filter out timestamps that are outside of the rate limit period - self.requests[client_ip] = [ - timestamp - for timestamp in self.requests[client_ip] - if timestamp > current_time - self.seconds - ] - - # Check if the number of requests exceeds the allowed limit - if len(self.requests[client_ip]) >= self.times: - raise HTTPException(status_code=429, detail=self.detail) + if current_count >= self.times: + # Compute Retry-After: time until the oldest timestamp leaves the window + oldest = min(filtered) if filtered else current_time + retry_after = int(max(0.0, self.seconds - (current_time - oldest))) + headers = {"Retry-After": str(retry_after)} if retry_after > 0 else None + raise HTTPException(status_code=429, detail=self.detail, headers=headers) # Record the current request timestamp - self.requests[client_ip].append(current_time) + if filtered: + # Append to existing filtered list + self.requests[key].append(current_time) + else: + # Create a fresh list for this key + self.requests[key] = [current_time] + + # Optionally attach standard rate limit headers + if self.add_headers: + response.headers["X-RateLimit-Limit"] = str(self.times) + response.headers["X-RateLimit-Remaining"] = str(max(0, self.times - len(self.requests[key]))) diff --git a/index.rst b/index.rst index 03d324d..b195fdc 100644 --- a/index.rst +++ b/index.rst @@ -6,10 +6,13 @@ `fastapi-throttle` is a simple in-memory rate limiter for FastAPI applications. This package allows you to control the number of requests a client can make to your API within a specified time window without relying on external dependencies like Redis. It is ideal for lightweight applications where simplicity and speed are paramount. ## Features -- **Without Redis** : You don’t need to install or configure Redis. +- **Without Redis** : You don’t need to install or configure Redis (monolith/single-worker focus). - **In-Memory Rate Limiting**: No external dependencies required. Keeps everything in memory for fast and simple rate limiting. - **Flexible Configuration**: Easily configure rate limits per route or globally. - **Python Version Support**: Compatible with Python 3.8 up to 3.12. +- **Custom Keying (optional)**: Provide a key function to limit by user, API key, path, etc. +- **Proxy-Aware (optional)**: `trust_proxy=True` to read `X-Forwarded-For` behind proxies/CDNs. +- **Rate-Limit Headers (optional)**: Add `X-RateLimit-*` and `Retry-After` headers for clients. ## Installation @@ -57,6 +60,14 @@ async def route2(): ## Configuration - times: The maximum number of requests allowed per client within the specified period. - seconds: The time window in seconds within which the requests are counted. +- detail: Optional custom detail message for 429 responses. +- key_func: Optional callable `key_func(Request) -> str` to compute a custom key (e.g., user id). +- trust_proxy: If True, use the first IP from `X-Forwarded-For` when present (default False). +- add_headers: If True, add `X-RateLimit-Limit`, `X-RateLimit-Remaining`, and `Retry-After`. + +## Notes +- When `add_headers=True`, successful responses (200 range) include `X-RateLimit-Limit` and `X-RateLimit-Remaining`. On 429, only `Retry-After` is added. +- Use `trust_proxy=True` only when your app is behind a trusted proxy/load balancer that correctly sets `X-Forwarded-For`. The first IP is treated as the client. ## Example with Custom Configuration Here is an example where you use custom rate limiting per endpoint: @@ -71,3 +82,50 @@ app = FastAPI() async def custom(): return {"message": "This is a custom route with its own rate limit."} ``` + +## Advanced Examples + +### Custom key function (limit by user or path) +```python +from fastapi import FastAPI, Depends, Request +from fastapi_throttle import RateLimiter + +app = FastAPI() + +def user_key(req: Request) -> str: + return req.headers.get("x-user-id", req.client.host or "unknown") + +limiter = RateLimiter(times=10, seconds=60, key_func=user_key) + +@app.get("/data", dependencies=[Depends(limiter)]) +async def data(): + return {"ok": True} +``` + +### Behind proxy/CDN (trust X-Forwarded-For) +```python +from fastapi import FastAPI, Depends +from fastapi_throttle import RateLimiter + +app = FastAPI() + +proxy_limit = RateLimiter(times=5, seconds=30, trust_proxy=True) + +@app.get("/proxy", dependencies=[Depends(proxy_limit)]) +async def proxy_route(): + return {"ok": True} +``` + +### Standard rate-limit headers +```python +from fastapi import FastAPI, Depends +from fastapi_throttle import RateLimiter + +app = FastAPI() + +headers_limit = RateLimiter(times=5, seconds=60, add_headers=True) + +@app.get("/limited", dependencies=[Depends(headers_limit)]) +async def limited(): + return {"message": "Check X-RateLimit-* headers"} +``` diff --git a/requirements.txt b/requirements.txt index 9f2525e..014e672 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ fastapi==0.116.1 uvicorn==0.33.0 pytest==8.3.5 -pytest-cov==6.2.1 +pytest-cov==5.0.0 httpx==0.28.1 diff --git a/setup.py b/setup.py index 845dc8c..92f992a 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( name="fastapi-throttle", - version="0.1.7", + version="0.1.8", packages=find_packages(), install_requires=[ "fastapi", diff --git a/tests/test_limiter.py b/tests/test_limiter.py index a53c28d..4ca79b0 100644 --- a/tests/test_limiter.py +++ b/tests/test_limiter.py @@ -43,3 +43,94 @@ async def route2(): response = client.get("/route2") assert response.status_code == 200 # Limit should be reset + + +def test_add_headers_and_retry_after(): + app = FastAPI() + + limiter = RateLimiter(times=2, seconds=5, add_headers=True) + + @app.get("/limited", dependencies=[Depends(limiter)]) + async def limited(): + return {"ok": True} + + client = TestClient(app) + + # First request should include X-RateLimit headers + r1 = client.get("/limited") + assert r1.status_code == 200 + assert r1.headers.get("X-RateLimit-Limit") == "2" + assert r1.headers.get("X-RateLimit-Remaining") == "1" + + # Second request still 200, remaining becomes 0 + r2 = client.get("/limited") + assert r2.status_code == 200 + assert r2.headers.get("X-RateLimit-Limit") == "2" + assert r2.headers.get("X-RateLimit-Remaining") == "0" + + # Third request should hit 429 and include Retry-After + r3 = client.get("/limited") + assert r3.status_code == 429 + assert r3.headers.get("Retry-After") is not None + assert r3.headers["Retry-After"].isdigit() + + +def test_retry_after_without_add_headers(): + app = FastAPI() + + limiter = RateLimiter(times=1, seconds=5, add_headers=False) + + @app.get("/rl", dependencies=[Depends(limiter)]) + async def rl(): + return {"ok": True} + + client = TestClient(app) + assert client.get("/rl").status_code == 200 + r2 = client.get("/rl") + assert r2.status_code == 429 + assert r2.headers.get("Retry-After") is not None + + +def test_trust_proxy_uses_x_forwarded_for(): + app = FastAPI() + + limiter = RateLimiter(times=1, seconds=10, trust_proxy=True) + + @app.get("/p", dependencies=[Depends(limiter)]) + async def p(): + return {"ok": True} + + client = TestClient(app) + + # Same client but with same first IP should hit the limit on second call + headers_a = {"X-Forwarded-For": "1.1.1.1, 2.2.2.2"} + assert client.get("/p", headers=headers_a).status_code == 200 + assert client.get("/p", headers=headers_a).status_code == 429 + + # Different first IP should be treated as a different key + headers_b = {"X-Forwarded-For": "3.3.3.3, 4.4.4.4"} + assert client.get("/p", headers=headers_b).status_code == 200 + + +def test_custom_key_func_by_user(): + app = FastAPI() + + def user_key(req): + return req.headers.get("x-user-id", "anon") + + limiter = RateLimiter(times=1, seconds=10, key_func=user_key) + + @app.get("/u", dependencies=[Depends(limiter)]) + async def u(): + return {"ok": True} + + client = TestClient(app) + + # user A limited on second call + headers_a = {"x-user-id": "A"} + assert client.get("/u", headers=headers_a).status_code == 200 + assert client.get("/u", headers=headers_a).status_code == 429 + + # user B is a different key, first call is allowed + headers_b = {"x-user-id": "B"} + assert client.get("/u", headers=headers_b).status_code == 200