Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,5 @@ build:
@uv sync
@uv build

pipeline-local: format clean test cover typeguard test-example test-example-parallel

pipeline: format test cover typeguard test-example test-example-parallel

15 changes: 11 additions & 4 deletions src/pytest_api_cov/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,16 @@ class ApiCoverageReportConfig(BaseModel):
openapi_spec: str | None = Field(None, alias="api-cov-openapi-spec")


def read_toml_config() -> dict[str, Any]:
"""Read the [tool.pytest_api_cov] section from pyproject.toml."""
def read_toml_config(rootdir: Path | None = None) -> dict[str, Any]:
"""Read the [tool.pytest_api_cov] section from pyproject.toml.

Args:
rootdir: Project root directory. Falls back to CWD if not provided.

"""
toml_path = (rootdir or Path.cwd()) / "pyproject.toml"
try:
with Path("pyproject.toml").open("rb") as f:
with toml_path.open("rb") as f:
toml_config = tomli.load(f)
return toml_config.get("tool", {}).get("pytest_api_cov", {}) # type: ignore[no-any-return]
except (FileNotFoundError, tomli.TOMLDecodeError):
Expand Down Expand Up @@ -84,7 +90,8 @@ def get_pytest_api_cov_report_config(session_config: Any) -> ApiCoverageReportCo
if isinstance(cached, ApiCoverageReportConfig):
return cached

toml_config = read_toml_config()
rootdir = getattr(session_config, "rootpath", None) or getattr(session_config, "rootdir", None)
toml_config = read_toml_config(rootdir)
cli_config = read_session_config(session_config)

final_config = {**toml_config, **cli_config}
Expand Down
1 change: 1 addition & 0 deletions src/pytest_api_cov/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class SessionData(BaseModel):

recorder: ApiCallRecorder = Field(default_factory=ApiCallRecorder)
discovered_endpoints: EndpointDiscovery = Field(default_factory=EndpointDiscovery)
discovery_complete: bool = Field(default=False)

def record_call(self, endpoint: str, test_name: str, method: str = "GET") -> None:
"""Record an API call."""
Expand Down
123 changes: 68 additions & 55 deletions src/pytest_api_cov/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

def _discover_openapi_endpoints(config: ApiCoverageReportConfig, coverage_data: SessionData) -> None:
"""Discover endpoints from OpenAPI spec if configured."""
if coverage_data.discovery_complete:
return
if not config.openapi_spec or coverage_data.discovered_endpoints.endpoints:
return

Expand All @@ -29,11 +31,14 @@ def _discover_openapi_endpoints(config: ApiCoverageReportConfig, coverage_data:
method, path = endpoint_method.split(" ", 1)
coverage_data.add_discovered_endpoint(path, method, "openapi_spec")

coverage_data.discovery_complete = True
logger.info(f"> Discovered {len(endpoints)} endpoints from OpenAPI spec: {config.openapi_spec}")


def _discover_app_endpoints(app: Any, coverage_data: SessionData, fixture_name: str) -> None:
"""Discover endpoints from the app instance."""
if coverage_data.discovery_complete:
return
if not (app and is_supported_framework(app) and not coverage_data.discovered_endpoints.endpoints):
return

Expand All @@ -46,6 +51,7 @@ def _discover_app_endpoints(app: Any, coverage_data: SessionData, fixture_name:
method, path = endpoint_method.split(" ", 1)
coverage_data.add_discovered_endpoint(path, method, f"{framework_name.lower()}_adapter")

coverage_data.discovery_complete = True
logger.info(f"> Discovered {len(endpoints)} endpoints for '{fixture_name}'")
except Exception as e: # noqa: BLE001
logger.warning(f"> Failed to discover endpoints from app: {e}")
Expand Down Expand Up @@ -210,68 +216,75 @@ def fixture_func(request: pytest.FixtureRequest) -> Any:
return pytest.fixture(fixture_func)


def wrap_client_with_coverage(client: Any, recorder: Any, test_name: str) -> Any:
"""Wrap an existing test client with coverage tracking."""
if client is None or recorder is None:
return client
class CoverageWrapper:
"""Wraps a test client to record HTTP calls for coverage tracking."""

_TRACKED_NAMES = frozenset({"get", "post", "put", "delete", "patch", "head", "options", "request", "open"})

class CoverageWrapper:
def __init__(self, wrapped_client: Any) -> None:
self._wrapped = wrapped_client

_TRACKED_NAMES = frozenset({"get", "post", "put", "delete", "patch", "head", "options", "request", "open"})

def _extract_path_and_method(self, name: str, args: Any, kwargs: Any) -> tuple[str, str] | None:
"""Pull path and HTTP method from the call arguments."""
# .request(method, url, ...) - method is first arg, url is second
if name == "request":
req_method = (args[0] if args else kwargs.get("method", "GET")).upper()
req_url = args[1] if len(args) > 1 else kwargs.get("url")
if isinstance(req_url, str):
return req_url if "?" not in req_url else req_url.partition("?")[0], req_method
return None

# .get(url), .post(url), .open(url), etc. - url is first arg
if args:
first = args[0]
if isinstance(first, str):
path = first if "?" not in first else first.partition("?")[0]
method = kwargs.get("method", name).upper()
return path, ("GET" if method == "OPEN" else method)

if hasattr(first, "url") and hasattr(first.url, "path"):
try:
return first.url.path, getattr(first, "method", name).upper()
except Exception: # noqa: BLE001
pass

if kwargs:
path_kw = kwargs.get("path") or kwargs.get("url") or kwargs.get("uri")
if isinstance(path_kw, str):
path = path_kw if "?" not in path_kw else path_kw.partition("?")[0]
method = kwargs.get("method", name).upper()
return path, ("GET" if method == "OPEN" else method)
def __init__(self, wrapped_client: Any, recorder: Any, test_name: str) -> None:
"""Bind the client, recorder, and test name."""
self._wrapped = wrapped_client
self._recorder = recorder
self._test_name = test_name

def _extract_path_and_method(self, name: str, args: Any, kwargs: Any) -> tuple[str, str] | None:
"""Pull path and HTTP method from the call arguments."""
# .request(method, url, ...) - method is first arg, url is second
if name == "request":
req_method = (args[0] if args else kwargs.get("method", "GET")).upper()
req_url = args[1] if len(args) > 1 else kwargs.get("url")
if isinstance(req_url, str):
return req_url if "?" not in req_url else req_url.partition("?")[0], req_method
return None

def __getattr__(self, name: str) -> Any:
attr = getattr(self._wrapped, name)
if name not in self._TRACKED_NAMES:
return attr
# .get(url), .post(url), .open(url), etc. - url is first arg
if args:
first = args[0]
if isinstance(first, str):
path = first if "?" not in first else first.partition("?")[0]
method = kwargs.get("method", name).upper()
return path, ("GET" if method == "OPEN" else method)

def tracked(*args: Any, **kwargs: Any) -> Any:
response = attr(*args, **kwargs)
if recorder is not None:
pm = self._extract_path_and_method(name, args, kwargs)
if pm:
path, method = pm
recorder.record_call(path, test_name, method)
return response
if hasattr(first, "url") and hasattr(first.url, "path"):
try:
return first.url.path, getattr(first, "method", name).upper()
except Exception: # noqa: BLE001
pass

object.__setattr__(self, name, tracked)
return tracked
if kwargs:
path_kw = kwargs.get("path") or kwargs.get("url") or kwargs.get("uri")
if isinstance(path_kw, str):
path = path_kw if "?" not in path_kw else path_kw.partition("?")[0]
method = kwargs.get("method", name).upper()
return path, ("GET" if method == "OPEN" else method)

return None

def __getattr__(self, name: str) -> Any:
"""Intercept attribute access to wrap tracked HTTP methods."""
attr = getattr(self._wrapped, name)
if name not in self._TRACKED_NAMES:
return attr

def tracked(*args: Any, **kwargs: Any) -> Any:
response = attr(*args, **kwargs)
if self._recorder is not None:
pm = self._extract_path_and_method(name, args, kwargs)
if pm:
path, method = pm
self._recorder.record_call(path, self._test_name, method)
return response

object.__setattr__(self, name, tracked)
return tracked


def wrap_client_with_coverage(client: Any, recorder: Any, test_name: str) -> Any:
"""Wrap an existing test client with coverage tracking."""
if client is None or recorder is None:
return client

return CoverageWrapper(client)
return CoverageWrapper(client, recorder, test_name)


def _coverage_client_impl(request: pytest.FixtureRequest) -> Any:
Expand Down
50 changes: 32 additions & 18 deletions src/pytest_api_cov/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,37 @@ def contains_escape_characters(endpoint: str) -> bool:
return ("<" in endpoint and ">" in endpoint) or ("{" in endpoint and "}" in endpoint)


def _compile_exclusion_pattern(pat: str) -> tuple[frozenset[str] | None, Pattern[str]]:
"""Compile a single exclusion pattern into a (methods, regex) pair."""
path_pattern = pat.strip()
methods: frozenset[str] | None = None
m = re.match(r"^([A-Za-z,]+)\s+(.+)$", pat)
if m:
methods = frozenset(mname.strip().upper() for mname in m.group(1).split(",") if mname.strip())
path_pattern = m.group(2)
regex = re.compile("^" + re.escape(path_pattern).replace(r"\*", ".*") + "$")
return methods, regex


_CompiledPatterns = tuple[tuple[frozenset[str] | None, Pattern[str]], ...]


@lru_cache(maxsize=128)
def _compile_exclusion_patterns(
patterns: tuple[str, ...],
) -> tuple[_CompiledPatterns | None, _CompiledPatterns | None]:
"""Compile and cache exclusion/negation patterns.

Accepts a tuple (hashable) so the result can be cached across calls.
"""
exclusion_only = [p for p in patterns if not p.startswith("!")]
negation_only = [p[1:] for p in patterns if p.startswith("!")]

compiled_exclusions = tuple(_compile_exclusion_pattern(p) for p in exclusion_only) if exclusion_only else None
compiled_negations = tuple(_compile_exclusion_pattern(p) for p in negation_only) if negation_only else None
return compiled_exclusions, compiled_negations


def categorise_endpoints(
endpoints: list[str],
called_data: dict[str, set[str]],
Expand All @@ -47,24 +78,7 @@ def categorise_endpoints(
compiled_exclusions = None
compiled_negations = None
else:
exclusion_only = [p for p in exclusion_patterns if not p.startswith("!")]
negation_only = [p[1:] for p in exclusion_patterns if p.startswith("!")]

def compile_patterns(patterns: list[str]) -> list[tuple[set[str] | None, Pattern[str]]]:
compiled: list[tuple[set[str] | None, Pattern[str]]] = []
for pat in patterns:
path_pattern = pat.strip()
methods: set[str] | None = None
m = re.match(r"^([A-Za-z,]+)\s+(.+)$", pat)
if m:
methods = {mname.strip().upper() for mname in m.group(1).split(",") if mname.strip()}
path_pattern = m.group(2)
regex = re.compile("^" + re.escape(path_pattern).replace(r"\*", ".*") + "$")
compiled.append((methods, regex))
return compiled

compiled_exclusions = compile_patterns(exclusion_only) if exclusion_only else None
compiled_negations = compile_patterns(negation_only) if negation_only else None
compiled_exclusions, compiled_negations = _compile_exclusion_patterns(tuple(exclusion_patterns))

for endpoint in endpoints:
is_excluded = False
Expand Down
Loading