diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index 05af2d67d8..0ca3d40c9b 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -22,6 +22,7 @@ get_http_target_regex_matching_callback_function, ) from pyrit.prompt_target.http_target.httpx_api_target import HTTPXAPITarget +from pyrit.prompt_target.http_target.mcp_auth_bypass_target import MCPAuthBypassTarget from pyrit.prompt_target.hugging_face.hugging_face_chat_target import HuggingFaceChatTarget from pyrit.prompt_target.hugging_face.hugging_face_endpoint_target import HuggingFaceEndpointTarget from pyrit.prompt_target.openai.openai_chat_audio_config import OpenAIChatAudioConfig @@ -50,6 +51,7 @@ "get_http_target_regex_matching_callback_function", "HTTPTarget", "HTTPXAPITarget", + "MCPAuthBypassTarget", "HuggingFaceChatTarget", "HuggingFaceEndpointTarget", "limit_requests_per_minute", diff --git a/pyrit/prompt_target/http_target/mcp_auth_bypass_target.py b/pyrit/prompt_target/http_target/mcp_auth_bypass_target.py new file mode 100644 index 0000000000..0db3cf69f9 --- /dev/null +++ b/pyrit/prompt_target/http_target/mcp_auth_bypass_target.py @@ -0,0 +1,137 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +import logging +from typing import Any, Optional + +import httpx + +from pyrit.models import ( + MessagePiece, + construct_response_from_request, +) +from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.utils import limit_requests_per_minute + +logger = logging.getLogger(__name__) + + +class MCPAuthBypassTarget(PromptTarget): + """ + MCPAuthBypassTarget tests MCP server endpoints for authentication bypass vulnerabilities. + Implements OWASP MCP-07 (Insufficient Authentication/Authorization) testing. + + Args: + mcp_server_url (str): The base URL of the MCP server endpoint. + bypass_technique (str): One of: no_auth, empty_token, malformed_token, role_escalation + mcp_method (str): The MCP JSON-RPC method to test. Defaults to tools/list. + timeout (int): Request timeout in seconds. Defaults to 30. + max_requests_per_minute (int, Optional): Rate limit for requests. + """ + + BYPASS_TECHNIQUES = ["no_auth", "empty_token", "malformed_token", "role_escalation"] + + def __init__( + self, + mcp_server_url: str, + bypass_technique: str = "no_auth", + mcp_method: str = "tools/list", + timeout: int = 30, + max_requests_per_minute: Optional[int] = None, + ) -> None: + """ + Initialize MCPAuthBypassTarget. + + Args: + mcp_server_url: The base URL of the MCP server endpoint. + bypass_technique: The auth bypass technique to use. + mcp_method: The MCP JSON-RPC method to test. + timeout: Request timeout in seconds. + max_requests_per_minute: Rate limit for requests. + + Raises: + ValueError: If bypass_technique is not a valid option. + """ + if bypass_technique not in self.BYPASS_TECHNIQUES: + raise ValueError( + f"Invalid bypass_technique '{bypass_technique}'. " + f"Must be one of: {self.BYPASS_TECHNIQUES}" + ) + super().__init__( + max_requests_per_minute=max_requests_per_minute, + endpoint=mcp_server_url, + ) + self.mcp_server_url = mcp_server_url + self.bypass_technique = bypass_technique + self.mcp_method = mcp_method + self.timeout = timeout + + def _build_headers(self) -> dict[str, str]: + base_headers = {"Content-Type": "application/json", "Accept": "application/json"} + if self.bypass_technique == "no_auth": + return base_headers + if self.bypass_technique == "empty_token": + return {**base_headers, "Authorization": "Bearer "} + if self.bypass_technique == "malformed_token": + return {**base_headers, "Authorization": "Bearer eyJhbGciOiJIUzI1NiJ9.invalid"} + if self.bypass_technique == "role_escalation": + tampered = "eyJhbGciOiJub25lIn0.eyJzdWIiOiJ0ZXN0Iiwicm9sZSI6ImFkbWluIn0." + return {**base_headers, "Authorization": f"Bearer {tampered}"} + return base_headers + + def _build_jsonrpc_payload(self, prompt: str) -> dict[str, Any]: + return {"jsonrpc": "2.0", "id": 1, "method": self.mcp_method, "params": {"prompt": prompt}} + + def _evaluate_response(self, status_code: int, response_body: str) -> str: + if status_code == 200: + return ( + f"[VULNERABILITY DETECTED] MCP-07 Auth Bypass succeeded using " + f"'{self.bypass_technique}'. Server returned HTTP 200.\nResponse: {response_body[:500]}" + ) + if status_code in (401, 403): + return ( + f"[SECURE] Server correctly rejected with HTTP {status_code} " + f"using '{self.bypass_technique}'.\nResponse: {response_body[:200]}" + ) + return ( + f"[INVESTIGATE] Unexpected HTTP {status_code} " + f"using '{self.bypass_technique}'.\nResponse: {response_body[:200]}" + ) + + def _validate_request(self, *, message) -> None: + """ + Validate the request message. MCP target accepts all text messages. + + Raises: + ValueError: If the message is None or empty. + """ + if not message: + raise ValueError("Message cannot be None or empty.") + + @limit_requests_per_minute + async def send_prompt_async(self, *, prompt_request: MessagePiece) -> MessagePiece: + """ + Send a prompt to the MCP server using the configured auth bypass technique. + + Args: + prompt_request: The prompt request to send. + + Returns: + MessagePiece: The response containing bypass test results. + """ + prompt_text = prompt_request.converted_value + headers = self._build_headers() + payload = self._build_jsonrpc_payload(prompt_text) + logger.info(f"MCPAuthBypassTarget: Testing '{self.bypass_technique}' against {self.mcp_server_url}") + try: + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.post(self.mcp_server_url, headers=headers, content=json.dumps(payload)) + result = self._evaluate_response(response.status_code, response.text) + except httpx.TimeoutException: + result = f"[ERROR] Request timed out after {self.timeout}s" + except httpx.ConnectError as e: + result = f"[ERROR] Connection failed to {self.mcp_server_url}: {e}" + except Exception as e: + result = f"[ERROR] Unexpected error: {type(e).__name__}: {e}" + return construct_response_from_request(request=prompt_request, response_text_pieces=[result]) diff --git a/tests/unit/build_scripts/test_check_links.py b/tests/unit/build_scripts/test_check_links.py new file mode 100644 index 0000000000..658a6dc4dc --- /dev/null +++ b/tests/unit/build_scripts/test_check_links.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import tempfile +from pathlib import Path + +import pytest + +from build_scripts.check_links import extract_urls, resolve_relative_url, strip_fragment + + +class TestStripFragment: + def test_removes_fragment(self) -> None: + assert strip_fragment("https://example.com/page#section") == "https://example.com/page" + + def test_no_fragment_unchanged(self) -> None: + assert strip_fragment("https://example.com/page") == "https://example.com/page" + + def test_empty_fragment(self) -> None: + assert strip_fragment("https://example.com/page#") == "https://example.com/page" + + def test_preserves_query_string(self) -> None: + result = strip_fragment("https://example.com/page?q=1#section") + assert "q=1" in result + assert "section" not in result + + +class TestResolveRelativeUrl: + def test_http_url_unchanged(self) -> None: + url = "https://example.com" + assert resolve_relative_url("/some/file.md", url) == url + + def test_mailto_unchanged(self) -> None: + url = "mailto:test@example.com" + assert resolve_relative_url("/some/file.md", url) == url + + def test_relative_url_resolved(self, tmp_path: Path) -> None: + base = str(tmp_path / "docs" / "file.md") + target = str(tmp_path / "docs" / "other.md") + Path(target).parent.mkdir(parents=True, exist_ok=True) + Path(target).write_text("# Other") + result = resolve_relative_url(base, "other.md") + assert "other" in result + + def test_relative_url_with_md_extension(self, tmp_path: Path) -> None: + base = str(tmp_path / "docs" / "file.md") + target = tmp_path / "docs" / "other.md" + target.parent.mkdir(parents=True, exist_ok=True) + target.write_text("# Other") + result = resolve_relative_url(base, "other") + assert result.endswith(".md") + + +class TestExtractUrls: + def test_extracts_markdown_links(self, tmp_path: Path) -> None: + f = tmp_path / "test.md" + f.write_text("[Click here](https://example.com)") + urls = extract_urls(str(f)) + assert "https://example.com" in urls + + def test_extracts_href_links(self, tmp_path: Path) -> None: + f = tmp_path / "test.html" + f.write_text('link') + urls = extract_urls(str(f)) + assert "https://example.com" in urls + + def test_extracts_src_links(self, tmp_path: Path) -> None: + f = tmp_path / "test.html" + f.write_text('') + urls = extract_urls(str(f)) + assert "https://example.com/image.png" in urls + + def test_empty_file_returns_no_urls(self, tmp_path: Path) -> None: + f = tmp_path / "empty.md" + f.write_text("") + urls = extract_urls(str(f)) + assert urls == [] + + def test_strips_fragments_from_extracted_urls(self, tmp_path: Path) -> None: + f = tmp_path / "test.md" + f.write_text("[link](https://example.com/page#section)") + urls = extract_urls(str(f)) + assert "https://example.com/page" in urls + assert not any("#section" in u for u in urls) diff --git a/tests/unit/build_scripts/test_generate_rss.py b/tests/unit/build_scripts/test_generate_rss.py new file mode 100644 index 0000000000..44f25d3be2 --- /dev/null +++ b/tests/unit/build_scripts/test_generate_rss.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import tempfile +from pathlib import Path + +import pytest + +from build_scripts.generate_rss import extract_date_from_filename, parse_blog_markdown + + +class TestExtractDateFromFilename: + def test_standard_date(self) -> None: + assert extract_date_from_filename("2024_12_3.md") == "2024-12-03" + + def test_double_digit_day_and_month(self) -> None: + assert extract_date_from_filename("2023_11_25.md") == "2023-11-25" + + def test_single_digit_month(self) -> None: + assert extract_date_from_filename("2024_1_15.md") == "2024-01-15" + + def test_returns_empty_for_invalid_filename(self) -> None: + assert extract_date_from_filename("no_date_here.md") == "" + + def test_returns_empty_for_non_numeric(self) -> None: + assert extract_date_from_filename("intro.md") == "" + + +class TestParseBlogMarkdown: + def test_extracts_title(self, tmp_path: Path) -> None: + f = tmp_path / "2024_01_01.md" + f.write_text("# My Blog Title\n\nSome description here.") + title, _ = parse_blog_markdown(f) + assert title == "My Blog Title" + + def test_extracts_description(self, tmp_path: Path) -> None: + f = tmp_path / "2024_01_01.md" + f.write_text("# Title\n\nThis is the description paragraph.") + _, desc = parse_blog_markdown(f) + assert "This is the description paragraph." in desc + + def test_skips_small_tag_in_description(self, tmp_path: Path) -> None: + f = tmp_path / "2024_01_01.md" + f.write_text("# Title\n\ndate info\n\nReal description here.") + _, desc = parse_blog_markdown(f) + assert "small" not in desc + assert "Real description here." in desc + + def test_empty_title_when_no_heading(self, tmp_path: Path) -> None: + f = tmp_path / "2024_01_01.md" + f.write_text("No heading here.\n\nJust paragraphs.") + title, _ = parse_blog_markdown(f) + assert title == "" + + def test_multiline_description_joined(self, tmp_path: Path) -> None: + f = tmp_path / "2024_01_01.md" + f.write_text("# Title\n\nLine one.\nLine two.") + _, desc = parse_blog_markdown(f) + assert "Line one." in desc + assert "Line two." in desc diff --git a/tests/unit/build_scripts/test_prepare_package.py b/tests/unit/build_scripts/test_prepare_package.py new file mode 100644 index 0000000000..047d9767cd --- /dev/null +++ b/tests/unit/build_scripts/test_prepare_package.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import shutil +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from build_scripts.prepare_package import build_frontend, copy_frontend_to_package + + +class TestBuildFrontend: + def test_returns_false_when_npm_not_found(self, tmp_path: Path) -> None: + with patch("subprocess.run", side_effect=FileNotFoundError): + result = build_frontend(tmp_path) + assert result is False + + def test_returns_false_when_package_json_missing(self, tmp_path: Path) -> None: + mock_run = MagicMock() + mock_run.return_value.stdout = "10.0.0\n" + with patch("subprocess.run", mock_run): + result = build_frontend(tmp_path) + assert result is False + + def test_returns_false_when_npm_install_fails(self, tmp_path: Path) -> None: + import subprocess + (tmp_path / "package.json").write_text("{}") + responses = [ + MagicMock(stdout="10.0.0\n"), + subprocess.CalledProcessError(1, "npm install", output="error"), + ] + with patch("subprocess.run", side_effect=responses): + result = build_frontend(tmp_path) + assert result is False + + def test_returns_false_when_npm_build_fails(self, tmp_path: Path) -> None: + import subprocess + (tmp_path / "package.json").write_text("{}") + responses = [ + MagicMock(stdout="10.0.0\n"), + MagicMock(), + subprocess.CalledProcessError(1, "npm run build", output="error"), + ] + with patch("subprocess.run", side_effect=responses): + result = build_frontend(tmp_path) + assert result is False + + def test_returns_true_when_build_succeeds(self, tmp_path: Path) -> None: + (tmp_path / "package.json").write_text("{}") + with patch("subprocess.run", return_value=MagicMock(stdout="10.0.0\n")): + result = build_frontend(tmp_path) + assert result is True + + +class TestCopyFrontendToPackage(object): + def test_returns_false_when_dist_missing(self, tmp_path: Path) -> None: + result = copy_frontend_to_package(tmp_path / "dist", tmp_path / "out") + assert result is False + + def test_returns_false_when_index_html_missing(self, tmp_path: Path) -> None: + dist = tmp_path / "dist" + dist.mkdir() + (dist / "main.js").write_text("console.log('hi')") + out = tmp_path / "out" + result = copy_frontend_to_package(dist, out) + assert result is False + + def test_returns_true_when_copy_succeeds(self, tmp_path: Path) -> None: + dist = tmp_path / "dist" + dist.mkdir() + (dist / "index.html").write_text("") + out = tmp_path / "out" + result = copy_frontend_to_package(dist, out) + assert result is True + assert (out / "index.html").exists() + + def test_removes_existing_output_dir(self, tmp_path: Path) -> None: + dist = tmp_path / "dist" + dist.mkdir() + (dist / "index.html").write_text("") + out = tmp_path / "out" + out.mkdir() + (out / "old_file.txt").write_text("old") + copy_frontend_to_package(dist, out) + assert not (out / "old_file.txt").exists() + assert (out / "index.html").exists() diff --git a/tests/unit/build_scripts/test_validate_docs.py b/tests/unit/build_scripts/test_validate_docs.py new file mode 100644 index 0000000000..bea434a49d --- /dev/null +++ b/tests/unit/build_scripts/test_validate_docs.py @@ -0,0 +1,86 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import tempfile +from pathlib import Path + +import pytest + +from build_scripts.validate_docs import find_orphaned_files, parse_toc_files, validate_toc_files + + +class TestParseTocFiles: + def test_extracts_single_file(self) -> None: + toc = [{"file": "intro"}] + result = parse_toc_files(toc) + assert "intro" in result + + def test_extracts_nested_children(self) -> None: + toc = [{"file": "parent", "children": [{"file": "child"}]}] + result = parse_toc_files(toc) + assert "parent" in result + assert "child" in result + + def test_ignores_entries_without_file(self) -> None: + toc = [{"title": "No file here"}] + result = parse_toc_files(toc) + assert len(result) == 0 + + def test_empty_toc(self) -> None: + result = parse_toc_files([]) + assert result == set() + + def test_normalizes_backslashes(self) -> None: + toc = [{"file": "setup\\install"}] + result = parse_toc_files(toc) + assert "setup/install" in result + + +class TestValidateTocFiles: + def test_no_errors_when_files_exist(self, tmp_path: Path) -> None: + (tmp_path / "intro.md").write_text("# Intro") + errors = validate_toc_files({"intro.md"}, tmp_path) + assert errors == [] + + def test_error_when_file_missing(self, tmp_path: Path) -> None: + errors = validate_toc_files({"missing.md"}, tmp_path) + assert len(errors) == 1 + assert "missing.md" in errors[0] + + def test_skips_api_generated_files(self, tmp_path: Path) -> None: + errors = validate_toc_files({"api/some_module"}, tmp_path) + assert errors == [] + + def test_multiple_missing_files(self, tmp_path: Path) -> None: + errors = validate_toc_files({"a.md", "b.md"}, tmp_path) + assert len(errors) == 2 + + +class TestFindOrphanedFiles: + def test_no_orphans_when_all_referenced(self, tmp_path: Path) -> None: + (tmp_path / "intro.md").write_text("# Intro") + orphaned = find_orphaned_files({"intro.md"}, tmp_path) + assert orphaned == [] + + def test_detects_orphaned_markdown(self, tmp_path: Path) -> None: + (tmp_path / "orphan.md").write_text("# Orphan") + orphaned = find_orphaned_files(set(), tmp_path) + assert any("orphan.md" in o for o in orphaned) + + def test_skips_build_directory(self, tmp_path: Path) -> None: + build_dir = tmp_path / "_build" + build_dir.mkdir() + (build_dir / "generated.md").write_text("# Generated") + orphaned = find_orphaned_files(set(), tmp_path) + assert not any("_build" in o for o in orphaned) + + def test_skips_myst_yml(self, tmp_path: Path) -> None: + (tmp_path / "myst.yml").write_text("project:") + orphaned = find_orphaned_files(set(), tmp_path) + assert not any("myst.yml" in o for o in orphaned) + + def test_skips_py_companion_files(self, tmp_path: Path) -> None: + (tmp_path / "notebook.ipynb").write_text("{}") + (tmp_path / "notebook.py").write_text("# companion") + orphaned = find_orphaned_files(set(), tmp_path) + assert not any("notebook.py" in o for o in orphaned) diff --git a/tests/unit/target/test_mcp_auth_bypass_target.py b/tests/unit/target/test_mcp_auth_bypass_target.py new file mode 100644 index 0000000000..d34fbea911 --- /dev/null +++ b/tests/unit/target/test_mcp_auth_bypass_target.py @@ -0,0 +1,66 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from unittest.mock import MagicMock + +import pytest + +from pyrit.prompt_target.http_target.mcp_auth_bypass_target import MCPAuthBypassTarget + + +def make_mock_request(text="test prompt"): + req = MagicMock() + req.converted_value = text + return req + + +class TestMCPAuthBypassTargetInit: + def test_valid_bypass_technique(self, sqlite_instance): + target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="no_auth") + assert target.bypass_technique == "no_auth" + + def test_invalid_bypass_technique_raises(self, sqlite_instance): + with pytest.raises(ValueError, match="Invalid bypass_technique"): + MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="invalid") + + def test_default_values(self, sqlite_instance): + target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080") + assert target.bypass_technique == "no_auth" + assert target.mcp_method == "tools/list" + assert target.timeout == 30 + + +class TestMCPAuthBypassTargetHeaders: + def test_no_auth_has_no_authorization_header(self, sqlite_instance): + target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="no_auth") + assert "Authorization" not in target._build_headers() + + def test_empty_token_has_empty_bearer(self, sqlite_instance): + target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="empty_token") + assert target._build_headers()["Authorization"] == "Bearer " + + def test_malformed_token_has_invalid_jwt(self, sqlite_instance): + target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="malformed_token") + assert "invalid" in target._build_headers()["Authorization"] + + def test_role_escalation_has_tampered_token(self, sqlite_instance): + target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="role_escalation") + assert "eyJhbGciOiJub25lIn0" in target._build_headers()["Authorization"] + + +class TestMCPAuthBypassTargetEvaluate: + def test_200_detected_as_vulnerability(self, sqlite_instance): + target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="no_auth") + assert "VULNERABILITY DETECTED" in target._evaluate_response(200, "ok") + + def test_401_detected_as_secure(self, sqlite_instance): + target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="no_auth") + assert "SECURE" in target._evaluate_response(401, "Unauthorized") + + def test_403_detected_as_secure(self, sqlite_instance): + target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="no_auth") + assert "SECURE" in target._evaluate_response(403, "Forbidden") + + def test_500_flagged_for_investigation(self, sqlite_instance): + target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="no_auth") + assert "INVESTIGATE" in target._evaluate_response(500, "Server Error")