Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pytest-api-cov"
version = "1.3.6"
version = "1.3.7"
description = "Pytest Plugin to provide API Coverage statistics for Python Web Frameworks"
readme = "README.md"
authors = [{ name = "Barnaby Gill", email = "barnabasgill@gmail.com" }]
Expand All @@ -13,6 +13,7 @@ dependencies = [
"tomli>=1.2.0",
"pytest>=6.0.0",
"PyYAML>=6.0",
"backports.strenum>=1.3.1; python_version < '3.11'",
]

[project.optional-dependencies]
Expand Down
8 changes: 7 additions & 1 deletion src/pytest_api_cov/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def supports_unicode() -> bool:

def get_pytest_api_cov_report_config(session_config: Any) -> ApiCoverageReportConfig:
"""Build final config by merging sources. Priority: CLI > pyproject.toml > defaults."""
cached = getattr(session_config, "_api_cov_config_cache", None)
if isinstance(cached, ApiCoverageReportConfig):
return cached

toml_config = read_toml_config()
cli_config = read_session_config(session_config)

Expand All @@ -90,4 +94,6 @@ def get_pytest_api_cov_report_config(session_config: Any) -> ApiCoverageReportCo
elif "force_sugar" not in final_config:
final_config["force_sugar"] = supports_unicode()

return ApiCoverageReportConfig.model_validate(final_config)
result = ApiCoverageReportConfig.model_validate(final_config)
session_config._api_cov_config_cache = result
return result
70 changes: 50 additions & 20 deletions src/pytest_api_cov/frameworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,24 @@

from __future__ import annotations

import sys
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any

if sys.version_info >= (3, 11):
from enum import StrEnum
else:
from backports.strenum import StrEnum


class SupportedFramework(StrEnum):
"""String enum representing officially supported web frameworks."""

FLASK = "flask"
FASTAPI = "fastapi"
DJANGO = "django"


if TYPE_CHECKING:
from .models import ApiCallRecorder

Expand Down Expand Up @@ -48,14 +63,18 @@ def get_tracked_client(self, recorder: ApiCallRecorder | None, test_name: str) -
if recorder is None:
return self.app.test_client()

url_adapter = None
if hasattr(self.app.url_map, "bind"):
url_adapter = self.app.url_map.bind("")

class TrackingFlaskClient(FlaskClient):
def open(self, *args: Any, **kwargs: Any) -> Any:
path = kwargs.get("path") or (args[0] if args else None)
method = kwargs.get("method", "GET").upper()

if path and hasattr(self.application.url_map, "bind"):
if path and url_adapter is not None:
try:
endpoint_name, _ = self.application.url_map.bind("").match(path, method=method)
endpoint_name, _ = url_adapter.match(path, method=method)
endpoint_rule_string = next(self.application.url_map.iter_rules(endpoint_name)).rule
recorder.record_call(endpoint_rule_string, test_name, method) # type: ignore[union-attr]
except Exception: # noqa: BLE001
Expand Down Expand Up @@ -176,29 +195,40 @@ def _unwrap_wsgi_app(app: Any) -> Any:
return None


def _detect_framework(app: Any) -> SupportedFramework | None:
"""Lightweight check to detect the framework."""
app_type = type(app).__name__
module_name = getattr(type(app), "__module__", "").split(".")[0]

match (module_name, app_type):
case ("flask", "Flask") | ("flask_openapi3", "OpenAPI"):
return SupportedFramework.FLASK
case ("fastapi", "FastAPI"):
return SupportedFramework.FASTAPI
case (module, _) if module == "django" or "django" in module:
return SupportedFramework.DJANGO
case _:
return None


def is_supported_framework(app: Any) -> bool:
"""Check if the app is a supported framework."""
if app is None:
return False
try:
get_framework_adapter(app)
except TypeError:
return False
return True
return _detect_framework(app) is not None


def get_framework_adapter(app: Any) -> BaseAdapter:
"""Detect the framework and return the appropriate adapter."""
app_type = type(app).__name__
module_name = getattr(type(app), "__module__", "").split(".")[0]

if (module_name == "flask" and app_type == "Flask") or (module_name == "flask_openapi3" and app_type == "OpenAPI"):
return FlaskAdapter(app)
if module_name == "fastapi" and app_type == "FastAPI":
return FastAPIAdapter(app)
if module_name == "django" or "django" in module_name:
return DjangoAdapter(app)

raise TypeError(
f"Unsupported application type: {app_type}. pytest-api-coverage supports Flask, FastAPI, and Django."
)
match _detect_framework(app):
case SupportedFramework.FLASK:
return FlaskAdapter(app)
case SupportedFramework.FASTAPI:
return FastAPIAdapter(app)
case SupportedFramework.DJANGO:
return DjangoAdapter(app)
case _:
app_type = type(app).__name__
raise TypeError(
f"Unsupported application type: {app_type}. pytest-api-coverage supports Flask, FastAPI, and Django."
)
4 changes: 2 additions & 2 deletions src/pytest_api_cov/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing import Any

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PrivateAttr


class ApiCallRecorder(BaseModel):
Expand Down Expand Up @@ -72,7 +72,7 @@ class EndpointDiscovery(BaseModel):
"""Discovered API endpoints."""

endpoints: list[str] = Field(default_factory=list)
_seen: set[str] = set()
_seen: set[str] = PrivateAttr(default_factory=set)
discovery_source: str = Field(default="unknown")

def model_post_init(self, _: Any, /) -> None:
Expand Down
12 changes: 8 additions & 4 deletions src/pytest_api_cov/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@

from __future__ import annotations

import json
import logging
from pathlib import Path

import yaml

logger = logging.getLogger(__name__)

HTTP_METHODS = {"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS", "TRACE"}
Expand All @@ -22,7 +19,14 @@ def parse_openapi_spec(path: str) -> list[str]:

try:
with spec_path.open("r", encoding="utf-8") as f:
spec = yaml.safe_load(f) if spec_path.suffix.lower() in (".yaml", ".yml") else json.load(f)
if spec_path.suffix.lower() in (".yaml", ".yml"):
import yaml

spec = yaml.safe_load(f)
else:
import json

spec = json.load(f)
except Exception:
logger.exception("Failed to parse OpenAPI spec", exc_info=True)
return []
Expand Down
34 changes: 17 additions & 17 deletions src/pytest_api_cov/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,14 @@ def _extract_path_and_method(self, name: str, args: Any, kwargs: Any) -> tuple[s
req_method = (args[0] if args else kwargs.get("method", "GET")).upper()
req_url = args[1] if len(args) > 1 else kwargs.get("url")
if isinstance(req_url, str):
return req_url.partition("?")[0], req_method
return req_url if "?" not in req_url else req_url.partition("?")[0], req_method
return None

# .get(url), .post(url), .open(url), etc. - url is first arg
if args:
first = args[0]
if isinstance(first, str):
path = first.partition("?")[0]
path = first if "?" not in first else first.partition("?")[0]
method = kwargs.get("method", name).upper()
return path, ("GET" if method == "OPEN" else method)

Expand All @@ -248,28 +248,28 @@ def _extract_path_and_method(self, name: str, args: Any, kwargs: Any) -> tuple[s
if kwargs:
path_kw = kwargs.get("path") or kwargs.get("url") or kwargs.get("uri")
if isinstance(path_kw, str):
path = path_kw.partition("?")[0]
path = path_kw if "?" not in path_kw else path_kw.partition("?")[0]
method = kwargs.get("method", name).upper()
return path, ("GET" if method == "OPEN" else method)

return None

def __getattr__(self, name: str) -> Any:
attr = getattr(self._wrapped, name)
if name in self._TRACKED_NAMES:

def tracked(*args: Any, **kwargs: Any) -> Any:
response = attr(*args, **kwargs)
if recorder is not None:
pm = self._extract_path_and_method(name, args, kwargs)
if pm:
path, method = pm
recorder.record_call(path, test_name, method)
return response

return tracked

return attr
if name not in self._TRACKED_NAMES:
return attr

def tracked(*args: Any, **kwargs: Any) -> Any:
response = attr(*args, **kwargs)
if recorder is not None:
pm = self._extract_path_and_method(name, args, kwargs)
if pm:
path, method = pm
recorder.record_call(path, test_name, method)
return response

object.__setattr__(self, name, tracked)
return tracked

return CoverageWrapper(client)

Expand Down
2 changes: 2 additions & 0 deletions src/pytest_api_cov/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import json
import re
from functools import lru_cache
from pathlib import Path
from re import Pattern
from typing import TYPE_CHECKING, Any
Expand All @@ -14,6 +15,7 @@
from .config import ApiCoverageReportConfig


@lru_cache(maxsize=512)
def endpoint_to_regex(endpoint: str) -> Pattern[str]:
"""Create a regex pattern from an endpoint by replacing dynamic segments."""
placeholder = "___PLACEHOLDER___"
Expand Down
14 changes: 9 additions & 5 deletions tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_config_priority_cli_over_toml(self, mock_read_toml, mock_read_session):
mock_read_toml.return_value = {"fail_under": 90.0, "report_path": "toml.json"}
mock_read_session.return_value = {"fail_under": 75.0}

mock_session_config = Mock()
mock_session_config = Mock(spec=["getoption"])
final_config = get_pytest_api_cov_report_config(mock_session_config)

assert final_config.fail_under == 75.0
Expand All @@ -143,7 +143,8 @@ def test_pydantic_model_validation(self, mock_read_toml, mock_read_session):
"""Pydantic model validates and sets defaults correctly."""
mock_read_toml.return_value = {"fail_under": 90.0}

final_config = get_pytest_api_cov_report_config(Mock())
mock_session_config = Mock(spec=["getoption"])
final_config = get_pytest_api_cov_report_config(mock_session_config)

assert final_config.fail_under == 90.0
assert final_config.show_covered_endpoints is False
Expand All @@ -159,15 +160,18 @@ def test_force_sugar_setting(self, mock_supports_unicode, mock_read_toml, mock_r
mock_read_toml.return_value = {}

mock_read_session.return_value = {"force_sugar_disabled": True}
config = get_pytest_api_cov_report_config(Mock())
mock_session_config = Mock(spec=["getoption"])
config = get_pytest_api_cov_report_config(mock_session_config)
assert config.force_sugar is False

mock_read_session.return_value = {}
config = get_pytest_api_cov_report_config(Mock())
mock_session_config = Mock(spec=["getoption"])
config = get_pytest_api_cov_report_config(mock_session_config)
assert config.force_sugar is True

mock_read_session.return_value = {"force_sugar": False}
config = get_pytest_api_cov_report_config(Mock())
mock_session_config = Mock(spec=["getoption"])
config = get_pytest_api_cov_report_config(mock_session_config)
assert config.force_sugar is False

def test_pydantic_validation_error(self):
Expand Down
13 changes: 12 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading