From 2448993b516e0359cd4a7c6e31a9c49129689e35 Mon Sep 17 00:00:00 2001 From: BarnabasG Date: Sun, 12 Apr 2026 23:11:15 +0100 Subject: [PATCH 1/2] 1.3.5 Performance enhancements, cleanups --- README.md | 4 +- pyproject.toml | 3 +- src/pytest_api_cov/cli.py | 29 +-- src/pytest_api_cov/config.py | 64 +++--- src/pytest_api_cov/frameworks.py | 72 +++--- src/pytest_api_cov/models.py | 84 ++++--- src/pytest_api_cov/openapi.py | 7 +- src/pytest_api_cov/plugin.py | 247 ++++++--------------- src/pytest_api_cov/report.py | 79 +++---- tests/unit/test_cli.py | 31 ++- tests/unit/test_config.py | 54 +++-- tests/unit/test_create_coverage_fixture.py | 12 +- tests/unit/test_frameworks.py | 82 ++++--- tests/unit/test_models.py | 84 +++---- tests/unit/test_openapi.py | 22 +- tests/unit/test_plugin.py | 91 ++++---- tests/unit/test_report.py | 110 ++++----- uv.lock | 13 +- 18 files changed, 465 insertions(+), 623 deletions(-) diff --git a/README.md b/README.md index 53e633c..43afd1a 100644 --- a/README.md +++ b/README.md @@ -234,7 +234,7 @@ Or use the CLI flag multiple times: pytest --api-cov-report --api-cov-client-fixture-names=my_custom_client --api-cov-client-fixture-names=another_fixture ``` -If the configured fixture(s) are not found, the plugin will try to use an `app` fixture (if present) to create a tracked client. If neither is available or the plugin cannot extract the app from a discovered client fixture, the tests will still run — coverage will simply be unavailable and a warning will be logged. +If the configured fixture(s) are not found, the plugin will try to use an `app` fixture (if present) to create a tracked client. If neither is available or the plugin cannot extract the app from a discovered client fixture, the tests will still run - coverage will simply be unavailable and a warning will be logged. #### Option 2: Helper Function @@ -514,7 +514,7 @@ If coverage is not running because the plugin could not locate an app, check the - Ensure you are running pytest with `--api-cov-report` enabled. - Confirm you have a test client fixture (e.g. `client`, `test_client`, `api_client`) or an `app` fixture in your test suite. - If you use a custom client fixture, add its name to `client_fixture_names` in `pyproject.toml` or pass it via the CLI using `--api-cov-client-fixture-names` (repeatable) so the plugin can find and wrap it. -- If the plugin finds the client fixture but cannot extract the underlying app (for example the client type is not supported or wrapped in an unexpected way), you will see a message like "Could not extract app from client" — in that case either provide an `app` fixture directly or wrap your existing client using `create_coverage_fixture`. +- If the plugin finds the client fixture but cannot extract the underlying app (for example the client type is not supported or wrapped in an unexpected way), you will see a message like "Could not extract app from client" - in that case either provide an `app` fixture directly or wrap your existing client using `create_coverage_fixture`. ### No endpoints Discovered diff --git a/pyproject.toml b/pyproject.toml index 4e08255..62fa1da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pytest-api-cov" -version = "1.3.4" +version = "1.3.5" description = "Pytest Plugin to provide API Coverage statistics for Python Web Frameworks" readme = "README.md" authors = [{ name = "Barnaby Gill", email = "barnabasgill@gmail.com" }] @@ -39,6 +39,7 @@ dev = [ "flask>=2.0.0", "httpx>=0.20.0", "starlette>=0.14.0", + "pybencher>=2.1.0", ] # API COVERAGE diff --git a/src/pytest_api_cov/cli.py b/src/pytest_api_cov/cli.py index 3430486..7484e91 100644 --- a/src/pytest_api_cov/cli.py +++ b/src/pytest_api_cov/cli.py @@ -1,16 +1,10 @@ """CLI commands for setup and configuration.""" import argparse -import sys def generate_conftest_content(framework: str, file_path: str, app_variable: str) -> str: - """Generate conftest.py content based on provided framework/module/app variable. - - This is a non-interactive helper that returns example content — the project - no longer performs automatic file-scanning. Use this helper to bootstrap a - `conftest.py` if desired. - """ + """Generate example conftest.py content for a given framework.""" module_path = file_path.replace("/", ".").replace("\\", ".").replace(".py", "") if framework == "FastAPI": @@ -23,7 +17,7 @@ def generate_conftest_content(framework: str, file_path: str, app_variable: str) test_client_import = "" client_creation = "# Create and return a test client for your framework" - return f'''"""conftest.py - Example generated by pytest-api-cov CLI (non-interactive)""" + return f'''"""conftest.py - Example generated by pytest-api-cov CLI""" import pytest {test_client_import} @@ -33,11 +27,7 @@ def generate_conftest_content(framework: str, file_path: str, app_variable: str) @pytest.fixture def client(): - """Standard test client fixture for {framework}. - - The pytest-api-cov plugin can extract the app from your client fixture - and wrap it with coverage tracking when enabled. - """ + """Test client fixture for {framework}.""" {client_creation} @@ -49,7 +39,7 @@ def client(): def generate_pyproject_config() -> str: - """Generate pyproject.toml configuration section.""" + """Generate example pyproject.toml configuration section.""" return """ # pytest-api-cov configuration [tool.pytest_api_cov] @@ -83,18 +73,11 @@ def generate_pyproject_config() -> str: def main() -> int: - """Run the main CLI entry point. - - Note: the previous interactive "init" wizard was removed. This CLI - provides programmatic helpers to generate example `conftest.py` and - `pyproject.toml` content; use those functions or create a manual - `conftest.py`/`pyproject.toml` as described in the README. - """ + """CLI entry point.""" parser = argparse.ArgumentParser(prog="pytest-api-cov", description="pytest API coverage plugin CLI tools") subparsers = parser.add_subparsers(dest="command", help="Available commands") - # Keep a non-interactive 'show-config' command for convenience subparsers.add_parser("show-pyproject", help="Print example pyproject.toml configuration") show_conftest = subparsers.add_parser("show-conftest", help="Print example conftest content") show_conftest.add_argument("framework", nargs=1, help="Framework name (FastAPI|Flask)") @@ -119,4 +102,6 @@ def main() -> int: if __name__ == "__main__": + import sys + sys.exit(main()) diff --git a/src/pytest_api_cov/config.py b/src/pytest_api_cov/config.py index a4c00fa..43285d3 100644 --- a/src/pytest_api_cov/config.py +++ b/src/pytest_api_cov/config.py @@ -1,8 +1,10 @@ """Configuration handling for the API coverage report.""" +from __future__ import annotations + import sys from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any import tomli from pydantic import BaseModel, ConfigDict, Field @@ -13,22 +15,22 @@ class ApiCoverageReportConfig(BaseModel): model_config = ConfigDict(populate_by_name=True) - fail_under: Optional[float] = Field(None, alias="api-cov-fail-under") + fail_under: float | None = Field(None, alias="api-cov-fail-under") show_uncovered_endpoints: bool = Field(default=True, alias="api-cov-show-uncovered-endpoints") show_covered_endpoints: bool = Field(default=False, alias="api-cov-show-covered-endpoints") show_excluded_endpoints: bool = Field(default=False, alias="api-cov-show-excluded-endpoints") - exclusion_patterns: List[str] = Field(default=[], alias="api-cov-exclusion-patterns") - report_path: Optional[str] = Field(None, alias="api-cov-report-path") + exclusion_patterns: list[str] = Field(default=[], alias="api-cov-exclusion-patterns") + report_path: str | None = Field(None, alias="api-cov-report-path") force_sugar: bool = Field(default=False, alias="api-cov-force-sugar") force_sugar_disabled: bool = Field(default=False, alias="api-cov-force-sugar-disabled") - client_fixture_names: List[str] = Field( + client_fixture_names: list[str] = Field( ["client", "test_client", "api_client", "app_client"], alias="api-cov-client-fixture-names" ) group_methods_by_endpoint: bool = Field(default=False, alias="api-cov-group-methods-by-endpoint") - openapi_spec: Optional[str] = Field(None, alias="api-cov-openapi-spec") + openapi_spec: str | None = Field(None, alias="api-cov-openapi-spec") -def read_toml_config() -> Dict[str, Any]: +def read_toml_config() -> dict[str, Any]: """Read the [tool.pytest_api_cov] section from pyproject.toml.""" try: with Path("pyproject.toml").open("rb") as f: @@ -38,28 +40,31 @@ def read_toml_config() -> Dict[str, Any]: return {} -def read_session_config(session_config: Any) -> Dict[str, Any]: +_CLI_OPTIONS = { + "api-cov-fail-under": "fail_under", + "api-cov-show-uncovered-endpoints": "show_uncovered_endpoints", + "api-cov-show-covered-endpoints": "show_covered_endpoints", + "api-cov-show-excluded-endpoints": "show_excluded_endpoints", + "api-cov-exclusion-patterns": "exclusion_patterns", + "api-cov-report-path": "report_path", + "api-cov-force-sugar": "force_sugar", + "api-cov-force-sugar-disabled": "force_sugar_disabled", + "api-cov-client-fixture-names": "client_fixture_names", + "api-cov-group-methods-by-endpoint": "group_methods_by_endpoint", + "api-cov-openapi-spec": "openapi_spec", +} + +_UNSET = (None, [], False) + + +def read_session_config(session_config: Any) -> dict[str, Any]: """Read configuration from pytest session config (command-line flags).""" - cli_options = { - "api-cov-fail-under": "fail_under", - "api-cov-show-uncovered-endpoints": "show_uncovered_endpoints", - "api-cov-show-covered-endpoints": "show_covered_endpoints", - "api-cov-show-excluded-endpoints": "show_excluded_endpoints", - "api-cov-exclusion-patterns": "exclusion_patterns", - "api-cov-report-path": "report_path", - "api-cov-force-sugar": "force_sugar", - "api-cov-force-sugar-disabled": "force_sugar_disabled", - "api-cov-client-fixture-names": "client_fixture_names", - "api-cov-group-methods-by-endpoint": "group_methods_by_endpoint", - "api-cov-openapi-spec": "openapi_spec", - } - config = {} - for opt, key in cli_options.items(): + config: dict[str, Any] = {} + for opt, key in _CLI_OPTIONS.items(): value = session_config.getoption(f"--{opt}") - if value is not None and value != [] and value is not False: + if value not in _UNSET: config[key] = value - # Validating negation flags if session_config.getoption("--api-cov-hide-uncovered-endpoints"): config["show_uncovered_endpoints"] = False @@ -67,17 +72,14 @@ def read_session_config(session_config: Any) -> Dict[str, Any]: def supports_unicode() -> bool: - """Check if the environment supports Unicode characters.""" + """Check if the terminal supports Unicode output.""" if not sys.stdout.isatty(): return False - return bool(sys.stdout) and sys.stdout.encoding.lower() in ["utf-8", "utf8"] + return sys.stdout.encoding.lower() in ("utf-8", "utf8") def get_pytest_api_cov_report_config(session_config: Any) -> ApiCoverageReportConfig: - """Get the final API coverage configuration by merging sources. - - Priority: CLI > pyproject.toml > Defaults. - """ + """Build final config by merging sources. Priority: CLI > pyproject.toml > defaults.""" toml_config = read_toml_config() cli_config = read_session_config(session_config) diff --git a/src/pytest_api_cov/frameworks.py b/src/pytest_api_cov/frameworks.py index 8e7658a..979eca1 100644 --- a/src/pytest_api_cov/frameworks.py +++ b/src/pytest_api_cov/frameworks.py @@ -1,31 +1,33 @@ -"""Framework adapters for Flask and FastAPI.""" +"""Framework adapters for Flask, FastAPI, and Django.""" -from typing import TYPE_CHECKING, Any, List, Optional +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from .models import ApiCallRecorder -class BaseAdapter: - """Base adapter for framework applications.""" +class BaseAdapter(ABC): + """Abstract base for framework adapters.""" def __init__(self, app: Any) -> None: - """Initialize the adapter.""" self.app = app - def get_endpoints(self) -> List[str]: - """Return a list of all endpoint paths.""" - raise NotImplementedError + @abstractmethod + def get_endpoints(self) -> list[str]: + """Return a list of 'METHOD /path' strings.""" - def get_tracked_client(self, recorder: Optional["ApiCallRecorder"], test_name: str) -> Any: - """Return a patched test client that records calls.""" - raise NotImplementedError + @abstractmethod + def get_tracked_client(self, recorder: ApiCallRecorder | None, test_name: str) -> Any: + """Return a test client that records calls.""" class FlaskAdapter(BaseAdapter): """Adapter for Flask applications.""" - def get_endpoints(self) -> List[str]: + def get_endpoints(self) -> list[str]: """Return list of 'METHOD /path' strings.""" excluded_rules = ("/static/",) endpoints = [ @@ -38,8 +40,8 @@ def get_endpoints(self) -> List[str]: return sorted(endpoints) - def get_tracked_client(self, recorder: Optional["ApiCallRecorder"], test_name: str) -> Any: - """Return a patched test client that records calls.""" + def get_tracked_client(self, recorder: ApiCallRecorder | None, test_name: str) -> Any: + """Return a Flask test client with call tracking.""" from flask.testing import FlaskClient if recorder is None: @@ -65,13 +67,13 @@ def open(self, *args: Any, **kwargs: Any) -> Any: class FastAPIAdapter(BaseAdapter): """Adapter for FastAPI applications.""" - def get_endpoints(self) -> List[str]: + def get_endpoints(self) -> list[str]: """Return list of 'METHOD /path' strings.""" - endpoints: List[str] = [] + endpoints: list[str] = [] self._collect_routes(self.app.routes, "", endpoints) return sorted(endpoints) - def _collect_routes(self, routes: List[Any], prefix: str, endpoints: List[str]) -> None: + def _collect_routes(self, routes: list[Any], prefix: str, endpoints: list[str]) -> None: """Recursively collect endpoints from routes, including mounted sub-apps.""" from fastapi.routing import APIRoute from starlette.routing import Mount @@ -83,10 +85,8 @@ def _collect_routes(self, routes: List[Any], prefix: str, endpoints: List[str]) ) elif isinstance(route, Mount): mount_prefix = prefix + route.path - # Sub-app with its own routes (FastAPI/Starlette router) if hasattr(route, "routes") and route.routes: self._collect_routes(route.routes, mount_prefix, endpoints) - # WSGI middleware wrapping a supported framework elif hasattr(route, "app"): inner = _unwrap_wsgi_app(route.app) if inner is not None: @@ -95,8 +95,8 @@ def _collect_routes(self, routes: List[Any], prefix: str, endpoints: List[str]) method, path = ep.split(" ", 1) endpoints.append(f"{method} {mount_prefix}{path}") - def get_tracked_client(self, recorder: Optional["ApiCallRecorder"], test_name: str) -> Any: - """Return a patched test client that records calls.""" + def get_tracked_client(self, recorder: ApiCallRecorder | None, test_name: str) -> Any: + """Return a FastAPI/Starlette test client with call tracking.""" from starlette.testclient import TestClient if recorder is None: @@ -117,14 +117,14 @@ def send(self, *args: Any, **kwargs: Any) -> Any: class DjangoAdapter(BaseAdapter): """Adapter for Django applications.""" - def get_endpoints(self) -> List[str]: + def get_endpoints(self) -> list[str]: """Return list of 'METHOD /path' strings.""" from django.urls import get_resolver # type: ignore[import-untyped] from django.urls.resolvers import URLPattern, URLResolver # type: ignore[import-untyped] - endpoints: List[str] = [] + endpoints: list[str] = [] - def _extract_patterns(patterns: List[Any], prefix: str = "") -> None: + def _extract_patterns(patterns: list[Any], prefix: str = "") -> None: for pattern in patterns: if isinstance(pattern, URLPattern): route = str(pattern.pattern).strip("^$") @@ -145,8 +145,8 @@ def _extract_patterns(patterns: List[Any], prefix: str = "") -> None: _extract_patterns(get_resolver().url_patterns) return sorted(endpoints) - def get_tracked_client(self, recorder: Optional["ApiCallRecorder"], test_name: str) -> Any: - """Return a patched test client that records calls.""" + def get_tracked_client(self, recorder: ApiCallRecorder | None, test_name: str) -> Any: + """Return a Django test client with call tracking.""" from django.test import Client # type: ignore[import-untyped] if recorder is None: @@ -167,8 +167,6 @@ def request(self, **request: Any) -> Any: def _unwrap_wsgi_app(app: Any) -> Any: """Extract the inner WSGI app from middleware wrappers, if supported.""" - from .plugin import is_supported_framework - type_name = type(app).__name__ if type_name in ("WSGIMiddleware", "WSGIResponder"): inner = getattr(app, "app", None) @@ -177,6 +175,17 @@ def _unwrap_wsgi_app(app: Any) -> Any: return None +def is_supported_framework(app: Any) -> bool: + """Check if the app is a supported framework.""" + if app is None: + return False + try: + get_framework_adapter(app) + except TypeError: + return False + return True + + def get_framework_adapter(app: Any) -> BaseAdapter: """Detect the framework and return the appropriate adapter.""" app_type = type(app).__name__ @@ -186,16 +195,9 @@ def get_framework_adapter(app: Any) -> BaseAdapter: return FlaskAdapter(app) if module_name == "fastapi" and app_type == "FastAPI": return FastAPIAdapter(app) - - # Django detection - # Django apps are often WSGIHandlers or just the module 'django' is present if module_name == "django" or "django" in module_name: return DjangoAdapter(app) - # Check for Django WSGI handler specifically - if app_type == "WSGIHandler" and module_name == "django.core.handlers.wsgi": - return DjangoAdapter(app) - raise TypeError( f"Unsupported application type: {app_type}. pytest-api-coverage supports Flask, FastAPI, and Django." ) diff --git a/src/pytest_api_cov/models.py b/src/pytest_api_cov/models.py index 9c110d4..65c9970 100644 --- a/src/pytest_api_cov/models.py +++ b/src/pytest_api_cov/models.py @@ -1,101 +1,102 @@ """Data models for pytest-api-cov.""" -from typing import Any, Dict, Iterable, List, Set, Tuple +from __future__ import annotations + +from typing import Any from pydantic import BaseModel, Field class ApiCallRecorder(BaseModel): - """Model for tracking API endpoint calls during testing.""" + """Tracks API endpoint calls during testing.""" model_config = {"arbitrary_types_allowed": True} - calls: Dict[str, Set[str]] = Field(default_factory=dict) + calls: dict[str, set[str]] = Field(default_factory=dict) def record_call(self, endpoint: str, test_name: str, method: str = "GET") -> None: - """Record that a test called a specific method on an endpoint.""" - endpoint_method = self._format_endpoint_key(method, endpoint) - if endpoint_method not in self.calls: - self.calls[endpoint_method] = set() - self.calls[endpoint_method].add(test_name) + """Record that a test called a specific method+endpoint.""" + key = self._format_endpoint_key(method, endpoint) + self.calls.setdefault(key, set()).add(test_name) @staticmethod def _format_endpoint_key(method: str, endpoint: str) -> str: - """Format method and endpoint into a consistent key format.""" + """Format method and endpoint into a consistent key.""" return f"{method.upper()} {endpoint}" @staticmethod def _parse_endpoint_key(endpoint_key: str) -> tuple[str, str]: - """Parse an endpoint key back into method and endpoint parts.""" + """Parse an endpoint key back into (method, endpoint).""" if " " in endpoint_key: method, endpoint = endpoint_key.split(" ", 1) return method, endpoint - # Handle legacy format without method return "GET", endpoint_key - def merge(self, other: "ApiCallRecorder") -> None: + def merge(self, other: ApiCallRecorder) -> None: """Merge another recorder's data into this one.""" for endpoint, callers in other.calls.items(): - if endpoint not in self.calls: - self.calls[endpoint] = set() - self.calls[endpoint].update(callers) + self.calls.setdefault(endpoint, set()).update(callers) - def to_serializable(self) -> Dict[str, List[str]]: - """Convert to a serializable format (sets -> lists) for worker communication.""" + def to_serializable(self) -> dict[str, list[str]]: + """Convert to serializable format (sets -> lists) for worker communication.""" return {endpoint: list(callers) for endpoint, callers in self.calls.items()} @classmethod - def from_serializable(cls, data: Dict[str, List[str]]) -> "ApiCallRecorder": + def from_serializable(cls, data: dict[str, list[str]]) -> ApiCallRecorder: """Create from serializable format (lists -> sets).""" calls = {endpoint: set(callers) for endpoint, callers in data.items()} return cls(calls=calls) def __len__(self) -> int: - """Return number of endpoints recorded.""" return len(self.calls) def __contains__(self, endpoint: str) -> bool: - """Check if an endpoint has been recorded.""" return endpoint in self.calls - def items(self) -> Iterable[Tuple[str, Set[str]]]: - """Iterate over endpoint, callers pairs.""" + def items(self) -> Any: + """Iterate over (endpoint, callers) pairs.""" return self.calls.items() - def keys(self) -> Iterable[str]: + def keys(self) -> Any: """Get all recorded endpoints.""" return self.calls.keys() - def values(self) -> Iterable[Set[str]]: + def values(self) -> Any: """Get all caller sets.""" return self.calls.values() class EndpointDiscovery(BaseModel): - """Model for discovered API endpoints.""" + """Discovered API endpoints.""" - endpoints: List[str] = Field(default_factory=list) + endpoints: list[str] = Field(default_factory=list) + _seen: set[str] = set() discovery_source: str = Field(default="unknown") + def model_post_init(self, __context: Any) -> None: + """Sync the internal set with any pre-populated endpoints.""" + self._seen = set(self.endpoints) + def add_endpoint(self, endpoint: str, method: str = "GET") -> None: - """Add a discovered endpoint method.""" - endpoint_method = ApiCallRecorder._format_endpoint_key(method, endpoint) - if endpoint_method not in self.endpoints: - self.endpoints.append(endpoint_method) + """Add a discovered endpoint if not already present.""" + key = ApiCallRecorder._format_endpoint_key(method, endpoint) + if key not in self._seen: + self._seen.add(key) + self.endpoints.append(key) - def merge(self, other: "EndpointDiscovery") -> None: + def merge(self, other: EndpointDiscovery) -> None: """Merge another discovery's endpoints into this one.""" for endpoint in other.endpoints: - if endpoint not in self.endpoints: + if endpoint not in self._seen: + self._seen.add(endpoint) self.endpoints.append(endpoint) def __len__(self) -> int: - """Return number of discovered endpoints.""" return len(self.endpoints) class SessionData(BaseModel): - """Model for session-level API coverage data.""" + """Session-level API coverage data.""" recorder: ApiCallRecorder = Field(default_factory=ApiCallRecorder) discovered_endpoints: EndpointDiscovery = Field(default_factory=EndpointDiscovery) @@ -104,22 +105,13 @@ def record_call(self, endpoint: str, test_name: str, method: str = "GET") -> Non """Record an API call.""" self.recorder.record_call(endpoint, test_name, method) - def add_discovered_endpoint(self, endpoint: str, method_or_source: str = "GET", source: str = "unknown") -> None: - """Add a discovered endpoint method.""" - # Handle both old and new method signatures for backward compatibility - if method_or_source in ["flask_adapter", "fastapi_adapter", "worker", "unknown"]: - # Old signature: add_discovered_endpoint(endpoint, source) - method = "GET" - source = method_or_source - else: - # New signature: add_discovered_endpoint(endpoint, method, source) - method = method_or_source - + def add_discovered_endpoint(self, endpoint: str, method: str = "GET", source: str = "unknown") -> None: + """Add a discovered endpoint.""" if not self.discovered_endpoints.endpoints: self.discovered_endpoints.discovery_source = source self.discovered_endpoints.add_endpoint(endpoint, method) - def merge_worker_data(self, worker_recorder: Dict[str, Any], worker_endpoints: List[str]) -> None: + def merge_worker_data(self, worker_recorder: dict[str, Any], worker_endpoints: list[str]) -> None: """Merge data from a worker process.""" if isinstance(worker_recorder, dict): all_lists = worker_recorder and all(isinstance(v, list) for v in worker_recorder.values()) diff --git a/src/pytest_api_cov/openapi.py b/src/pytest_api_cov/openapi.py index 4bec5d3..23967cb 100644 --- a/src/pytest_api_cov/openapi.py +++ b/src/pytest_api_cov/openapi.py @@ -1,9 +1,10 @@ """OpenAPI spec parsing.""" +from __future__ import annotations + import json import logging from pathlib import Path -from typing import List import yaml @@ -12,7 +13,7 @@ HTTP_METHODS = {"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS", "TRACE"} -def parse_openapi_spec(path: str) -> List[str]: +def parse_openapi_spec(path: str) -> list[str]: """Parse OpenAPI spec and return list of 'METHOD /path' strings.""" spec_path = Path(path).resolve() if not spec_path.exists(): @@ -26,7 +27,7 @@ def parse_openapi_spec(path: str) -> List[str]: logger.exception("Failed to parse OpenAPI spec", exc_info=True) return [] - endpoints: List[str] = [] + endpoints: list[str] = [] for path_key, path_item in spec.get("paths", {}).items(): endpoints.extend(f"{method.upper()} {path_key}" for method in path_item if method.upper() in HTTP_METHODS) diff --git a/src/pytest_api_cov/plugin.py b/src/pytest_api_cov/plugin.py index 2ebdcca..9048cd7 100644 --- a/src/pytest_api_cov/plugin.py +++ b/src/pytest_api_cov/plugin.py @@ -1,11 +1,12 @@ """pytest plugin for API coverage tracking.""" import logging -from typing import Any, Optional, Tuple +from typing import Any import pytest from .config import ApiCoverageReportConfig, get_pytest_api_cov_report_config +from .frameworks import get_framework_adapter, is_supported_framework from .models import SessionData from .openapi import parse_openapi_spec from .pytest_flags import add_pytest_api_cov_flags @@ -37,8 +38,6 @@ def _discover_app_endpoints(app: Any, coverage_data: SessionData, fixture_name: return try: - from .frameworks import get_framework_adapter - adapter = get_framework_adapter(app) endpoints = adapter.get_endpoints() framework_name = type(app).__name__ @@ -52,37 +51,18 @@ def _discover_app_endpoints(app: Any, coverage_data: SessionData, fixture_name: logger.warning(f"> Failed to discover endpoints from app: {e}") -def is_supported_framework(app: Any) -> bool: - """Check if the app is a supported framework (Flask or FastAPI).""" - if app is None: - return False - - app_type = type(app).__name__ - module_name = getattr(type(app), "__module__", "").split(".")[0] - - return ( - (module_name == "flask" and app_type == "Flask") - or (module_name == "flask_openapi3" and app_type == "OpenAPI") - or (module_name == "fastapi" and app_type == "FastAPI") - or (module_name == "django.core.handlers.wsgi" and app_type == "WSGIHandler") - or (module_name == "django" or "django" in module_name) - ) - - -def extract_app_from_client(client: Any) -> Optional[Any]: - """Extract app from various client types.""" - # Typical attributes used by popular clients +def extract_app_from_client(client: Any) -> Any | None: + """Extract app from various test client types.""" if client is None: return None - # common attribute for requests-like test clients if hasattr(client, "app"): return client.app if hasattr(client, "application"): return client.application - # Starlette/requests transport internals + # Starlette/httpx transport internals if hasattr(client, "_transport") and hasattr(client._transport, "app"): return client._transport.app @@ -102,22 +82,14 @@ def pytest_configure(config: pytest.Config) -> None: if config.getoption("--api-cov-report"): verbosity = config.option.verbose - if verbosity >= 2: # -vv or more + if verbosity >= 2: log_level = logging.DEBUG - elif verbosity >= 1: # -v + elif verbosity >= 1: log_level = logging.INFO else: log_level = logging.WARNING logger.setLevel(log_level) - - # if not logger.handlers: - # handler = logging.StreamHandler() - # handler.setLevel(log_level) - # formatter = logging.Formatter("%(message)s") - # handler.setFormatter(formatter) - # logger.addHandler(handler) - logger.info("Initializing API coverage plugin...") if config.pluginmanager.hasplugin("xdist"): @@ -125,40 +97,25 @@ def pytest_configure(config: pytest.Config) -> None: def pytest_sessionstart(session: pytest.Session) -> None: - """Initialize the call recorder at the start of the session.""" + """Initialize coverage data at session start.""" if session.config.getoption("--api-cov-report"): session.api_coverage_data = SessionData() # type: ignore[attr-defined] -def create_coverage_fixture(fixture_name: str, existing_fixture_name: Optional[str] = None) -> Any: +def create_coverage_fixture(fixture_name: str, existing_fixture_name: str | None = None) -> Any: """Create a coverage-enabled fixture with a custom name. - Args: - fixture_name: The name for the new fixture - existing_fixture_name: Optional name of existing fixture to wrap + Example:: - Returns: - A pytest fixture function that can be used in conftest.py - - Example usage in conftest.py: - import pytest - from pytest_api_cov.plugin import create_coverage_fixture - - # Create a new fixture my_client = create_coverage_fixture('my_client') - - # Wrap an existing fixture flask_client = create_coverage_fixture('flask_client', 'original_flask_client') - """ def fixture_func(request: pytest.FixtureRequest) -> Any: """Coverage-enabled client fixture.""" session = request.node.session - # Do not skip tests; if coverage is disabled or not initialized, try to return an existing client coverage_enabled = bool(session.config.getoption("--api-cov-report")) - coverage_data = getattr(session, "api_coverage_data", None) # Try to obtain an existing client if requested @@ -170,12 +127,11 @@ def fixture_func(request: pytest.FixtureRequest) -> Any: except pytest.FixtureLookupError: logger.warning(f"> Existing fixture '{existing_fixture_name}' not found when creating '{fixture_name}'") - # If coverage is not enabled or recorder not available, return existing client (if any) + # Without coverage, just pass through the existing client if not coverage_enabled or coverage_data is None: if existing_client is not None: yield existing_client return - # Try to fall back to an app fixture to construct a client try: app = request.getfixturevalue("app") except pytest.FixtureLookupError: @@ -184,10 +140,7 @@ def fixture_func(request: pytest.FixtureRequest) -> Any: ) yield None return - # if we have an app, attempt to create a tracked client using adapter without recorder try: - from .frameworks import get_framework_adapter - adapter = get_framework_adapter(app) client = adapter.get_tracked_client(None, request.node.name) except Exception: # noqa: BLE001 @@ -197,14 +150,10 @@ def fixture_func(request: pytest.FixtureRequest) -> Any: yield client return - # At this point coverage is enabled and coverage_data exists config = get_pytest_api_cov_report_config(request.config) - - # Check for OpenAPI spec first _discover_openapi_endpoints(config, coverage_data) if existing_client is None: - # Try to find a client fixture by common names for name in config.client_fixture_names: try: existing_client = request.getfixturevalue(name) @@ -223,10 +172,8 @@ def fixture_func(request: pytest.FixtureRequest) -> Any: except pytest.FixtureLookupError: app = None - # Discover endpoints from app if not already discovered _discover_app_endpoints(app, coverage_data, fixture_name) - # If we have an existing client, wrap it; otherwise try to create a tracked client from app if existing_client is not None: wrapped = wrap_client_with_coverage(existing_client, coverage_data.recorder, request.node.name) yield wrapped @@ -234,8 +181,6 @@ def fixture_func(request: pytest.FixtureRequest) -> Any: if app is not None: try: - from .frameworks import get_framework_adapter - adapter = get_framework_adapter(app) client = adapter.get_tracked_client(coverage_data.recorder, request.node.name) except Exception as e: # noqa: BLE001 @@ -244,7 +189,7 @@ def fixture_func(request: pytest.FixtureRequest) -> Any: yield client return - # Last resort: yield None but do not skip + # Last resort — yield None but don't skip, so tests still run logger.warning( f"> create_coverage_fixture('{fixture_name}') could not provide a client; " "tests will run without API coverage for this fixture." @@ -264,41 +209,27 @@ class CoverageWrapper: def __init__(self, wrapped_client: Any) -> None: self._wrapped = wrapped_client - def _extract_path_and_method(self, name: str, args: Any, kwargs: Any) -> Optional[Tuple[str, str]]: - # Try several strategies to obtain a path and method - path = None - method = None - - # First, if args[0] looks like a string path + def _extract_path_and_method(self, name: str, args: Any, kwargs: Any) -> tuple[str, str] | None: + """Pull path and HTTP method from the call arguments.""" if args: first = args[0] if isinstance(first, str): path = first.partition("?")[0] method = kwargs.get("method", name).upper() - if method == "OPEN": - method = "GET" + return path, ("GET" if method == "OPEN" else method) - return path, method - - # For starlette/requests TestClient, args[0] may be a Request or PreparedRequest if hasattr(first, "url") and hasattr(first.url, "path"): try: - path = first.url.path - method = getattr(first, "method", name).upper() + return first.url.path, getattr(first, "method", name).upper() except Exception: # noqa: BLE001 pass - else: - return path, method if kwargs: path_kw = kwargs.get("path") or kwargs.get("url") or kwargs.get("uri") if isinstance(path_kw, str): path = path_kw.partition("?")[0] method = kwargs.get("method", name).upper() - if method == "OPEN": - method = "GET" - - return path, method + return path, ("GET" if method == "OPEN" else method) return None @@ -335,128 +266,96 @@ def tracked_open(*args: Any, **kwargs: Any) -> Any: return CoverageWrapper(client) -@pytest.fixture -def coverage_client(request: pytest.FixtureRequest) -> Any: - """Smart client fixture that wrap's user's existing test client with coverage tracking.""" +def _coverage_client_impl(request: pytest.FixtureRequest) -> Any: + """Inner generator shared by coverage_client and create_coverage_fixture.""" session = request.node.session - if not session.config.getoption("--api-cov-report"): - pytest.skip("API coverage not enabled. Use --api-cov-report flag.") + coverage_enabled = bool(session.config.getoption("--api-cov-report")) + coverage_data = getattr(session, "api_coverage_data", None) + + if not coverage_enabled or coverage_data is None: + # Try common client fixture names then app fixture + for name in ("client", "test_client", "api_client", "app_client"): + try: + yield request.getfixturevalue(name) + return + except pytest.FixtureLookupError: + continue + try: + app = request.getfixturevalue("app") + adapter = get_framework_adapter(app) + yield adapter.get_tracked_client(None, request.node.name) + except (pytest.FixtureLookupError, Exception): # noqa: BLE001 + yield None + return config = get_pytest_api_cov_report_config(request.config) - coverage_data = getattr(session, "api_coverage_data", None) - if coverage_data is None: - pytest.skip("API coverage data not initialized. This should not happen.") - - # Check for OpenAPI spec first - if config.openapi_spec and not coverage_data.discovered_endpoints.endpoints: - endpoints = parse_openapi_spec(config.openapi_spec) - if endpoints: - for endpoint_method in endpoints: - method, path = endpoint_method.split(" ", 1) - coverage_data.add_discovered_endpoint(path, method, "openapi_spec") - logger.info( - f"> pytest-api-coverage: Discovered {len(endpoints)} endpoints from OpenAPI spec: {config.openapi_spec}" - ) - else: - logger.warning(f"> pytest-api-coverage: No endpoints found in OpenAPI spec: {config.openapi_spec}") + _discover_openapi_endpoints(config, coverage_data) + # Find a client fixture client = None - for fixture_name in config.client_fixture_names: + for name in config.client_fixture_names: try: - client = request.getfixturevalue(fixture_name) - logger.info(f"> Found custom fixture '{fixture_name}', wrapping with coverage tracking") + client = request.getfixturevalue(name) + logger.info(f"> Found client fixture '{name}'") break except pytest.FixtureLookupError: - logger.debug(f"> Custom fixture '{fixture_name}' not found") continue - if client is None: - # Try to fallback to an 'app' fixture and create a tracked client + app = extract_app_from_client(client) if client else None + if app is None: try: app = request.getfixturevalue("app") - logger.info("> Found 'app' fixture, creating tracked client from app") - from .frameworks import get_framework_adapter - - adapter = get_framework_adapter(app) - client = adapter.get_tracked_client(coverage_data.recorder, request.node.name) except pytest.FixtureLookupError: - logger.warning("> No test client fixture found and no 'app' fixture available. Falling back to None") - client = None - except Exception as e: # noqa: BLE001 - logger.warning(f"> Failed to create tracked client from 'app' fixture: {e}") - client = None - - if client is None: - logger.warning("> Coverage client could not be created; tests will run without API coverage for this session.") - return None - - app = extract_app_from_client(client) - logger.debug(f"> Extracted app from client: {app}, app type: {type(app).__name__ if app else None}") - - if app is None: - logger.warning("> No app found, returning client without coverage tracking") - return client + app = None - if not is_supported_framework(app): - logger.warning( - f"> Unsupported framework: {type(app).__name__}. pytest-api-coverage supports Flask, FastAPI, and Django." - ) - return client - - try: - from .frameworks import get_framework_adapter + _discover_app_endpoints(app, coverage_data, "coverage_client") - adapter = get_framework_adapter(app) - logger.debug(f"> Got adapter: {adapter}, adapter type: {type(adapter).__name__ if adapter else None}") - except TypeError as e: - logger.warning(f"> Framework detection failed: {e}") - return client + if client is not None: + yield wrap_client_with_coverage(client, coverage_data.recorder, request.node.name) + return - if not coverage_data.discovered_endpoints.endpoints: + if app is not None: try: - endpoints = adapter.get_endpoints() - logger.debug(f"> Adapter returned {len(endpoints)} endpoints") - framework_name = type(app).__name__ - for endpoint_method in endpoints: - method, path = endpoint_method.split(" ", 1) - coverage_data.add_discovered_endpoint(path, method, f"{framework_name.lower()}_adapter") - logger.info(f"> pytest-api-coverage: Discovered {len(endpoints)} endpoints.") - logger.debug(f"> Discovered endpoints: {endpoints}") + adapter = get_framework_adapter(app) + yield adapter.get_tracked_client(coverage_data.recorder, request.node.name) + return except Exception as e: # noqa: BLE001 - logger.warning(f"> pytest-api-coverage: Could not discover endpoints. Error: {e}") - return client + logger.warning(f"> Failed to create tracked client: {e}") - return wrap_client_with_coverage(client, coverage_data.recorder, request.node.name) + logger.warning("> coverage_client could not provide a client; tests will run without API coverage.") + yield None + + +@pytest.fixture +def coverage_client(request: pytest.FixtureRequest) -> Any: + """Smart client fixture that wraps the user's test client with coverage tracking.""" + yield from _coverage_client_impl(request) def pytest_sessionfinish(session: pytest.Session) -> None: - """Generate the API coverage report at the end of the session.""" + """Generate the API coverage report at session end.""" if session.config.getoption("--api-cov-report"): coverage_data = getattr(session, "api_coverage_data", None) if coverage_data is None: logger.warning("> No API coverage data found. Plugin may not have been properly initialized.") return - logger.debug(f"> pytest-api-coverage: Generating report for {len(coverage_data.recorder)} recorded endpoints.") + logger.debug(f"> Generating report for {len(coverage_data.recorder)} recorded endpoints.") if hasattr(session.config, "workeroutput"): serializable_recorder = coverage_data.recorder.to_serializable() session.config.workeroutput["api_call_recorder"] = serializable_recorder session.config.workeroutput["discovered_endpoints"] = coverage_data.discovered_endpoints.endpoints logger.debug("> Sent API call data and discovered endpoints to master process") else: - logger.debug("> No workeroutput found, generating report for master data.") - worker_recorder_data = getattr(session.config, "worker_api_call_recorder", {}) worker_endpoints = getattr(session.config, "worker_discovered_endpoints", []) - # Merge worker data into session data if worker_recorder_data or worker_endpoints: coverage_data.merge_worker_data(worker_recorder_data, worker_endpoints) logger.debug(f"> Merged worker data: {len(worker_recorder_data)} endpoints") logger.debug(f"> Final merged data: {len(coverage_data.recorder)} recorded endpoints") - logger.debug(f"> Using discovered endpoints: {coverage_data.discovered_endpoints.endpoints}") api_cov_config = get_pytest_api_cov_report_config(session.config) status = generate_pytest_api_cov_report( @@ -475,33 +374,21 @@ def pytest_sessionfinish(session: pytest.Session) -> None: class DeferXdistPlugin: - """Simple class to defer pytest-xdist hook until we know it is installed.""" + """Defers pytest-xdist hook until we know it is installed.""" def pytest_testnodedown(self, node: Any) -> None: """Collect API call data from each worker as they finish.""" - logger.debug("> pytest-api-coverage: Worker node down.") + logger.debug("> Worker node down.") worker_data = node.workeroutput.get("api_call_recorder", {}) discovered_endpoints = node.workeroutput.get("discovered_endpoints", []) - logger.debug(f"> Worker data: {worker_data}") - logger.debug(f"> Worker discovered endpoints: {discovered_endpoints}") - # Merge API call data if worker_data: - logger.debug("> Worker data found, merging with current data.") current = getattr(node.config, "worker_api_call_recorder", {}) - logger.debug(f"> Current data before merge: {current}") - # Merge the worker data into current for endpoint, calls in worker_data.items(): - if endpoint not in current: - current[endpoint] = set() - elif not isinstance(current[endpoint], set): - current[endpoint] = set(current[endpoint]) - current[endpoint].update(calls) - logger.debug(f"> Updated endpoint {endpoint} with calls: {calls}") + current.setdefault(endpoint, set()).update(calls) node.config.worker_api_call_recorder = current - logger.debug(f"> Updated current data: {current}") if discovered_endpoints and not getattr(node.config, "worker_discovered_endpoints", []): node.config.worker_discovered_endpoints = discovered_endpoints diff --git a/src/pytest_api_cov/report.py b/src/pytest_api_cov/report.py index 1cdae58..eb12004 100644 --- a/src/pytest_api_cov/report.py +++ b/src/pytest_api_cov/report.py @@ -1,10 +1,12 @@ """API coverage report generation.""" +from __future__ import annotations + import json import re from pathlib import Path from re import Pattern -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any from rich.console import Console @@ -19,48 +21,41 @@ def endpoint_to_regex(endpoint: str) -> Pattern[str]: def contains_escape_characters(endpoint: str) -> bool: - """Escape special characters in the endpoint string.""" + """Check whether an endpoint contains dynamic path segments.""" return ("<" in endpoint and ">" in endpoint) or ("{" in endpoint and "}" in endpoint) def categorise_endpoints( - endpoints: List[str], - called_data: Dict[str, Set[str]], - exclusion_patterns: List[str], -) -> Tuple[List[str], List[str], List[str]]: + endpoints: list[str], + called_data: dict[str, set[str]], + exclusion_patterns: list[str], +) -> tuple[list[str], list[str], list[str]]: """Categorise endpoints into covered, uncovered, and excluded. - Exclusion patterns support simple wildcard matching with negation and optional - HTTP method prefixes: - - Use * for wildcard (matches any characters) - - Use ! at the start to negate a pattern (include what would otherwise be excluded) - - Optionally prefix a pattern with one or more HTTP methods to target only those methods, - e.g. "GET /health" or "GET,POST /users/*" (methods are case-insensitive) - - All other characters are matched literally - - Examples: "/admin/*", "/health", "!users/bob", "GET /health", "GET,POST /users/*" - - Pattern order matters: exclusions are applied first, then negations override them + Exclusion patterns support wildcard matching with negation and optional + HTTP method prefixes. Pattern order matters: exclusions first, then + negations override them. """ - covered, uncovered, excluded = [], [], [] + covered: list[str] = [] + uncovered: list[str] = [] + excluded: list[str] = [] if not exclusion_patterns: compiled_exclusions = None compiled_negations = None else: - # Separate exclusion and negation patterns exclusion_only = [p for p in exclusion_patterns if not p.startswith("!")] - negation_only = [p[1:] for p in exclusion_patterns if p.startswith("!")] # Remove the '!' prefix + negation_only = [p[1:] for p in exclusion_patterns if p.startswith("!")] - def compile_patterns(patterns: List[str]) -> List[Tuple[Optional[Set[str]], Pattern[str]]]: - compiled: List[Tuple[Optional[Set[str]], Pattern[str]]] = [] + def compile_patterns(patterns: list[str]) -> list[tuple[set[str] | None, Pattern[str]]]: + compiled: list[tuple[set[str] | None, Pattern[str]]] = [] for pat in patterns: path_pattern = pat.strip() - methods: Optional[Set[str]] = None - # Detect method prefix + methods: set[str] | None = None m = re.match(r"^([A-Za-z,]+)\s+(.+)$", pat) if m: methods = {mname.strip().upper() for mname in m.group(1).split(",") if mname.strip()} path_pattern = m.group(2) - # Build regex from the path part regex = re.compile("^" + re.escape(path_pattern).replace(r"\*", ".*") + "$") compiled.append((methods, regex)) return compiled @@ -69,7 +64,6 @@ def compile_patterns(patterns: List[str]) -> List[Tuple[Optional[Set[str]], Patt compiled_negations = compile_patterns(negation_only) if negation_only else None for endpoint in endpoints: - # Check exclusion patterns against both full "METHOD /path" and just "/path" is_excluded = False endpoint_method = None path_only = endpoint @@ -80,37 +74,27 @@ def compile_patterns(patterns: List[str]) -> List[Tuple[Optional[Set[str]], Patt if compiled_exclusions: for methods_set, regex in compiled_exclusions: if methods_set: - if not endpoint_method: - continue - if endpoint_method not in methods_set: + if not endpoint_method or endpoint_method not in methods_set: continue if regex.match(path_only) or regex.match(endpoint): is_excluded = True break - # No methods specified elif regex.match(path_only) or regex.match(endpoint): is_excluded = True break - # Negation patterns if is_excluded and compiled_negations: - is_negated = False for methods_set, regex in compiled_negations: if methods_set: - if not endpoint_method: - continue - if endpoint_method not in methods_set: + if not endpoint_method or endpoint_method not in methods_set: continue if regex.match(path_only) or regex.match(endpoint): - is_negated = True + is_excluded = False break elif regex.match(path_only) or regex.match(endpoint): - is_negated = True + is_excluded = False break - if is_negated: - is_excluded = False # Negation overrides exclusion - if is_excluded: excluded.append(endpoint) continue @@ -126,21 +110,18 @@ def compile_patterns(patterns: List[str]) -> List[Tuple[Optional[Set[str]], Patt def print_endpoints( console: Console, label: str, - endpoints: List[str], + endpoints: list[str], symbol: str, style: str, ) -> None: - """Print a list of endpoints to the console with a label and style.""" + """Print a list of endpoints to the console.""" if endpoints: console.print(f"[{style}]{label}[/]:") for endpoint in endpoints: - # Format endpoint with consistent spacing for HTTP methods if " " in endpoint: method, path = endpoint.split(" ", 1) - # Pad method to 6 characters (longest common method is DELETE) formatted_endpoint = f"{method:<6} {path}" else: - # Handle legacy format without method formatted_endpoint = endpoint console.print(f" {symbol}\t[{style}]{formatted_endpoint}[/]") @@ -151,13 +132,13 @@ def compute_coverage(covered_count: int, uncovered_count: int) -> float: return round(100 * covered_count / total, 2) if total > 0 else 0.0 -def prepare_endpoint_detail(endpoints: List[str], called_data: Dict[str, Set[str]]) -> List[Dict[str, Any]]: - """Prepare endpoint details by mapping each endpoint to its callers.""" +def prepare_endpoint_detail(endpoints: list[str], called_data: dict[str, set[str]]) -> list[dict[str, Any]]: + """Map each endpoint to its callers for JSON report output.""" details = [] for endpoint in endpoints: if contains_escape_characters(endpoint): pattern = endpoint_to_regex(endpoint) - callers = set() + callers: set[str] = set() for call, call_set in called_data.items(): if pattern.match(call): callers.update(call_set) @@ -167,7 +148,7 @@ def prepare_endpoint_detail(endpoints: List[str], called_data: Dict[str, Set[str return sorted(details, key=lambda x: len(x["callers"])) -def write_report_file(report_data: Dict[str, Any], report_path: str) -> None: +def write_report_file(report_data: dict[str, Any], report_path: str) -> None: """Write the report data to a JSON file.""" path = Path(report_path).resolve() path.parent.mkdir(parents=True, exist_ok=True) @@ -177,8 +158,8 @@ def write_report_file(report_data: Dict[str, Any], report_path: str) -> None: def generate_pytest_api_cov_report( api_cov_config: ApiCoverageReportConfig, - called_data: Dict[str, Set[str]], - discovered_endpoints: List[str], + called_data: dict[str, set[str]], + discovered_endpoints: list[str], ) -> int: """Generate and print the API coverage report, returning an exit status.""" console = Console() diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 5cad875..34abdd0 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -1,7 +1,6 @@ """Unit tests for pytest-api-cov CLI module.""" -from pathlib import Path -from unittest.mock import Mock, mock_open, patch +from unittest.mock import patch import pytest @@ -13,32 +12,30 @@ class TestGenerateConftestContent: - """Tests for generate_conftest_content function.""" + """Tests for generate_conftest_content.""" def test_fastapi_conftest(self): - """Test generating conftest for FastAPI.""" + """Generate conftest for FastAPI.""" content = generate_conftest_content("FastAPI", "app.py", "app") assert "import pytest" in content assert "from fastapi.testclient import TestClient" in content assert "from app import app" in content assert "def client():" in content - assert "The pytest-api-cov plugin can extract the app from your client fixture" in content assert "return TestClient(app)" in content def test_flask_conftest(self): - """Test generating conftest for Flask.""" + """Generate conftest for Flask.""" content = generate_conftest_content("Flask", "main.py", "application") assert "import pytest" in content assert "from flask.testing import FlaskClient" in content assert "from main import application" in content assert "def client():" in content - assert "The pytest-api-cov plugin can extract the app from your client fixture" in content assert "return FlaskClient(app)" in content def test_subdirectory_conftest(self): - """Test generating conftest for app in subdirectory.""" + """Generate conftest for app in subdirectory.""" content = generate_conftest_content("FastAPI", "src/main.py", "app") assert "import pytest" in content @@ -47,7 +44,7 @@ def test_subdirectory_conftest(self): assert "return TestClient(app)" in content def test_nested_subdirectory_conftest(self): - """Test generating conftest for app in nested subdirectory.""" + """Generate conftest for app in nested subdirectory.""" content = generate_conftest_content("Flask", "example/src/main.py", "app") assert "import pytest" in content @@ -57,10 +54,10 @@ def test_nested_subdirectory_conftest(self): class TestGeneratePyprojectConfig: - """Tests for generate_pyproject_config function.""" + """Tests for generate_pyproject_config.""" def test_pyproject_config_structure(self): - """Test structure of generated pyproject config.""" + """Verify generated pyproject config contains expected sections.""" config = generate_pyproject_config() assert "[tool.pytest_api_cov]" in config @@ -74,10 +71,10 @@ def test_pyproject_config_structure(self): class TestMain: - """Tests for main function.""" + """Tests for main CLI entry point.""" def test_main_show_pyproject(self, monkeypatch): - """Test main prints pyproject snippet for show-pyproject.""" + """show-pyproject prints config snippet.""" monkeypatch.setattr("sys.argv", ["pytest-api-cov", "show-pyproject"]) with patch("builtins.print") as mock_print: result = main() @@ -85,7 +82,7 @@ def test_main_show_pyproject(self, monkeypatch): mock_print.assert_called() def test_main_show_conftest(self, monkeypatch): - """Test main prints conftest snippet for show-conftest.""" + """show-conftest prints conftest snippet.""" monkeypatch.setattr("sys.argv", ["pytest-api-cov", "show-conftest", "FastAPI", "src.main", "app"]) with patch("builtins.print") as mock_print: result = main() @@ -93,18 +90,18 @@ def test_main_show_conftest(self, monkeypatch): mock_print.assert_called() def test_main_no_command(self, monkeypatch): - """Test main with no command (should show help).""" + """No command returns exit code 1.""" monkeypatch.setattr("sys.argv", ["pytest-api-cov"]) result = main() assert result == 1 def test_main_unknown_command(self): - """Test main with unknown command.""" + """Unknown command exits with code 2.""" with pytest.raises(SystemExit) as exc_info: monkeypatch = pytest.MonkeyPatch() try: - monkeypatch.setenv("DUMMY", "1") # noop to obtain monkeypatch object + monkeypatch.setenv("DUMMY", "1") monkeypatch.setattr("sys.argv", ["pytest-api-cov", "unknown"]) main() finally: diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index dce520d..b7bbdb6 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1,11 +1,11 @@ """Tests for configuration module.""" import os +from pathlib import Path from unittest.mock import Mock, patch import pytest import tomli -from pathlib import Path from pydantic import ValidationError from pytest_api_cov.config import ( @@ -21,7 +21,7 @@ class TestConfigLoading: """Tests for loading configuration from different sources.""" def test_read_toml_config_success(self, tmp_path): - """Verify reading a valid pyproject.toml.""" + """Read a valid pyproject.toml.""" pyproject_content = """ [tool.pytest_api_cov] fail_under = 95.5 @@ -41,19 +41,19 @@ def test_read_toml_config_success(self, tmp_path): os.chdir(original_cwd) def test_read_toml_config_file_not_found(self): - """Ensure it returns an empty dict if pyproject.toml is missing.""" + """Return empty dict when pyproject.toml is missing.""" with patch("pathlib.Path.open", side_effect=FileNotFoundError): config = read_toml_config() assert config == {} def test_read_toml_config_toml_decode_error(self): - """Ensure it returns an empty dict if pyproject.toml has syntax errors.""" + """Return empty dict on TOML syntax errors.""" with patch("pathlib.Path.open", side_effect=tomli.TOMLDecodeError("Invalid TOML", "", 0)): config = read_toml_config() assert config == {} def test_read_toml_config_missing_section(self, tmp_path): - """Ensure it returns an empty dict if the [tool.pytest_api_cov] section is missing.""" + """Return empty dict when [tool.pytest_api_cov] section is absent.""" (tmp_path / "pyproject.toml").write_text("[project]\nname = 'test'") original_cwd = Path.cwd() @@ -65,7 +65,7 @@ def test_read_toml_config_missing_section(self, tmp_path): os.chdir(original_cwd) def test_read_session_config(self): - """Verify reading config from pytest's session object (CLI flags).""" + """Read config from pytest CLI flags.""" mock_session_config = Mock() mock_session_config.getoption.side_effect = lambda name: { "--api-cov-fail-under": 80.0, @@ -80,7 +80,7 @@ def test_read_session_config(self): assert "show_excluded_endpoints" not in config def test_read_session_config_with_false_values(self): - """Test that False values are not included in config.""" + """False values are not included in config.""" mock_session_config = Mock() mock_session_config.getoption.side_effect = lambda name: { "--api-cov-show-covered-endpoints": False, @@ -92,7 +92,7 @@ def test_read_session_config_with_false_values(self): assert "exclusion_patterns" not in config def test_read_session_config_with_none_values(self): - """Test that None values are not included in config.""" + """None values are not included in config.""" mock_session_config = Mock() mock_session_config.getoption.side_effect = lambda name: { "--api-cov-fail-under": None, @@ -102,33 +102,31 @@ def test_read_session_config_with_none_values(self): assert "fail_under" not in config @pytest.mark.parametrize( - ("is_tty", "encoding", "stdout_bool", "expected"), + ("is_tty", "encoding", "expected"), [ - (False, "utf-8", True, False), - (True, "utf-8", True, True), - (True, "UTF8", True, True), - (True, "ascii", True, False), - (True, "utf-8", False, False), + (False, "utf-8", False), + (True, "utf-8", True), + (True, "UTF8", True), + (True, "ascii", False), ], ) - def test_supports_unicode(self, is_tty, encoding, stdout_bool, expected): - """Test supports_unicode with various configurations.""" + def test_supports_unicode(self, is_tty, encoding, expected): + """Check unicode support detection.""" mock_stdout = Mock() mock_stdout.isatty.return_value = is_tty mock_stdout.encoding = encoding - mock_stdout.__bool__ = Mock(return_value=stdout_bool) with patch("sys.stdout", mock_stdout): assert supports_unicode() == expected class TestConfigMerging: - """Tests the merging logic of different config sources.""" + """Tests for config merging logic.""" @patch("pytest_api_cov.config.read_session_config") @patch("pytest_api_cov.config.read_toml_config") def test_config_priority_cli_over_toml(self, mock_read_toml, mock_read_session): - """Ensure CLI arguments override pyproject.toml settings.""" + """CLI arguments override pyproject.toml settings.""" mock_read_toml.return_value = {"fail_under": 90.0, "report_path": "toml.json"} mock_read_session.return_value = {"fail_under": 75.0} @@ -142,7 +140,7 @@ def test_config_priority_cli_over_toml(self, mock_read_toml, mock_read_session): @patch("pytest_api_cov.config.read_session_config", return_value={}) @patch("pytest_api_cov.config.read_toml_config") def test_pydantic_model_validation(self, mock_read_toml, mock_read_session): - """Test that the Pydantic model correctly validates and sets defaults.""" + """Pydantic model validates and sets defaults correctly.""" mock_read_toml.return_value = {"fail_under": 90.0} final_config = get_pytest_api_cov_report_config(Mock()) @@ -156,7 +154,7 @@ def test_pydantic_model_validation(self, mock_read_toml, mock_read_session): @patch("pytest_api_cov.config.read_toml_config") @patch("pytest_api_cov.config.supports_unicode") def test_force_sugar_setting(self, mock_supports_unicode, mock_read_toml, mock_read_session): - """Test force_sugar setting logic.""" + """force_sugar respects disabled flag and unicode support.""" mock_supports_unicode.return_value = True mock_read_toml.return_value = {} @@ -173,12 +171,12 @@ def test_force_sugar_setting(self, mock_supports_unicode, mock_read_toml, mock_r assert config.force_sugar is False def test_pydantic_validation_error(self): - """Ensure invalid types raise a validation error.""" + """Invalid types raise ValidationError.""" with pytest.raises(ValidationError): ApiCoverageReportConfig.model_validate({"fail_under": "not-a-float"}) def test_read_session_config_empty_options(self): - """Test read_session_config with no options set.""" + """No options set returns empty dict.""" mock_session_config = Mock() mock_session_config.getoption.return_value = None @@ -186,7 +184,7 @@ def test_read_session_config_empty_options(self): assert config == {} def test_read_session_config_with_empty_list(self): - """Test read_session_config with empty list value.""" + """Empty list value is treated as unset.""" mock_session_config = Mock() mock_session_config.getoption.side_effect = lambda name: { "--api-cov-exclusion-patterns": [], @@ -196,7 +194,7 @@ def test_read_session_config_with_empty_list(self): assert "exclusion_patterns" not in config def test_read_session_config_with_false_boolean(self): - """Test read_session_config with False boolean value.""" + """False boolean value is treated as unset.""" mock_session_config = Mock() mock_session_config.getoption.side_effect = lambda name: { "--api-cov-show-covered-endpoints": False, @@ -206,7 +204,7 @@ def test_read_session_config_with_false_boolean(self): assert "show_covered_endpoints" not in config def test_read_session_config_with_none_value(self): - """Test read_session_config with None value.""" + """None value is treated as unset.""" mock_session_config = Mock() mock_session_config.getoption.side_effect = lambda name: { "--api-cov-fail-under": None, @@ -216,7 +214,7 @@ def test_read_session_config_with_none_value(self): assert "fail_under" not in config def test_read_session_config_with_openapi_spec(self): - """Test read_session_config with openapi_spec.""" + """openapi_spec is read from CLI flags.""" mock_session_config = Mock() mock_session_config.getoption.side_effect = lambda name: { "--api-cov-openapi-spec": "openapi.json", @@ -226,7 +224,7 @@ def test_read_session_config_with_openapi_spec(self): assert config["openapi_spec"] == "openapi.json" def test_read_toml_config_with_openapi_spec(self, tmp_path): - """Verify reading openapi_spec from pyproject.toml.""" + """openapi_spec is read from pyproject.toml.""" pyproject_content = """ [tool.pytest_api_cov] openapi_spec = "openapi.yaml" diff --git a/tests/unit/test_create_coverage_fixture.py b/tests/unit/test_create_coverage_fixture.py index f65cd77..c668d75 100644 --- a/tests/unit/test_create_coverage_fixture.py +++ b/tests/unit/test_create_coverage_fixture.py @@ -4,32 +4,30 @@ class TestCreateCoverageFixture: - """Tests for the create_coverage_fixture helper function.""" + """Tests for create_coverage_fixture.""" def test_create_coverage_fixture_returns_callable(self): - """Test that create_coverage_fixture returns a callable fixture function.""" + """Returns a callable fixture function.""" fixture_func = create_coverage_fixture("test_client") assert callable(fixture_func) assert fixture_func.__name__ == "test_client" def test_create_coverage_fixture_with_existing_fixture_name(self): - """Test create_coverage_fixture with existing fixture name parameter.""" + """Accepts an existing fixture name parameter.""" fixture_func = create_coverage_fixture("my_client", "existing_client") assert callable(fixture_func) assert fixture_func.__name__ == "my_client" def test_create_coverage_fixture_is_pytest_fixture(self): - """Test that the created function is a pytest fixture.""" + """Result is a pytest fixture.""" fixture_func = create_coverage_fixture("test_client") - # Should be a pytest fixture (has _pytestfixturefunction or similar attribute) assert hasattr(fixture_func, "_pytestfixturefunction") or hasattr(fixture_func, "_fixture_function") def test_create_coverage_fixture_preserves_name(self): - """Test that the fixture function preserves the specified name.""" + """Fixture name is preserved.""" fixture_func = create_coverage_fixture("custom_name") - # The function name should be preserved assert fixture_func.__name__ == "custom_name" diff --git a/tests/unit/test_frameworks.py b/tests/unit/test_frameworks.py index 3ca6b4b..8b1a07d 100644 --- a/tests/unit/test_frameworks.py +++ b/tests/unit/test_frameworks.py @@ -4,7 +4,7 @@ import pytest -from pytest_api_cov.frameworks import FastAPIAdapter, FlaskAdapter, get_framework_adapter +from pytest_api_cov.frameworks import BaseAdapter, FastAPIAdapter, FlaskAdapter, get_framework_adapter class MockFlaskRule: @@ -25,7 +25,7 @@ def __init__(self, path, methods=None): class TestFlaskAdapter: - """Tests for the Flask framework adapter.""" + """Tests for the Flask adapter.""" def setup_method(self): """Set up test fixtures.""" @@ -36,18 +36,18 @@ def setup_method(self): self.adapter = FlaskAdapter(self.mock_app) def test_flask_get_endpoints(self): - """Verify endpoint discovery for Flask.""" + """Discover Flask endpoints.""" endpoints = self.adapter.get_endpoints() expected = ["GET /", "GET /users/", "POST /", "POST /users/"] assert sorted(endpoints) == sorted(expected) def test_flask_get_tracked_client_no_recorder(self): - """Test that get_tracked_client returns normal client when recorder is None.""" + """No recorder returns a normal test client.""" client = self.adapter.get_tracked_client(None, "test_name") assert client == self.mock_app.test_client() def test_flask_get_tracked_client_with_recorder(self): - """Test that get_tracked_client returns tracking client when recorder is provided.""" + """Recorder returns a TrackingFlaskClient.""" self.mock_app.response_class = type("MockResponse", (), {}) recorder = {} @@ -56,7 +56,7 @@ def test_flask_get_tracked_client_with_recorder(self): assert "TrackingFlaskClient" in str(type(client)) def test_flask_tracking_client_open_method(self): - """Test the TrackingFlaskClient open method.""" + """TrackingFlaskClient.open forwards calls.""" with patch("flask.testing.FlaskClient"): recorder = {} client = self.adapter.get_tracked_client(recorder, "test_name") @@ -70,9 +70,7 @@ def test_flask_tracking_client_open_method(self): assert response == "response" def test_flask_tracking_client_exception_handling(self): - """Test exception handling in Flask tracking client.""" - from unittest.mock import Mock, patch - + """Exceptions during URL matching are silently caught.""" self.mock_app.response_class = type("MockResponse", (), {}) recorder = {} @@ -89,9 +87,10 @@ def test_flask_tracking_client_exception_handling(self): class TestFastAPIAdapter: - """Tests for the FastAPI framework adapter.""" + """Tests for the FastAPI adapter.""" def setup_method(self): + """Set up test fixtures.""" self.mock_app = Mock() self.mock_app.__module__ = "fastapi" type(self.mock_app).__name__ = "FastAPI" @@ -106,27 +105,27 @@ def setup_method(self): self.adapter = FastAPIAdapter(self.mock_app) def test_fastapi_get_endpoints(self): - """Verify endpoint discovery for FastAPI.""" + """Discover FastAPI endpoints.""" with patch("fastapi.routing.APIRoute", MockFastAPIRoute): endpoints = self.adapter.get_endpoints() expected = ["GET /", "GET /items/{item_id}", "POST /", "POST /items/{item_id}"] assert sorted(endpoints) == sorted(expected) def test_fastapi_get_tracked_client_no_recorder(self): - """Test that get_tracked_client returns normal client when recorder is None.""" + """No recorder returns a normal TestClient.""" with patch("starlette.testclient.TestClient") as MockTestClient: self.adapter.get_tracked_client(None, "test_name") MockTestClient.assert_called_once_with(self.mock_app) def test_fastapi_get_tracked_client_with_recorder(self): - """Test that get_tracked_client returns tracking client when recorder is provided.""" + """Recorder returns a TrackingFastAPIClient.""" recorder = {} client = self.adapter.get_tracked_client(recorder, "test_name") assert hasattr(client, "send") assert "TrackingFastAPIClient" in str(type(client)) def test_fastapi_tracking_client_send_method(self): - """Test the TrackingFastAPIClient send method exists and can be called.""" + """TrackingFastAPIClient.send records calls and forwards.""" from pytest_api_cov.models import ApiCallRecorder recorder = ApiCallRecorder() @@ -147,43 +146,43 @@ def test_fastapi_tracking_client_send_method(self): class TestBaseAdapter: - """Test the base adapter class.""" + """Tests for the abstract base adapter.""" - def test_base_adapter_get_endpoints_not_implemented(self): - """Test that BaseAdapter.get_endpoints raises NotImplementedError.""" - from pytest_api_cov.frameworks import BaseAdapter + def test_base_adapter_cannot_be_instantiated(self): + """BaseAdapter is abstract and cannot be instantiated directly.""" + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + BaseAdapter(None) - base_adapter = BaseAdapter(None) - with pytest.raises(NotImplementedError): - base_adapter.get_endpoints() + def test_base_adapter_subclass_must_implement_methods(self): + """Subclasses missing abstract methods cannot be instantiated.""" - def test_base_adapter_get_tracked_client_not_implemented(self): - """Test that BaseAdapter.get_tracked_client raises NotImplementedError.""" - from pytest_api_cov.frameworks import BaseAdapter + class PartialAdapter(BaseAdapter): + def get_endpoints(self): + return [] - base_adapter = BaseAdapter(None) - with pytest.raises(NotImplementedError): - base_adapter.get_tracked_client({}, "test") + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + PartialAdapter(None) + def test_base_adapter_complete_subclass(self): + """Subclasses implementing all methods can be instantiated.""" -class TestAdapterFactory: - """Tests the factory function for getting adapters.""" + class CompleteAdapter(BaseAdapter): + def get_endpoints(self): + return [] - def test_get_framework_adapter(self): - """Test the factory function for selecting the correct adapter.""" + def get_tracked_client(self, recorder, test_name): + return None - class MockFlask: - __module__ = "flask.app" - __name__ = "Flask" + adapter = CompleteAdapter(None) + assert adapter.get_endpoints() == [] + assert adapter.get_tracked_client(None, "test") is None - class MockFastAPI: - __module__ = "fastapi.applications" - __name__ = "FastAPI" - class MockWSGIHandler: - __module__ = "django.core.handlers.wsgi" - __name__ = "WSGIHandler" +class TestAdapterFactory: + """Tests for the adapter factory function.""" + def test_get_framework_adapter(self): + """Factory returns correct adapter per framework.""" mock_flask_app = Mock() mock_flask_app.__class__.__module__ = "flask.app" mock_flask_app.__class__.__name__ = "Flask" @@ -203,7 +202,6 @@ class MockWSGIHandler: assert isinstance(get_framework_adapter(mock_flask_app), FlaskAdapter) assert isinstance(get_framework_adapter(mock_fastapi_app), FastAPIAdapter) - # Django is now supported from pytest_api_cov.frameworks import DjangoAdapter assert isinstance(get_framework_adapter(mock_django_app), DjangoAdapter) @@ -212,7 +210,7 @@ class MockWSGIHandler: get_framework_adapter(mock_unsupported_app) def test_get_framework_adapter_with_missing_module(self): - """Test factory function with app that has no __module__ attribute.""" + """App without __module__ raises TypeError.""" mock_app = Mock() mock_class = Mock() mock_class.__name__ = "UnknownApp" diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index c668b1e..8a5d5ea 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -6,16 +6,16 @@ class TestApiCallRecorder: - """Tests for ApiCallRecorder model.""" + """Tests for ApiCallRecorder.""" def test_init_default(self): - """Test ApiCallRecorder initialization with defaults.""" + """Default init has empty calls.""" recorder = ApiCallRecorder() assert recorder.calls == {} assert len(recorder) == 0 def test_record_call_new_endpoint(self): - """Test recording a call to a new endpoint.""" + """Record a call to a new endpoint.""" recorder = ApiCallRecorder() recorder.record_call("/test", "test_func") @@ -24,7 +24,7 @@ def test_record_call_new_endpoint(self): assert len(recorder) == 1 def test_record_call_with_method(self): - """Test recording a call with specific HTTP method.""" + """Record a call with a specific HTTP method.""" recorder = ApiCallRecorder() recorder.record_call("/test", "test_func", "POST") @@ -33,7 +33,7 @@ def test_record_call_with_method(self): assert len(recorder) == 1 def test_record_call_different_methods_same_endpoint(self): - """Test recording calls to same endpoint with different methods.""" + """Same path with different methods creates separate entries.""" recorder = ApiCallRecorder() recorder.record_call("/test", "test_get", "GET") recorder.record_call("/test", "test_post", "POST") @@ -45,7 +45,7 @@ def test_record_call_different_methods_same_endpoint(self): assert len(recorder) == 2 def test_record_call_existing_endpoint(self): - """Test recording additional calls to existing endpoint.""" + """Multiple calls to the same endpoint accumulate callers.""" recorder = ApiCallRecorder() recorder.record_call("/test", "test_func1") recorder.record_call("/test", "test_func2") @@ -56,7 +56,7 @@ def test_record_call_existing_endpoint(self): assert len(callers) == 2 def test_record_call_duplicate(self): - """Test recording duplicate calls (should not create duplicates).""" + """Duplicate calls are deduplicated (set behavior).""" recorder = ApiCallRecorder() recorder.record_call("/test", "test_func") recorder.record_call("/test", "test_func") @@ -66,7 +66,7 @@ def test_record_call_duplicate(self): assert "test_func" in callers def test_calls_keys(self): - """Test accessing recorded endpoint keys directly.""" + """Recorded endpoint keys are accessible.""" recorder = ApiCallRecorder() recorder.record_call("/endpoint1", "test1") recorder.record_call("/endpoint2", "test2") @@ -77,13 +77,13 @@ def test_calls_keys(self): assert "GET /endpoint2" in endpoints def test_calls_nonexistent(self): - """Test accessing callers for non-existent endpoint.""" + """Non-existent endpoint returns empty set via .get().""" recorder = ApiCallRecorder() callers = recorder.calls.get("GET /nonexistent", set()) assert callers == set() def test_merge_empty_recorder(self): - """Test merging with an empty recorder.""" + """Merging an empty recorder changes nothing.""" recorder1 = ApiCallRecorder() recorder1.record_call("/test", "test1") @@ -94,7 +94,7 @@ def test_merge_empty_recorder(self): assert "test1" in recorder1.calls["GET /test"] def test_merge_with_data(self): - """Test merging two recorders with data.""" + """Merging two recorders combines all data.""" recorder1 = ApiCallRecorder() recorder1.record_call("/endpoint1", "test1") recorder1.record_call("/shared", "test1") @@ -115,7 +115,7 @@ def test_merge_with_data(self): assert len(shared_callers) == 2 def test_to_serializable(self): - """Test converting to serializable format.""" + """Convert to serializable format (sets -> lists).""" recorder = ApiCallRecorder() recorder.record_call("/test1", "func1") recorder.record_call("/test1", "func2") @@ -131,7 +131,7 @@ def test_to_serializable(self): assert serializable["GET /test2"] == ["func3"] def test_from_serializable(self): - """Test creating from serializable format.""" + """Create from serializable format (lists -> sets).""" data = {"GET /test1": ["func1", "func2"], "POST /test2": ["func3"]} recorder = ApiCallRecorder.from_serializable(data) @@ -143,7 +143,7 @@ def test_from_serializable(self): assert recorder.calls["POST /test2"] == {"func3"} def test_contains(self): - """Test __contains__ method.""" + """__contains__ checks for endpoint key presence.""" recorder = ApiCallRecorder() recorder.record_call("/test", "func") @@ -152,7 +152,7 @@ def test_contains(self): assert "GET /nonexistent" not in recorder def test_items(self): - """Test items() method.""" + """items() iterates over (endpoint, callers) pairs.""" recorder = ApiCallRecorder() recorder.record_call("/test1", "func1") recorder.record_call("/test2", "func2") @@ -165,7 +165,7 @@ def test_items(self): assert "GET /test2" in endpoints def test_keys(self): - """Test keys() method.""" + """keys() returns all recorded endpoint keys.""" recorder = ApiCallRecorder() recorder.record_call("/test1", "func1") recorder.record_call("/test2", "func2") @@ -176,7 +176,7 @@ def test_keys(self): assert "GET /test2" in keys def test_values(self): - """Test values() method.""" + """values() returns all caller sets.""" recorder = ApiCallRecorder() recorder.record_call("/test1", "func1") recorder.record_call("/test2", "func2") @@ -189,17 +189,17 @@ def test_values(self): class TestEndpointDiscovery: - """Tests for EndpointDiscovery model.""" + """Tests for EndpointDiscovery.""" def test_init_default(self): - """Test EndpointDiscovery initialization with defaults.""" + """Default init has empty endpoints.""" discovery = EndpointDiscovery() assert discovery.endpoints == [] assert discovery.discovery_source == "unknown" assert len(discovery) == 0 def test_init_with_data(self): - """Test EndpointDiscovery initialization with data.""" + """Init with pre-populated endpoints.""" endpoints = ["GET /test1", "POST /test2"] discovery = EndpointDiscovery(endpoints=endpoints, discovery_source="test") @@ -208,7 +208,7 @@ def test_init_with_data(self): assert len(discovery) == 2 def test_add_endpoint_new(self): - """Test adding a new endpoint.""" + """Add a new endpoint.""" discovery = EndpointDiscovery() discovery.add_endpoint("/test") @@ -216,7 +216,7 @@ def test_add_endpoint_new(self): assert "GET /test" in discovery.endpoints def test_add_endpoint_with_method(self): - """Test adding an endpoint with specific method.""" + """Add an endpoint with a specific method.""" discovery = EndpointDiscovery() discovery.add_endpoint("/test", "POST") @@ -224,7 +224,7 @@ def test_add_endpoint_with_method(self): assert "POST /test" in discovery.endpoints def test_add_endpoint_duplicate(self): - """Test adding duplicate endpoint (should not create duplicates).""" + """Duplicate endpoints are not added twice.""" discovery = EndpointDiscovery() discovery.add_endpoint("/test") discovery.add_endpoint("/test") @@ -233,7 +233,7 @@ def test_add_endpoint_duplicate(self): assert discovery.endpoints.count("GET /test") == 1 def test_merge_empty(self): - """Test merging with empty discovery.""" + """Merging an empty discovery changes nothing.""" discovery1 = EndpointDiscovery() discovery1.add_endpoint("/test1") @@ -244,7 +244,7 @@ def test_merge_empty(self): assert "GET /test1" in discovery1.endpoints def test_merge_with_data(self): - """Test merging with another discovery containing data.""" + """Merging combines endpoints and deduplicates.""" discovery1 = EndpointDiscovery() discovery1.add_endpoint("/test1") discovery1.add_endpoint("/shared") @@ -262,10 +262,10 @@ def test_merge_with_data(self): class TestSessionData: - """Tests for SessionData model.""" + """Tests for SessionData.""" def test_init_default(self): - """Test SessionData initialization with defaults.""" + """Default init creates empty recorder and discovery.""" session = SessionData() assert isinstance(session.recorder, ApiCallRecorder) @@ -274,7 +274,7 @@ def test_init_default(self): assert len(session.discovered_endpoints) == 0 def test_record_call(self): - """Test record_call convenience method.""" + """record_call delegates to recorder.""" session = SessionData() session.record_call("/test", "test_func") @@ -282,7 +282,7 @@ def test_record_call(self): assert "test_func" in session.recorder.calls["GET /test"] def test_record_call_with_method(self): - """Test record_call with specific method.""" + """record_call with specific method.""" session = SessionData() session.record_call("/test", "test_func", "POST") @@ -290,7 +290,7 @@ def test_record_call_with_method(self): assert "test_func" in session.recorder.calls["POST /test"] def test_add_discovered_endpoint(self): - """Test add_discovered_endpoint convenience method.""" + """add_discovered_endpoint with method and source.""" session = SessionData() session.add_discovered_endpoint("/test", "GET", "flask_adapter") @@ -298,7 +298,7 @@ def test_add_discovered_endpoint(self): assert session.discovered_endpoints.discovery_source == "flask_adapter" def test_add_discovered_endpoint_multiple(self): - """Test adding multiple discovered endpoints.""" + """Multiple discovered endpoints accumulate.""" session = SessionData() session.add_discovered_endpoint("/test1", "GET", "flask_adapter") session.add_discovered_endpoint("/test2", "POST", "flask_adapter") @@ -309,7 +309,7 @@ def test_add_discovered_endpoint_multiple(self): assert session.discovered_endpoints.discovery_source == "flask_adapter" def test_merge_worker_data_dict_serializable(self): - """Test merging worker data in serializable format.""" + """Merge worker data in serializable format.""" session = SessionData() session.record_call("/session", "session_test") @@ -326,7 +326,7 @@ def test_merge_worker_data_dict_serializable(self): assert "POST /worker_endpoint" in session.discovered_endpoints.endpoints def test_merge_worker_data_dict_raw(self): - """Test merging worker data in raw dict format.""" + """Merge worker data in raw dict format.""" session = SessionData() session.record_call("/session", "session_test") @@ -339,7 +339,7 @@ def test_merge_worker_data_dict_raw(self): assert "worker_test" in session.recorder.calls["/worker"] def test_merge_worker_data_dict_mixed(self): - """Test merging worker data with mixed types.""" + """Merge worker data with mixed value types.""" session = SessionData() worker_recorder = { @@ -368,7 +368,7 @@ def test_merge_worker_data_dict_mixed(self): def test_merge_worker_data_edge_cases( self, worker_recorder, worker_endpoints, expected_recorder_len, expected_endpoints ): - """Test merging worker data with various edge cases.""" + """Merge worker data with edge cases (empty, non-dict, None).""" session = SessionData() if expected_recorder_len > 0: session.record_call("/session", "session_test") @@ -381,19 +381,19 @@ def test_merge_worker_data_edge_cases( for endpoint in expected_endpoints: assert endpoint in session.discovered_endpoints.endpoints - def test_add_discovered_endpoint_first_endpoint(self): - """Test adding the first endpoint sets the discovery source.""" + def test_add_discovered_endpoint_first_sets_source(self): + """First endpoint sets the discovery source.""" session = SessionData() - session.add_discovered_endpoint("/first", "flask_adapter") + session.add_discovered_endpoint("/first", "GET", "flask_adapter") assert session.discovered_endpoints.discovery_source == "flask_adapter" assert "GET /first" in session.discovered_endpoints.endpoints - def test_add_discovered_endpoint_subsequent_endpoints(self): - """Test adding subsequent endpoints doesn't change the discovery source.""" + def test_add_discovered_endpoint_subsequent_keeps_source(self): + """Subsequent endpoints don't change the discovery source.""" session = SessionData() - session.add_discovered_endpoint("/first", "flask_adapter") - session.add_discovered_endpoint("/second", "fastapi_adapter") + session.add_discovered_endpoint("/first", "GET", "flask_adapter") + session.add_discovered_endpoint("/second", "GET", "fastapi_adapter") assert session.discovered_endpoints.discovery_source == "flask_adapter" assert "GET /first" in session.discovered_endpoints.endpoints diff --git a/tests/unit/test_openapi.py b/tests/unit/test_openapi.py index c6f9f78..ced9a2f 100644 --- a/tests/unit/test_openapi.py +++ b/tests/unit/test_openapi.py @@ -1,17 +1,18 @@ """Unit tests for OpenAPI parser.""" import json + import pytest import yaml -from unittest.mock import patch, mock_open + from pytest_api_cov.openapi import parse_openapi_spec class TestParseOpenApiSpec: - """Tests for parse_openapi_spec function.""" + """Tests for parse_openapi_spec.""" def test_parse_json_spec_success(self, tmp_path): - """Test parsing a valid JSON OpenAPI spec.""" + """Parse a valid JSON OpenAPI spec.""" spec_content = { "openapi": "3.0.0", "paths": {"/users": {"get": {}, "post": {}}, "/items/{itemId}": {"put": {}}}, @@ -27,7 +28,7 @@ def test_parse_json_spec_success(self, tmp_path): assert "PUT /items/{itemId}" in endpoints def test_parse_yaml_spec_success(self, tmp_path): - """Test parsing a valid YAML OpenAPI spec.""" + """Parse a valid YAML OpenAPI spec.""" spec_content = """ openapi: 3.0.0 paths: @@ -48,12 +49,12 @@ def test_parse_yaml_spec_success(self, tmp_path): assert "PUT /items/{itemId}" in endpoints def test_file_not_found(self): - """Test handling of non-existent file.""" + """Non-existent file returns empty list.""" endpoints = parse_openapi_spec("non_existent.json") assert endpoints == [] def test_invalid_json_syntax(self, tmp_path): - """Test handling of invalid JSON syntax.""" + """Invalid JSON returns empty list.""" spec_file = tmp_path / "invalid.json" spec_file.write_text("{invalid json") @@ -61,7 +62,7 @@ def test_invalid_json_syntax(self, tmp_path): assert endpoints == [] def test_invalid_yaml_syntax(self, tmp_path): - """Test handling of invalid YAML syntax.""" + """Invalid YAML returns empty list.""" spec_file = tmp_path / "invalid.yaml" spec_file.write_text("invalid: yaml: :") @@ -69,7 +70,7 @@ def test_invalid_yaml_syntax(self, tmp_path): assert endpoints == [] def test_missing_paths_key(self, tmp_path): - """Test handling of spec without 'paths' key.""" + """Missing 'paths' key returns empty list.""" spec_content = {"openapi": "3.0.0", "info": {}} spec_file = tmp_path / "openapi.json" spec_file.write_text(json.dumps(spec_content)) @@ -78,12 +79,9 @@ def test_missing_paths_key(self, tmp_path): assert endpoints == [] def test_unsupported_file_extension(self, tmp_path): - """Test handling of unsupported file extension.""" + """Unsupported extension falls back to JSON parsing.""" spec_file = tmp_path / "spec.txt" spec_file.write_text("{}") - # Should probably log an error or return empty, depending on implementation. - # Assuming current implementation tries to parse based on extension or content. - # Let's check the implementation if needed, but for now expect empty or handled. endpoints = parse_openapi_spec(str(spec_file)) assert endpoints == [] diff --git a/tests/unit/test_plugin.py b/tests/unit/test_plugin.py index 25fc2ed..330ba4b 100644 --- a/tests/unit/test_plugin.py +++ b/tests/unit/test_plugin.py @@ -5,50 +5,50 @@ import pytest +from pytest_api_cov.frameworks import is_supported_framework from pytest_api_cov.models import SessionData from pytest_api_cov.plugin import ( DeferXdistPlugin, - is_supported_framework, + create_coverage_fixture, + extract_app_from_client, pytest_addoption, pytest_configure, pytest_sessionfinish, pytest_sessionstart, - extract_app_from_client, wrap_client_with_coverage, - create_coverage_fixture, ) class TestSupportedFramework: - """Tests for framework detection utility functions.""" + """Tests for framework detection.""" def test_is_supported_framework_none(self): - """Test framework detection with None.""" + """None is not a supported framework.""" assert is_supported_framework(None) is False def test_is_supported_framework_flask(self): - """Test framework detection with Flask app.""" + """Flask app is supported.""" mock_app = Mock() mock_app.__class__.__name__ = "Flask" mock_app.__class__.__module__ = "flask.app" assert is_supported_framework(mock_app) is True def test_is_supported_framework_fastapi(self): - """Test framework detection with FastAPI app.""" + """FastAPI app is supported.""" mock_app = Mock() mock_app.__class__.__name__ = "FastAPI" mock_app.__class__.__module__ = "fastapi.applications" assert is_supported_framework(mock_app) is True def test_is_supported_framework_django(self): - """Test framework detection with Django app.""" + """Django app is supported.""" mock_app = Mock() mock_app.__class__.__name__ = "Django" mock_app.__class__.__module__ = "django.core" assert is_supported_framework(mock_app) is True def test_is_supported_framework_unsupported(self): - """Test framework detection with unsupported framework.""" + """Unsupported framework returns False.""" mock_app = Mock() mock_app.__class__.__name__ = "Bottle" mock_app.__class__.__module__ = "bottle" @@ -59,7 +59,7 @@ class TestPluginHooks: """Tests for pytest plugin hooks.""" def test_pytest_addoption(self): - """Test that pytest_addoption adds the required flags.""" + """pytest_addoption registers flags.""" mock_parser = Mock() pytest_addoption(mock_parser) @@ -67,7 +67,7 @@ def test_pytest_addoption(self): assert callable(pytest_addoption) def test_pytest_sessionstart_with_api_cov_report(self): - """Test pytest_sessionstart when --api-cov-report is enabled.""" + """Session start creates coverage data when flag is set.""" mock_session = Mock() mock_session.config.getoption.return_value = True @@ -79,7 +79,7 @@ def test_pytest_sessionstart_with_api_cov_report(self): assert hasattr(mock_session.api_coverage_data, "discovered_endpoints") def test_pytest_sessionstart_without_api_cov_report(self): - """Test pytest_sessionstart when --api-cov-report is disabled.""" + """Session start skips coverage data when flag is off.""" class SimpleSession: def __init__(self): @@ -95,7 +95,7 @@ def __init__(self): @patch("pytest_api_cov.plugin.get_pytest_api_cov_report_config") @patch("pytest_api_cov.plugin.generate_pytest_api_cov_report") def test_pytest_sessionfinish_with_api_cov_report(self, mock_generate_report, mock_get_config): - """Test pytest_sessionfinish when --api-cov-report is enabled.""" + """Session finish generates report when flag is set.""" mock_session = Mock() mock_session.config.getoption.side_effect = lambda flag: flag == "--api-cov-report" @@ -120,7 +120,7 @@ def test_pytest_sessionfinish_with_api_cov_report(self, mock_generate_report, mo assert mock_session.exitstatus == 1 def test_pytest_sessionfinish_without_api_cov_report(self): - """Test pytest_sessionfinish when --api-cov-report is disabled.""" + """Session finish is a no-op when flag is off.""" class SimpleSession: def __init__(self): @@ -136,7 +136,7 @@ def __init__(self): @patch("pytest_api_cov.config.get_pytest_api_cov_report_config") @patch("pytest_api_cov.report.generate_pytest_api_cov_report") def test_pytest_sessionfinish_with_workeroutput(self, mock_generate_report, mock_get_config): - """Test pytest_sessionfinish with workeroutput (parallel execution).""" + """Session finish serializes data for xdist workers.""" mock_session = Mock() mock_session.config.getoption.return_value = True @@ -161,7 +161,7 @@ def test_pytest_sessionfinish_with_workeroutput(self, mock_generate_report, mock @patch("pytest_api_cov.plugin.get_pytest_api_cov_report_config") @patch("pytest_api_cov.plugin.generate_pytest_api_cov_report") def test_pytest_sessionfinish_with_worker_data(self, mock_generate_report, mock_get_config): - """Test pytest_sessionfinish with worker data merging.""" + """Session finish merges worker data on the master.""" mock_session = Mock() mock_session.config.getoption.side_effect = lambda flag: flag == "--api-cov-report" @@ -192,7 +192,7 @@ def test_pytest_sessionfinish_with_worker_data(self, mock_generate_report, mock_ @patch("pytest_api_cov.plugin.get_pytest_api_cov_report_config") @patch("pytest_api_cov.plugin.generate_pytest_api_cov_report") def test_pytest_sessionfinish_with_non_dict_worker_data(self, mock_generate_report, mock_get_config): - """Test pytest_sessionfinish with non-dict worker data.""" + """Session finish handles non-dict worker data gracefully.""" mock_session = Mock() mock_session.config.getoption.side_effect = lambda flag: flag == "--api-cov-report" @@ -202,8 +202,6 @@ def test_pytest_sessionfinish_with_non_dict_worker_data(self, mock_generate_repo mock_session.api_coverage_data = coverage_data mock_session.exitstatus = 0 - from collections import defaultdict - class NonDictWorkerData: def __init__(self): self.data = defaultdict(set) @@ -245,22 +243,21 @@ def __bool__(self): mock_generate_report.assert_called_once() def test_pytest_configure_with_xdist(self): - """Test pytest_configure when pytest-xdist is available.""" + """xdist plugin is registered when available.""" mock_config = Mock() - mock_config.getoption.return_value = True # --api-cov-report is enabled - mock_config.option.verbose = 1 # -v verbosity level + mock_config.getoption.return_value = True + mock_config.option.verbose = 1 mock_config.pluginmanager.hasplugin.return_value = True pytest_configure(mock_config) - # DeferXdistPlugin mock_config.pluginmanager.register.assert_called_once() def test_pytest_configure_without_xdist(self): - """Test pytest_configure when pytest-xdist is not available.""" + """No xdist registration when plugin is absent.""" mock_config = Mock() - mock_config.getoption.return_value = True # --api-cov-report is enabled - mock_config.option.verbose = 0 # no verbosity + mock_config.getoption.return_value = True + mock_config.option.verbose = 0 mock_config.pluginmanager.hasplugin.return_value = False pytest_configure(mock_config) @@ -268,9 +265,9 @@ def test_pytest_configure_without_xdist(self): mock_config.pluginmanager.register.assert_not_called() def test_pytest_configure_without_api_cov_report(self): - """Test pytest_configure when --api-cov-report is not enabled.""" + """Logging is skipped when api-cov-report is off.""" mock_config = Mock() - mock_config.getoption.return_value = False # --api-cov-report is not enabled + mock_config.getoption.return_value = False mock_config.pluginmanager.hasplugin.return_value = True pytest_configure(mock_config) @@ -280,19 +277,19 @@ def test_pytest_configure_without_api_cov_report(self): @pytest.mark.parametrize( ("verbose_level", "expected_log_level"), [ - (0, "WARNING"), # normal run - (1, "INFO"), # -v - (2, "DEBUG"), # -vv or more - (3, "DEBUG"), # -vvv + (0, "WARNING"), + (1, "INFO"), + (2, "DEBUG"), + (3, "DEBUG"), ], ) @patch("pytest_api_cov.plugin.logger") def test_pytest_configure_logging_levels(self, mock_logger, verbose_level, expected_log_level): - """Test that logging levels are set correctly based on verbosity.""" + """Log level matches verbosity flag.""" import logging mock_config = Mock() - mock_config.getoption.return_value = True # --api-cov-report enabled + mock_config.getoption.return_value = True mock_config.option.verbose = verbose_level mock_config.pluginmanager.hasplugin.return_value = False mock_logger.handlers = [] @@ -304,7 +301,7 @@ def test_pytest_configure_logging_levels(self, mock_logger, verbose_level, expec @patch("pytest_api_cov.plugin.logger") def test_pytest_configure_existing_handler(self, mock_logger): - """Test that no new handler is added if one already exists.""" + """No duplicate handlers are added.""" mock_config = Mock() mock_config.getoption.return_value = True mock_config.option.verbose = 1 @@ -317,10 +314,10 @@ def test_pytest_configure_existing_handler(self, mock_logger): class TestDeferXdistPlugin: - """Tests for the DeferXdistPlugin class.""" + """Tests for the xdist deferred plugin.""" def test_pytest_testnodedown_with_worker_data(self): - """Test pytest_testnodedown when worker data is available.""" + """Worker data is merged on node down.""" mock_node = Mock() mock_node.workeroutput = {"api_call_recorder": {"/test": ["test_func"]}} @@ -334,7 +331,7 @@ def test_pytest_testnodedown_with_worker_data(self): assert "test_func" in worker_data["/test"] def test_pytest_testnodedown_without_worker_data(self): - """Test pytest_testnodedown when no worker data is available.""" + """No-op when worker has no data.""" mock_node = Mock() mock_node.workeroutput = {} mock_node.config.worker_api_call_recorder = {} @@ -345,7 +342,7 @@ def test_pytest_testnodedown_without_worker_data(self): assert mock_node.config.worker_api_call_recorder == {} def test_pytest_testnodedown_with_existing_worker_data(self): - """Test pytest_testnodedown when worker data already exists.""" + """New worker data is merged with existing data.""" mock_node = Mock() mock_node.workeroutput = {"api_call_recorder": {"/new": ["new_test"]}} @@ -432,7 +429,7 @@ def open(self, *args, **kwargs): def test_create_coverage_fixture_returns_existing_client_when_coverage_disabled(): - """create_coverage_fixture yields existing fixture when coverage disabled.""" + """Existing fixture is yielded when coverage is disabled.""" fixture = create_coverage_fixture("my_client", existing_fixture_name="existing") class SimpleSession: @@ -461,9 +458,9 @@ def getfixturevalue(self, name): next(gen) -@patch("pytest_api_cov.frameworks.get_framework_adapter") +@patch("pytest_api_cov.plugin.get_framework_adapter") def test_create_coverage_fixture_falls_back_to_app_when_no_existing_and_coverage_disabled(mock_get_adapter): - """When no existing client but an app fixture exists and coverage disabled, create tracked client.""" + """App fixture fallback when no existing client and coverage disabled.""" fixture = create_coverage_fixture("my_client", existing_fixture_name=None) class SimpleSession: @@ -488,7 +485,6 @@ def getfixturevalue(self, name): mock_get_adapter.return_value = adapter req = Req() - # Unwrap pytest.fixture wrapper to call the inner generator directly raw_fixture = getattr(fixture, "__wrapped__", fixture) gen = raw_fixture(req) got = next(gen) @@ -501,25 +497,22 @@ def getfixturevalue(self, name): @patch("pytest_api_cov.plugin.get_pytest_api_cov_report_config") @patch("pytest_api_cov.plugin.parse_openapi_spec") def test_create_coverage_fixture_with_openapi_spec(mock_parse_spec, mock_get_config): - """Test that endpoints are discovered from OpenAPI spec if configured.""" + """Endpoints are discovered from OpenAPI spec when configured.""" fixture = create_coverage_fixture("my_client") - # Mock config to have openapi_spec mock_config = Mock() mock_config.openapi_spec = "openapi.json" mock_config.client_fixture_names = ["client"] mock_get_config.return_value = mock_config - # Mock parse_openapi_spec mock_parse_spec.return_value = ["GET /users", "POST /users"] - # Mock session and coverage data coverage_data = SessionData() class SimpleSession: def __init__(self): self.config = Mock() - self.config.getoption.return_value = True # coverage enabled + self.config.getoption.return_value = True self.api_coverage_data = coverage_data session = SimpleSession() @@ -538,12 +531,10 @@ def getfixturevalue(self, name): req = Req() - # Execute fixture raw_fixture = getattr(fixture, "__wrapped__", fixture) gen = raw_fixture(req) client = next(gen) - # Verify mock_parse_spec.assert_called_once_with("openapi.json") assert "GET /users" in coverage_data.discovered_endpoints.endpoints assert "POST /users" in coverage_data.discovered_endpoints.endpoints diff --git a/tests/unit/test_report.py b/tests/unit/test_report.py index e2f003b..765dada 100644 --- a/tests/unit/test_report.py +++ b/tests/unit/test_report.py @@ -17,16 +17,16 @@ class TestEndpointCategorization: - """Tests for endpoint classification logic.""" + """Tests for endpoint categorisation logic.""" def test_endpoint_to_regex_conversion(self): - """Verify regex creation for Flask and FastAPI style placeholders.""" + """Regex creation for Flask and FastAPI style placeholders.""" assert endpoint_to_regex("/users/").pattern == "^/users/(.+)$" assert endpoint_to_regex("/items/{item_id}/data").pattern == "^/items/(.+)/data$" assert endpoint_to_regex("/static/path").pattern == "^/static/path$" def test_categorise_endpoints(self): - """Test the main categorization logic.""" + """Standard categorisation with exclusions.""" discovered = ["/users", "/users/{user_id}", "/health", "/admin/dashboard", "/a/b/c", "/a/b/d/c"] called = {"/users", "/users/123", "/admin/dashboard"} excluded = ["*admin*", "/a/*/c"] @@ -38,7 +38,7 @@ def test_categorise_endpoints(self): assert set(excluded_out) == {"/admin/dashboard", "/a/b/c", "/a/b/d/c"} def test_categorise_with_no_exclusions(self): - """Ensure it works correctly with no exclusion patterns.""" + """No exclusion patterns passes all endpoints through.""" discovered = ["/a", "/b"] called = {"/a"} covered, uncovered, excluded = categorise_endpoints(discovered, called, []) @@ -47,7 +47,7 @@ def test_categorise_with_no_exclusions(self): assert excluded == [] def test_categorise_with_exclusion_patterns(self): - """Test categorization with exclusion patterns.""" + """Exact exclusion pattern removes an endpoint.""" discovered = ["/public", "/admin", "/internal"] called = {"/public", "/admin"} excluded = ["/admin"] @@ -58,7 +58,7 @@ def test_categorise_with_exclusion_patterns(self): assert set(excluded_out) == {"/admin"} def test_categorise_with_wildcard_exclusions(self): - """Test categorization with wildcard exclusion patterns.""" + """Wildcard exclusion pattern removes matching endpoints.""" discovered = ["/public", "/admin/users", "/admin/settings", "/internal"] called = {"/public", "/admin/users"} excluded = ["/admin/*"] @@ -69,7 +69,7 @@ def test_categorise_with_wildcard_exclusions(self): assert set(excluded_out) == {"/admin/users", "/admin/settings"} def test_categorise_with_literal_dot_patterns(self): - """Test that dots in patterns are treated literally, not as regex wildcards.""" + """Dots in patterns are treated literally, not as regex wildcards.""" discovered = ["/api/v1.0/users", "/api/v1x0/users", "/api/v2.0/users"] called = set() excluded = ["/api/v1.0/*"] @@ -80,62 +80,62 @@ def test_categorise_with_literal_dot_patterns(self): assert set(excluded_out) == {"/api/v1.0/users"} def test_categorise_with_negation_patterns(self): - """Test categorization with negation patterns that override exclusions.""" + """Negation pattern overrides an exclusion.""" discovered = ["/users/alice", "/users/bob", "/users/charlie", "/admin/settings"] called = {"/users/alice", "/users/bob"} - patterns = ["/users/*", "!/users/bob"] # Exclude all users except bob + patterns = ["/users/*", "!/users/bob"] covered, uncovered, excluded_out = categorise_endpoints(discovered, called, patterns) - assert set(covered) == {"/users/bob"} # bob is negated from exclusion + assert set(covered) == {"/users/bob"} assert set(uncovered) == {"/admin/settings"} - assert set(excluded_out) == {"/users/alice", "/users/charlie"} # alice and charlie are excluded + assert set(excluded_out) == {"/users/alice", "/users/charlie"} def test_categorise_with_multiple_negation_patterns(self): - """Test categorization with multiple negation patterns.""" + """Multiple negation patterns re-include multiple endpoints.""" discovered = ["/api/v1/users", "/api/v1/admin", "/api/v1/public", "/api/v2/users", "/health"] called = {"/api/v1/users", "/api/v1/public"} - patterns = ["/api/v1/*", "!/api/v1/users", "!/api/v1/public"] # Exclude v1 except users and public + patterns = ["/api/v1/*", "!/api/v1/users", "!/api/v1/public"] covered, uncovered, excluded_out = categorise_endpoints(discovered, called, patterns) - assert set(covered) == {"/api/v1/users", "/api/v1/public"} # negated from exclusion + assert set(covered) == {"/api/v1/users", "/api/v1/public"} assert set(uncovered) == {"/api/v2/users", "/health"} - assert set(excluded_out) == {"/api/v1/admin"} # only admin is excluded + assert set(excluded_out) == {"/api/v1/admin"} def test_categorise_with_negation_wildcard_patterns(self): - """Test negation patterns with wildcards.""" + """Negation with wildcards re-includes a subtree.""" discovered = ["/admin/users/alice", "/admin/users/bob", "/admin/settings", "/public"] called = {"/admin/users/alice"} - patterns = ["/admin/*", "!/admin/users/*"] # Exclude all admin except admin/users/* + patterns = ["/admin/*", "!/admin/users/*"] covered, uncovered, excluded_out = categorise_endpoints(discovered, called, patterns) assert set(covered) == {"/admin/users/alice"} - assert set(uncovered) == {"/admin/users/bob", "/public"} # bob is uncovered but not excluded - assert set(excluded_out) == {"/admin/settings"} # settings is excluded + assert set(uncovered) == {"/admin/users/bob", "/public"} + assert set(excluded_out) == {"/admin/settings"} def test_categorise_with_method_endpoint_negation(self): - """Test negation patterns work with METHOD /path format.""" + """Negation works with METHOD /path format endpoints.""" discovered = ["GET /users/alice", "POST /users/alice", "GET /users/bob", "GET /admin"] called = {"GET /users/alice", "GET /users/bob"} - patterns = ["/users/*", "!/users/bob"] # Exclude all users except bob + patterns = ["/users/*", "!/users/bob"] covered, uncovered, excluded_out = categorise_endpoints(discovered, called, patterns) - assert set(covered) == {"GET /users/bob"} # bob is negated from exclusion + assert set(covered) == {"GET /users/bob"} assert set(uncovered) == {"GET /admin"} - assert set(excluded_out) == {"GET /users/alice", "POST /users/alice"} # alice endpoints excluded + assert set(excluded_out) == {"GET /users/alice", "POST /users/alice"} def test_categorise_negation_without_matching_exclusion(self): - """Test that negation patterns without matching exclusions don't affect anything.""" + """Negation without a matching exclusion has no effect.""" discovered = ["/users/alice", "/users/bob", "/admin"] called = {"/users/alice"} - patterns = ["!/users/charlie"] # Negation for non-existent exclusion + patterns = ["!/users/charlie"] covered, uncovered, excluded_out = categorise_endpoints(discovered, called, patterns) assert set(covered) == {"/users/alice"} assert set(uncovered) == {"/users/bob", "/admin"} - assert set(excluded_out) == set() # Nothing excluded + assert set(excluded_out) == set() def test_categorise_complex_exclusion_negation_scenario(self): - """Test complex scenario with multiple exclusions and negations.""" + """Complex mix of exclusions and negations.""" discovered = [ "/api/v1/users", "/api/v1/admin", @@ -148,10 +148,10 @@ def test_categorise_complex_exclusion_negation_scenario(self): ] called = {"/api/v1/users", "/api/v1/public", "/health"} patterns = [ - "/api/v1/*", # Exclude all v1 endpoints - "/metrics", # Exclude metrics - "!/api/v1/users", # But include v1/users - "!/api/v1/public", # But include v1/public + "/api/v1/*", + "/metrics", + "!/api/v1/users", + "!/api/v1/public", ] covered, uncovered, excluded_out = categorise_endpoints(discovered, called, patterns) @@ -160,10 +160,10 @@ def test_categorise_complex_exclusion_negation_scenario(self): assert set(excluded_out) == {"/api/v1/admin", "/metrics"} def test_categorise_with_method_specific_exclusion(self): - """Exclude a specific HTTP method for an endpoint using 'METHOD /path' patterns.""" + """Exclude a specific HTTP method for an endpoint.""" discovered = ["GET /items", "POST /items", "GET /health"] called = {"POST /items"} - patterns = ["GET /items"] # Exclude only GET /items + patterns = ["GET /items"] covered, uncovered, excluded_out = categorise_endpoints(discovered, called, patterns) assert set(covered) == {"POST /items"} @@ -171,10 +171,10 @@ def test_categorise_with_method_specific_exclusion(self): assert set(excluded_out) == {"GET /items"} def test_categorise_with_multiple_method_prefixes(self): - """Support comma-separated method prefixes to exclude multiple methods for a path.""" + """Comma-separated method prefixes exclude multiple methods.""" discovered = ["GET /users/1", "POST /users/1", "PUT /users/1"] called = {"PUT /users/1"} - patterns = ["GET,POST /users/*"] # Exclude GET and POST for users + patterns = ["GET,POST /users/*"] covered, uncovered, excluded_out = categorise_endpoints(discovered, called, patterns) assert set(covered) == {"PUT /users/1"} @@ -182,10 +182,10 @@ def test_categorise_with_multiple_method_prefixes(self): assert set(excluded_out) == {"GET /users/1", "POST /users/1"} def test_categorise_method_prefixed_negation(self): - """Negation with a method prefix should re-include only that method.""" + """Negation with a method prefix re-includes only that method.""" discovered = ["GET /users/alice", "POST /users/alice", "GET /users/bob"] called = {"GET /users/bob"} - patterns = ["/users/*", "!GET /users/bob"] # Exclude all users but re-include GET /users/bob + patterns = ["/users/*", "!GET /users/bob"] covered, uncovered, excluded_out = categorise_endpoints(discovered, called, patterns) assert set(covered) == {"GET /users/bob"} @@ -194,18 +194,18 @@ def test_categorise_method_prefixed_negation(self): class TestCoverageCalculationAndReporting: - """Tests for coverage computation and report generation.""" + """Tests for coverage computation and report output.""" @pytest.mark.parametrize( ("covered", "uncovered", "expected"), [(10, 0, 100.0), (0, 10, 0.0), (5, 5, 50.0), (3, 1, 75.0), (0, 0, 0.0)], ) def test_compute_coverage(self, covered, uncovered, expected): - """Test coverage percentage calculation.""" + """Coverage percentage calculation.""" assert compute_coverage(covered, uncovered) == expected def test_prepare_endpoint_detail(self): - """Verify that caller information is correctly aggregated.""" + """Caller information is correctly aggregated.""" endpoints = ["/static", "/users/{user_id}"] called_data = { "/static": {"test_a"}, @@ -222,7 +222,7 @@ def test_prepare_endpoint_detail(self): @patch("pytest_api_cov.report.Console") def test_generate_report_success(self, mock_console_cls): - """Test report generation when coverage meets the requirement.""" + """Report shows SUCCESS when coverage meets threshold.""" mock_console = mock_console_cls.return_value config = ApiCoverageReportConfig.model_validate({"fail_under": 70.0}) discovered = ["/a", "/b", "/c", "/d"] @@ -230,13 +230,13 @@ def test_generate_report_success(self, mock_console_cls): status = generate_pytest_api_cov_report(config, called, discovered) - assert status == 0 # Success + assert status == 0 success_print = next(c for c in mock_console.print.call_args_list if "SUCCESS" in c.args[0]) assert "Coverage of 75.0%" in success_print.args[0] @patch("pytest_api_cov.report.Console") def test_generate_report_failure(self, mock_console_cls): - """Test report generation when coverage is below the requirement.""" + """Report shows FAIL when coverage is below threshold.""" mock_console = mock_console_cls.return_value config = ApiCoverageReportConfig.model_validate({"fail_under": 80.0}) discovered = ["/a", "/b", "/c", "/d"] @@ -244,13 +244,13 @@ def test_generate_report_failure(self, mock_console_cls): status = generate_pytest_api_cov_report(config, called, discovered) - assert status == 1 # Failure + assert status == 1 fail_print = next(c for c in mock_console.print.call_args_list if "FAIL" in c.args[0]) assert "FAIL: Required coverage of 80.0% not met. Actual coverage: 75.0%" in fail_print.args[0] @patch("pytest_api_cov.report.write_report_file") def test_json_report_generation(self, mock_write_report): - """Ensure the JSON report is generated when a path is provided.""" + """JSON report is written when report_path is set.""" config = ApiCoverageReportConfig.model_validate({"report_path": "coverage.json"}) status = generate_pytest_api_cov_report(config, {"/a": "foo"}, ["/a", "/b"]) @@ -262,7 +262,7 @@ def test_json_report_generation(self, mock_write_report): @patch("pytest_api_cov.report.Console") def test_generate_report_no_endpoints(self, mock_console_cls): - """Test report generation when no endpoints are discovered.""" + """Empty endpoint list prints a warning.""" mock_console = mock_console_cls.return_value config = ApiCoverageReportConfig.model_validate({}) discovered = [] @@ -277,13 +277,13 @@ def test_generate_report_no_endpoints(self, mock_console_cls): @pytest.mark.parametrize( ("force_sugar", "expected_symbols"), [ - (True, ["❌", "✅", "🚫"]), # Unicode symbols - (False, ["[X]", "[.]", "[-]"]), # ASCII symbols + (True, ["❌", "✅", "🚫"]), + (False, ["[X]", "[.]", "[-]"]), ], ) @patch("pytest_api_cov.report.Console") def test_generate_report_sugar_symbols(self, mock_console_cls, force_sugar, expected_symbols): - """Test report generation with different symbol configurations.""" + """Sugar/ASCII symbols match force_sugar setting.""" mock_console = mock_console_cls.return_value config = ApiCoverageReportConfig.model_validate( { @@ -306,23 +306,23 @@ def test_generate_report_sugar_symbols(self, mock_console_cls, force_sugar, expe class TestPrintEndpoints: - """Tests for the print_endpoints function.""" + """Tests for print_endpoints.""" @patch("pytest_api_cov.report.Console") def test_print_endpoints_with_endpoints(self, mock_console_cls): - """Test print_endpoints when there are endpoints to print.""" + """Prints label + one line per endpoint.""" mock_console = mock_console_cls.return_value endpoints = ["/a", "/b"] print_endpoints(mock_console, "Test Label", endpoints, "✓", "green") - assert mock_console.print.call_count == 3 # Label + 2 endpoints + assert mock_console.print.call_count == 3 label_call = mock_console.print.call_args_list[0] assert "Test Label" in label_call.args[0] @patch("pytest_api_cov.report.Console") def test_print_endpoints_without_endpoints(self, mock_console_cls): - """Test print_endpoints when there are no endpoints to print.""" + """No output when endpoint list is empty.""" mock_console = mock_console_cls.return_value endpoints = [] @@ -332,12 +332,12 @@ def test_print_endpoints_without_endpoints(self, mock_console_cls): class TestWriteReportFile: - """Tests for the write_report_file function.""" + """Tests for write_report_file.""" @patch("pathlib.Path.open") @patch("json.dump") def test_write_report_file(self, mock_json_dump, mock_open): - """Test that write_report_file writes data correctly.""" + """Report data is written as JSON.""" report_data = {"coverage": 100.0, "endpoints": ["/a", "/b"]} report_path = "test_report.json" diff --git a/uv.lock b/uv.lock index c5f8aed..8a31461 100644 --- a/uv.lock +++ b/uv.lock @@ -529,6 +529,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "pybencher" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5f/44/b8db7439b8022a424c124f2a0a1e728a1ab84e86ad4bb3468e400982d038/pybencher-2.1.0.tar.gz", hash = "sha256:c7de5fccdc4a253a71dece74cb87ab813a86f5274afdc8fc1d53b910ebbbbc81", size = 30105, upload-time = "2026-03-29T20:58:33.181Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/72/c3/71c13804badf6c651cdd7f6d5e1c1e956882dec079a9428465262d41d0a8/pybencher-2.1.0-py3-none-any.whl", hash = "sha256:52474b219515d472d166d0511f02e52c21adec2c2f81660792c57eb1612808b9", size = 7553, upload-time = "2026-03-29T20:58:34.376Z" }, +] + [[package]] name = "pydantic" version = "2.12.3" @@ -687,7 +696,7 @@ wheels = [ [[package]] name = "pytest-api-cov" -version = "1.3.4" +version = "1.3.5" source = { editable = "." } dependencies = [ { name = "pydantic" }, @@ -718,6 +727,7 @@ dev = [ { name = "httpx" }, { name = "mypy" }, { name = "path" }, + { name = "pybencher" }, { name = "pytest-cov" }, { name = "pytest-sugar" }, { name = "pytest-xdist" }, @@ -751,6 +761,7 @@ dev = [ { name = "httpx", specifier = ">=0.20.0" }, { name = "mypy", specifier = ">=1.17.0" }, { name = "path", specifier = ">=16.0.0" }, + { name = "pybencher", specifier = ">=2.1.0" }, { name = "pytest-cov", specifier = ">=6.2.1" }, { name = "pytest-sugar", specifier = ">=1.0.0" }, { name = "pytest-xdist", specifier = ">=3.8.0" }, From 5dc231b7b6bfdb8ae8b516f4938b7df68a98f7cc Mon Sep 17 00:00:00 2001 From: BarnabasG Date: Sun, 12 Apr 2026 23:18:06 +0100 Subject: [PATCH 2/2] Fix formatting --- src/pytest_api_cov/config.py | 2 +- src/pytest_api_cov/frameworks.py | 1 + src/pytest_api_cov/models.py | 5 +++- src/pytest_api_cov/plugin.py | 39 ++++++++++++++++++-------------- src/pytest_api_cov/report.py | 5 ++-- 5 files changed, 31 insertions(+), 21 deletions(-) diff --git a/src/pytest_api_cov/config.py b/src/pytest_api_cov/config.py index 43285d3..40e7618 100644 --- a/src/pytest_api_cov/config.py +++ b/src/pytest_api_cov/config.py @@ -54,7 +54,7 @@ def read_toml_config() -> dict[str, Any]: "api-cov-openapi-spec": "openapi_spec", } -_UNSET = (None, [], False) +_UNSET: tuple[Any, ...] = (None, [], False) def read_session_config(session_config: Any) -> dict[str, Any]: diff --git a/src/pytest_api_cov/frameworks.py b/src/pytest_api_cov/frameworks.py index 979eca1..297a97c 100644 --- a/src/pytest_api_cov/frameworks.py +++ b/src/pytest_api_cov/frameworks.py @@ -13,6 +13,7 @@ class BaseAdapter(ABC): """Abstract base for framework adapters.""" def __init__(self, app: Any) -> None: + """Bind the framework app instance.""" self.app = app @abstractmethod diff --git a/src/pytest_api_cov/models.py b/src/pytest_api_cov/models.py index 65c9970..60a6426 100644 --- a/src/pytest_api_cov/models.py +++ b/src/pytest_api_cov/models.py @@ -48,9 +48,11 @@ def from_serializable(cls, data: dict[str, list[str]]) -> ApiCallRecorder: return cls(calls=calls) def __len__(self) -> int: + """Return the number of distinct endpoints recorded.""" return len(self.calls) def __contains__(self, endpoint: str) -> bool: + """Check if an endpoint has been recorded.""" return endpoint in self.calls def items(self) -> Any: @@ -73,7 +75,7 @@ class EndpointDiscovery(BaseModel): _seen: set[str] = set() discovery_source: str = Field(default="unknown") - def model_post_init(self, __context: Any) -> None: + def model_post_init(self, _: Any, /) -> None: """Sync the internal set with any pre-populated endpoints.""" self._seen = set(self.endpoints) @@ -92,6 +94,7 @@ def merge(self, other: EndpointDiscovery) -> None: self.endpoints.append(endpoint) def __len__(self) -> int: + """Return the number of discovered endpoints.""" return len(self.endpoints) diff --git a/src/pytest_api_cov/plugin.py b/src/pytest_api_cov/plugin.py index 9048cd7..f0f8230 100644 --- a/src/pytest_api_cov/plugin.py +++ b/src/pytest_api_cov/plugin.py @@ -102,6 +102,16 @@ def pytest_sessionstart(session: pytest.Session) -> None: session.api_coverage_data = SessionData() # type: ignore[attr-defined] +def _try_get_fixture(request: pytest.FixtureRequest, names: tuple[str, ...] | list[str]) -> Any | None: + """Try fixture names in order, return the first found or None.""" + for name in names: + try: + return request.getfixturevalue(name) + except pytest.FixtureLookupError: # noqa: PERF203 + continue + return None + + def create_coverage_fixture(fixture_name: str, existing_fixture_name: str | None = None) -> Any: """Create a coverage-enabled fixture with a custom name. @@ -275,32 +285,26 @@ def _coverage_client_impl(request: pytest.FixtureRequest) -> Any: if not coverage_enabled or coverage_data is None: # Try common client fixture names then app fixture - for name in ("client", "test_client", "api_client", "app_client"): - try: - yield request.getfixturevalue(name) - return - except pytest.FixtureLookupError: - continue + found = _try_get_fixture(request, ("client", "test_client", "api_client", "app_client")) + if found is not None: + yield found + return try: app = request.getfixturevalue("app") adapter = get_framework_adapter(app) - yield adapter.get_tracked_client(None, request.node.name) except (pytest.FixtureLookupError, Exception): # noqa: BLE001 yield None + else: + yield adapter.get_tracked_client(None, request.node.name) return config = get_pytest_api_cov_report_config(request.config) _discover_openapi_endpoints(config, coverage_data) # Find a client fixture - client = None - for name in config.client_fixture_names: - try: - client = request.getfixturevalue(name) - logger.info(f"> Found client fixture '{name}'") - break - except pytest.FixtureLookupError: - continue + client = _try_get_fixture(request, config.client_fixture_names) + if client is not None: + logger.info("> Found client fixture") app = extract_app_from_client(client) if client else None if app is None: @@ -318,10 +322,11 @@ def _coverage_client_impl(request: pytest.FixtureRequest) -> Any: if app is not None: try: adapter = get_framework_adapter(app) - yield adapter.get_tracked_client(coverage_data.recorder, request.node.name) - return except Exception as e: # noqa: BLE001 logger.warning(f"> Failed to create tracked client: {e}") + else: + yield adapter.get_tracked_client(coverage_data.recorder, request.node.name) + return logger.warning("> coverage_client could not provide a client; tests will run without API coverage.") yield None diff --git a/src/pytest_api_cov/report.py b/src/pytest_api_cov/report.py index eb12004..1db77e1 100644 --- a/src/pytest_api_cov/report.py +++ b/src/pytest_api_cov/report.py @@ -6,11 +6,12 @@ import re from pathlib import Path from re import Pattern -from typing import Any +from typing import TYPE_CHECKING, Any from rich.console import Console -from .config import ApiCoverageReportConfig +if TYPE_CHECKING: + from .config import ApiCoverageReportConfig def endpoint_to_regex(endpoint: str) -> Pattern[str]: