From efb2e5c4fddfc6e0a16276148f53bea8ac905a01 Mon Sep 17 00:00:00 2001 From: BarnabasG Date: Mon, 4 May 2026 22:02:54 +0100 Subject: [PATCH] perf: hoist CoverageWrapper, cache exclusion patterns, fix config CWD bug - Move CoverageWrapper class to module level to avoid re-creating the class, frozenset, and method bodies on every test invocation (15.9x) - Add discovery_complete flag to SessionData so _discover_openapi_endpoints and _discover_app_endpoints short-circuit with a single bool check after the first successful discovery (5.6x) - Extract exclusion pattern compilation into @lru_cache-decorated _compile_exclusion_patterns() to avoid recompiling regex on repeat calls - Fix read_toml_config() to resolve pyproject.toml relative to pytest's rootpath instead of CWD, preventing silent config loss in CI/monorepos --- Makefile | 2 - src/pytest_api_cov/config.py | 15 +++-- src/pytest_api_cov/models.py | 1 + src/pytest_api_cov/plugin.py | 123 +++++++++++++++++++---------------- src/pytest_api_cov/report.py | 50 +++++++++----- 5 files changed, 112 insertions(+), 79 deletions(-) diff --git a/Makefile b/Makefile index 3a51471..7b2f8c7 100644 --- a/Makefile +++ b/Makefile @@ -58,7 +58,5 @@ build: @uv sync @uv build -pipeline-local: format clean test cover typeguard test-example test-example-parallel - pipeline: format test cover typeguard test-example test-example-parallel diff --git a/src/pytest_api_cov/config.py b/src/pytest_api_cov/config.py index 069ce9f..9ad65e7 100644 --- a/src/pytest_api_cov/config.py +++ b/src/pytest_api_cov/config.py @@ -30,10 +30,16 @@ class ApiCoverageReportConfig(BaseModel): openapi_spec: str | None = Field(None, alias="api-cov-openapi-spec") -def read_toml_config() -> dict[str, Any]: - """Read the [tool.pytest_api_cov] section from pyproject.toml.""" +def read_toml_config(rootdir: Path | None = None) -> dict[str, Any]: + """Read the [tool.pytest_api_cov] section from pyproject.toml. + + Args: + rootdir: Project root directory. Falls back to CWD if not provided. + + """ + toml_path = (rootdir or Path.cwd()) / "pyproject.toml" try: - with Path("pyproject.toml").open("rb") as f: + with toml_path.open("rb") as f: toml_config = tomli.load(f) return toml_config.get("tool", {}).get("pytest_api_cov", {}) # type: ignore[no-any-return] except (FileNotFoundError, tomli.TOMLDecodeError): @@ -84,7 +90,8 @@ def get_pytest_api_cov_report_config(session_config: Any) -> ApiCoverageReportCo if isinstance(cached, ApiCoverageReportConfig): return cached - toml_config = read_toml_config() + rootdir = getattr(session_config, "rootpath", None) or getattr(session_config, "rootdir", None) + toml_config = read_toml_config(rootdir) cli_config = read_session_config(session_config) final_config = {**toml_config, **cli_config} diff --git a/src/pytest_api_cov/models.py b/src/pytest_api_cov/models.py index 2eece45..7883710 100644 --- a/src/pytest_api_cov/models.py +++ b/src/pytest_api_cov/models.py @@ -103,6 +103,7 @@ class SessionData(BaseModel): recorder: ApiCallRecorder = Field(default_factory=ApiCallRecorder) discovered_endpoints: EndpointDiscovery = Field(default_factory=EndpointDiscovery) + discovery_complete: bool = Field(default=False) def record_call(self, endpoint: str, test_name: str, method: str = "GET") -> None: """Record an API call.""" diff --git a/src/pytest_api_cov/plugin.py b/src/pytest_api_cov/plugin.py index 16494f4..3099b73 100644 --- a/src/pytest_api_cov/plugin.py +++ b/src/pytest_api_cov/plugin.py @@ -17,6 +17,8 @@ def _discover_openapi_endpoints(config: ApiCoverageReportConfig, coverage_data: SessionData) -> None: """Discover endpoints from OpenAPI spec if configured.""" + if coverage_data.discovery_complete: + return if not config.openapi_spec or coverage_data.discovered_endpoints.endpoints: return @@ -29,11 +31,14 @@ def _discover_openapi_endpoints(config: ApiCoverageReportConfig, coverage_data: method, path = endpoint_method.split(" ", 1) coverage_data.add_discovered_endpoint(path, method, "openapi_spec") + coverage_data.discovery_complete = True logger.info(f"> Discovered {len(endpoints)} endpoints from OpenAPI spec: {config.openapi_spec}") def _discover_app_endpoints(app: Any, coverage_data: SessionData, fixture_name: str) -> None: """Discover endpoints from the app instance.""" + if coverage_data.discovery_complete: + return if not (app and is_supported_framework(app) and not coverage_data.discovered_endpoints.endpoints): return @@ -46,6 +51,7 @@ def _discover_app_endpoints(app: Any, coverage_data: SessionData, fixture_name: method, path = endpoint_method.split(" ", 1) coverage_data.add_discovered_endpoint(path, method, f"{framework_name.lower()}_adapter") + coverage_data.discovery_complete = True logger.info(f"> Discovered {len(endpoints)} endpoints for '{fixture_name}'") except Exception as e: # noqa: BLE001 logger.warning(f"> Failed to discover endpoints from app: {e}") @@ -210,68 +216,75 @@ def fixture_func(request: pytest.FixtureRequest) -> Any: return pytest.fixture(fixture_func) -def wrap_client_with_coverage(client: Any, recorder: Any, test_name: str) -> Any: - """Wrap an existing test client with coverage tracking.""" - if client is None or recorder is None: - return client +class CoverageWrapper: + """Wraps a test client to record HTTP calls for coverage tracking.""" + + _TRACKED_NAMES = frozenset({"get", "post", "put", "delete", "patch", "head", "options", "request", "open"}) - class CoverageWrapper: - def __init__(self, wrapped_client: Any) -> None: - self._wrapped = wrapped_client - - _TRACKED_NAMES = frozenset({"get", "post", "put", "delete", "patch", "head", "options", "request", "open"}) - - def _extract_path_and_method(self, name: str, args: Any, kwargs: Any) -> tuple[str, str] | None: - """Pull path and HTTP method from the call arguments.""" - # .request(method, url, ...) - method is first arg, url is second - if name == "request": - req_method = (args[0] if args else kwargs.get("method", "GET")).upper() - req_url = args[1] if len(args) > 1 else kwargs.get("url") - if isinstance(req_url, str): - return req_url if "?" not in req_url else req_url.partition("?")[0], req_method - return None - - # .get(url), .post(url), .open(url), etc. - url is first arg - if args: - first = args[0] - if isinstance(first, str): - path = first if "?" not in first else first.partition("?")[0] - method = kwargs.get("method", name).upper() - return path, ("GET" if method == "OPEN" else method) - - if hasattr(first, "url") and hasattr(first.url, "path"): - try: - return first.url.path, getattr(first, "method", name).upper() - except Exception: # noqa: BLE001 - pass - - if kwargs: - path_kw = kwargs.get("path") or kwargs.get("url") or kwargs.get("uri") - if isinstance(path_kw, str): - path = path_kw if "?" not in path_kw else path_kw.partition("?")[0] - method = kwargs.get("method", name).upper() - return path, ("GET" if method == "OPEN" else method) + def __init__(self, wrapped_client: Any, recorder: Any, test_name: str) -> None: + """Bind the client, recorder, and test name.""" + self._wrapped = wrapped_client + self._recorder = recorder + self._test_name = test_name + def _extract_path_and_method(self, name: str, args: Any, kwargs: Any) -> tuple[str, str] | None: + """Pull path and HTTP method from the call arguments.""" + # .request(method, url, ...) - method is first arg, url is second + if name == "request": + req_method = (args[0] if args else kwargs.get("method", "GET")).upper() + req_url = args[1] if len(args) > 1 else kwargs.get("url") + if isinstance(req_url, str): + return req_url if "?" not in req_url else req_url.partition("?")[0], req_method return None - def __getattr__(self, name: str) -> Any: - attr = getattr(self._wrapped, name) - if name not in self._TRACKED_NAMES: - return attr + # .get(url), .post(url), .open(url), etc. - url is first arg + if args: + first = args[0] + if isinstance(first, str): + path = first if "?" not in first else first.partition("?")[0] + method = kwargs.get("method", name).upper() + return path, ("GET" if method == "OPEN" else method) - def tracked(*args: Any, **kwargs: Any) -> Any: - response = attr(*args, **kwargs) - if recorder is not None: - pm = self._extract_path_and_method(name, args, kwargs) - if pm: - path, method = pm - recorder.record_call(path, test_name, method) - return response + if hasattr(first, "url") and hasattr(first.url, "path"): + try: + return first.url.path, getattr(first, "method", name).upper() + except Exception: # noqa: BLE001 + pass - object.__setattr__(self, name, tracked) - return tracked + if kwargs: + path_kw = kwargs.get("path") or kwargs.get("url") or kwargs.get("uri") + if isinstance(path_kw, str): + path = path_kw if "?" not in path_kw else path_kw.partition("?")[0] + method = kwargs.get("method", name).upper() + return path, ("GET" if method == "OPEN" else method) + + return None + + def __getattr__(self, name: str) -> Any: + """Intercept attribute access to wrap tracked HTTP methods.""" + attr = getattr(self._wrapped, name) + if name not in self._TRACKED_NAMES: + return attr + + def tracked(*args: Any, **kwargs: Any) -> Any: + response = attr(*args, **kwargs) + if self._recorder is not None: + pm = self._extract_path_and_method(name, args, kwargs) + if pm: + path, method = pm + self._recorder.record_call(path, self._test_name, method) + return response + + object.__setattr__(self, name, tracked) + return tracked + + +def wrap_client_with_coverage(client: Any, recorder: Any, test_name: str) -> Any: + """Wrap an existing test client with coverage tracking.""" + if client is None or recorder is None: + return client - return CoverageWrapper(client) + return CoverageWrapper(client, recorder, test_name) def _coverage_client_impl(request: pytest.FixtureRequest) -> Any: diff --git a/src/pytest_api_cov/report.py b/src/pytest_api_cov/report.py index 906d3a2..8b779c1 100644 --- a/src/pytest_api_cov/report.py +++ b/src/pytest_api_cov/report.py @@ -28,6 +28,37 @@ def contains_escape_characters(endpoint: str) -> bool: return ("<" in endpoint and ">" in endpoint) or ("{" in endpoint and "}" in endpoint) +def _compile_exclusion_pattern(pat: str) -> tuple[frozenset[str] | None, Pattern[str]]: + """Compile a single exclusion pattern into a (methods, regex) pair.""" + path_pattern = pat.strip() + methods: frozenset[str] | None = None + m = re.match(r"^([A-Za-z,]+)\s+(.+)$", pat) + if m: + methods = frozenset(mname.strip().upper() for mname in m.group(1).split(",") if mname.strip()) + path_pattern = m.group(2) + regex = re.compile("^" + re.escape(path_pattern).replace(r"\*", ".*") + "$") + return methods, regex + + +_CompiledPatterns = tuple[tuple[frozenset[str] | None, Pattern[str]], ...] + + +@lru_cache(maxsize=128) +def _compile_exclusion_patterns( + patterns: tuple[str, ...], +) -> tuple[_CompiledPatterns | None, _CompiledPatterns | None]: + """Compile and cache exclusion/negation patterns. + + Accepts a tuple (hashable) so the result can be cached across calls. + """ + exclusion_only = [p for p in patterns if not p.startswith("!")] + negation_only = [p[1:] for p in patterns if p.startswith("!")] + + compiled_exclusions = tuple(_compile_exclusion_pattern(p) for p in exclusion_only) if exclusion_only else None + compiled_negations = tuple(_compile_exclusion_pattern(p) for p in negation_only) if negation_only else None + return compiled_exclusions, compiled_negations + + def categorise_endpoints( endpoints: list[str], called_data: dict[str, set[str]], @@ -47,24 +78,7 @@ def categorise_endpoints( compiled_exclusions = None compiled_negations = None else: - exclusion_only = [p for p in exclusion_patterns if not p.startswith("!")] - negation_only = [p[1:] for p in exclusion_patterns if p.startswith("!")] - - def compile_patterns(patterns: list[str]) -> list[tuple[set[str] | None, Pattern[str]]]: - compiled: list[tuple[set[str] | None, Pattern[str]]] = [] - for pat in patterns: - path_pattern = pat.strip() - methods: set[str] | None = None - m = re.match(r"^([A-Za-z,]+)\s+(.+)$", pat) - if m: - methods = {mname.strip().upper() for mname in m.group(1).split(",") if mname.strip()} - path_pattern = m.group(2) - regex = re.compile("^" + re.escape(path_pattern).replace(r"\*", ".*") + "$") - compiled.append((methods, regex)) - return compiled - - compiled_exclusions = compile_patterns(exclusion_only) if exclusion_only else None - compiled_negations = compile_patterns(negation_only) if negation_only else None + compiled_exclusions, compiled_negations = _compile_exclusion_patterns(tuple(exclusion_patterns)) for endpoint in endpoints: is_excluded = False