Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyrit/prompt_target/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
get_http_target_regex_matching_callback_function,
)
from pyrit.prompt_target.http_target.httpx_api_target import HTTPXAPITarget
from pyrit.prompt_target.http_target.mcp_auth_bypass_target import MCPAuthBypassTarget
from pyrit.prompt_target.hugging_face.hugging_face_chat_target import HuggingFaceChatTarget
from pyrit.prompt_target.hugging_face.hugging_face_endpoint_target import HuggingFaceEndpointTarget
from pyrit.prompt_target.openai.openai_chat_audio_config import OpenAIChatAudioConfig
Expand Down Expand Up @@ -50,6 +51,7 @@
"get_http_target_regex_matching_callback_function",
"HTTPTarget",
"HTTPXAPITarget",
"MCPAuthBypassTarget",
"HuggingFaceChatTarget",
"HuggingFaceEndpointTarget",
"limit_requests_per_minute",
Expand Down
137 changes: 137 additions & 0 deletions pyrit/prompt_target/http_target/mcp_auth_bypass_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import json
import logging
from typing import Any, Optional

import httpx

from pyrit.models import (
MessagePiece,
construct_response_from_request,
)
from pyrit.prompt_target.common.prompt_target import PromptTarget
from pyrit.prompt_target.common.utils import limit_requests_per_minute

logger = logging.getLogger(__name__)


class MCPAuthBypassTarget(PromptTarget):
"""
MCPAuthBypassTarget tests MCP server endpoints for authentication bypass vulnerabilities.
Implements OWASP MCP-07 (Insufficient Authentication/Authorization) testing.

Args:
mcp_server_url (str): The base URL of the MCP server endpoint.
bypass_technique (str): One of: no_auth, empty_token, malformed_token, role_escalation
mcp_method (str): The MCP JSON-RPC method to test. Defaults to tools/list.
timeout (int): Request timeout in seconds. Defaults to 30.
max_requests_per_minute (int, Optional): Rate limit for requests.
"""

BYPASS_TECHNIQUES = ["no_auth", "empty_token", "malformed_token", "role_escalation"]

def __init__(
self,
mcp_server_url: str,
bypass_technique: str = "no_auth",
mcp_method: str = "tools/list",
timeout: int = 30,
max_requests_per_minute: Optional[int] = None,
) -> None:
"""
Initialize MCPAuthBypassTarget.

Args:
mcp_server_url: The base URL of the MCP server endpoint.
bypass_technique: The auth bypass technique to use.
mcp_method: The MCP JSON-RPC method to test.
timeout: Request timeout in seconds.
max_requests_per_minute: Rate limit for requests.

Raises:
ValueError: If bypass_technique is not a valid option.
"""
if bypass_technique not in self.BYPASS_TECHNIQUES:
raise ValueError(
f"Invalid bypass_technique '{bypass_technique}'. "
f"Must be one of: {self.BYPASS_TECHNIQUES}"
)
super().__init__(
max_requests_per_minute=max_requests_per_minute,
endpoint=mcp_server_url,
)
self.mcp_server_url = mcp_server_url
self.bypass_technique = bypass_technique
self.mcp_method = mcp_method
self.timeout = timeout

def _build_headers(self) -> dict[str, str]:
base_headers = {"Content-Type": "application/json", "Accept": "application/json"}
if self.bypass_technique == "no_auth":
return base_headers
if self.bypass_technique == "empty_token":
return {**base_headers, "Authorization": "Bearer "}
if self.bypass_technique == "malformed_token":
return {**base_headers, "Authorization": "Bearer eyJhbGciOiJIUzI1NiJ9.invalid"}
if self.bypass_technique == "role_escalation":
tampered = "eyJhbGciOiJub25lIn0.eyJzdWIiOiJ0ZXN0Iiwicm9sZSI6ImFkbWluIn0."
return {**base_headers, "Authorization": f"Bearer {tampered}"}
return base_headers

def _build_jsonrpc_payload(self, prompt: str) -> dict[str, Any]:
return {"jsonrpc": "2.0", "id": 1, "method": self.mcp_method, "params": {"prompt": prompt}}

def _evaluate_response(self, status_code: int, response_body: str) -> str:
if status_code == 200:
return (
f"[VULNERABILITY DETECTED] MCP-07 Auth Bypass succeeded using "
f"'{self.bypass_technique}'. Server returned HTTP 200.\nResponse: {response_body[:500]}"
)
if status_code in (401, 403):
return (
f"[SECURE] Server correctly rejected with HTTP {status_code} "
f"using '{self.bypass_technique}'.\nResponse: {response_body[:200]}"
)
return (
f"[INVESTIGATE] Unexpected HTTP {status_code} "
f"using '{self.bypass_technique}'.\nResponse: {response_body[:200]}"
)

def _validate_request(self, *, message) -> None:
"""
Validate the request message. MCP target accepts all text messages.

Raises:
ValueError: If the message is None or empty.
"""
if not message:
raise ValueError("Message cannot be None or empty.")

@limit_requests_per_minute
async def send_prompt_async(self, *, prompt_request: MessagePiece) -> MessagePiece:
"""
Send a prompt to the MCP server using the configured auth bypass technique.

Args:
prompt_request: The prompt request to send.

Returns:
MessagePiece: The response containing bypass test results.
"""
prompt_text = prompt_request.converted_value
headers = self._build_headers()
payload = self._build_jsonrpc_payload(prompt_text)
logger.info(f"MCPAuthBypassTarget: Testing '{self.bypass_technique}' against {self.mcp_server_url}")
try:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(self.mcp_server_url, headers=headers, content=json.dumps(payload))
result = self._evaluate_response(response.status_code, response.text)
except httpx.TimeoutException:
result = f"[ERROR] Request timed out after {self.timeout}s"
except httpx.ConnectError as e:
result = f"[ERROR] Connection failed to {self.mcp_server_url}: {e}"
except Exception as e:
result = f"[ERROR] Unexpected error: {type(e).__name__}: {e}"
return construct_response_from_request(request=prompt_request, response_text_pieces=[result])
84 changes: 84 additions & 0 deletions tests/unit/build_scripts/test_check_links.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import tempfile
from pathlib import Path

import pytest

from build_scripts.check_links import extract_urls, resolve_relative_url, strip_fragment


class TestStripFragment:
def test_removes_fragment(self) -> None:
assert strip_fragment("https://example.com/page#section") == "https://example.com/page"

def test_no_fragment_unchanged(self) -> None:
assert strip_fragment("https://example.com/page") == "https://example.com/page"

def test_empty_fragment(self) -> None:
assert strip_fragment("https://example.com/page#") == "https://example.com/page"

def test_preserves_query_string(self) -> None:
result = strip_fragment("https://example.com/page?q=1#section")
assert "q=1" in result
assert "section" not in result


class TestResolveRelativeUrl:
def test_http_url_unchanged(self) -> None:
url = "https://example.com"
assert resolve_relative_url("/some/file.md", url) == url

def test_mailto_unchanged(self) -> None:
url = "mailto:test@example.com"
assert resolve_relative_url("/some/file.md", url) == url

def test_relative_url_resolved(self, tmp_path: Path) -> None:
base = str(tmp_path / "docs" / "file.md")
target = str(tmp_path / "docs" / "other.md")
Path(target).parent.mkdir(parents=True, exist_ok=True)
Path(target).write_text("# Other")
result = resolve_relative_url(base, "other.md")
assert "other" in result

def test_relative_url_with_md_extension(self, tmp_path: Path) -> None:
base = str(tmp_path / "docs" / "file.md")
target = tmp_path / "docs" / "other.md"
target.parent.mkdir(parents=True, exist_ok=True)
target.write_text("# Other")
result = resolve_relative_url(base, "other")
assert result.endswith(".md")


class TestExtractUrls:
def test_extracts_markdown_links(self, tmp_path: Path) -> None:
f = tmp_path / "test.md"
f.write_text("[Click here](https://example.com)")
urls = extract_urls(str(f))
assert "https://example.com" in urls

def test_extracts_href_links(self, tmp_path: Path) -> None:
f = tmp_path / "test.html"
f.write_text('<a href="https://example.com">link</a>')
urls = extract_urls(str(f))
assert "https://example.com" in urls

def test_extracts_src_links(self, tmp_path: Path) -> None:
f = tmp_path / "test.html"
f.write_text('<img src="https://example.com/image.png">')
urls = extract_urls(str(f))
assert "https://example.com/image.png" in urls

def test_empty_file_returns_no_urls(self, tmp_path: Path) -> None:
f = tmp_path / "empty.md"
f.write_text("")
urls = extract_urls(str(f))
assert urls == []

def test_strips_fragments_from_extracted_urls(self, tmp_path: Path) -> None:
f = tmp_path / "test.md"
f.write_text("[link](https://example.com/page#section)")
urls = extract_urls(str(f))
assert "https://example.com/page" in urls
assert not any("#section" in u for u in urls)
60 changes: 60 additions & 0 deletions tests/unit/build_scripts/test_generate_rss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import tempfile
from pathlib import Path

import pytest

from build_scripts.generate_rss import extract_date_from_filename, parse_blog_markdown


class TestExtractDateFromFilename:
def test_standard_date(self) -> None:
assert extract_date_from_filename("2024_12_3.md") == "2024-12-03"

def test_double_digit_day_and_month(self) -> None:
assert extract_date_from_filename("2023_11_25.md") == "2023-11-25"

def test_single_digit_month(self) -> None:
assert extract_date_from_filename("2024_1_15.md") == "2024-01-15"

def test_returns_empty_for_invalid_filename(self) -> None:
assert extract_date_from_filename("no_date_here.md") == ""

def test_returns_empty_for_non_numeric(self) -> None:
assert extract_date_from_filename("intro.md") == ""


class TestParseBlogMarkdown:
def test_extracts_title(self, tmp_path: Path) -> None:
f = tmp_path / "2024_01_01.md"
f.write_text("# My Blog Title\n\nSome description here.")
title, _ = parse_blog_markdown(f)
assert title == "My Blog Title"

def test_extracts_description(self, tmp_path: Path) -> None:
f = tmp_path / "2024_01_01.md"
f.write_text("# Title\n\nThis is the description paragraph.")
_, desc = parse_blog_markdown(f)
assert "This is the description paragraph." in desc

def test_skips_small_tag_in_description(self, tmp_path: Path) -> None:
f = tmp_path / "2024_01_01.md"
f.write_text("# Title\n\n<small>date info</small>\n\nReal description here.")
_, desc = parse_blog_markdown(f)
assert "small" not in desc
assert "Real description here." in desc

def test_empty_title_when_no_heading(self, tmp_path: Path) -> None:
f = tmp_path / "2024_01_01.md"
f.write_text("No heading here.\n\nJust paragraphs.")
title, _ = parse_blog_markdown(f)
assert title == ""

def test_multiline_description_joined(self, tmp_path: Path) -> None:
f = tmp_path / "2024_01_01.md"
f.write_text("# Title\n\nLine one.\nLine two.")
_, desc = parse_blog_markdown(f)
assert "Line one." in desc
assert "Line two." in desc
88 changes: 88 additions & 0 deletions tests/unit/build_scripts/test_prepare_package.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import shutil
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch

import pytest

from build_scripts.prepare_package import build_frontend, copy_frontend_to_package


class TestBuildFrontend:
def test_returns_false_when_npm_not_found(self, tmp_path: Path) -> None:
with patch("subprocess.run", side_effect=FileNotFoundError):
result = build_frontend(tmp_path)
assert result is False

def test_returns_false_when_package_json_missing(self, tmp_path: Path) -> None:
mock_run = MagicMock()
mock_run.return_value.stdout = "10.0.0\n"
with patch("subprocess.run", mock_run):
result = build_frontend(tmp_path)
assert result is False

def test_returns_false_when_npm_install_fails(self, tmp_path: Path) -> None:
import subprocess
(tmp_path / "package.json").write_text("{}")
responses = [
MagicMock(stdout="10.0.0\n"),
subprocess.CalledProcessError(1, "npm install", output="error"),
]
with patch("subprocess.run", side_effect=responses):
result = build_frontend(tmp_path)
assert result is False

def test_returns_false_when_npm_build_fails(self, tmp_path: Path) -> None:
import subprocess
(tmp_path / "package.json").write_text("{}")
responses = [
MagicMock(stdout="10.0.0\n"),
MagicMock(),
subprocess.CalledProcessError(1, "npm run build", output="error"),
]
with patch("subprocess.run", side_effect=responses):
result = build_frontend(tmp_path)
assert result is False

def test_returns_true_when_build_succeeds(self, tmp_path: Path) -> None:
(tmp_path / "package.json").write_text("{}")
with patch("subprocess.run", return_value=MagicMock(stdout="10.0.0\n")):
result = build_frontend(tmp_path)
assert result is True


class TestCopyFrontendToPackage(object):
def test_returns_false_when_dist_missing(self, tmp_path: Path) -> None:
result = copy_frontend_to_package(tmp_path / "dist", tmp_path / "out")
assert result is False

def test_returns_false_when_index_html_missing(self, tmp_path: Path) -> None:
dist = tmp_path / "dist"
dist.mkdir()
(dist / "main.js").write_text("console.log('hi')")
out = tmp_path / "out"
result = copy_frontend_to_package(dist, out)
assert result is False

def test_returns_true_when_copy_succeeds(self, tmp_path: Path) -> None:
dist = tmp_path / "dist"
dist.mkdir()
(dist / "index.html").write_text("<html></html>")
out = tmp_path / "out"
result = copy_frontend_to_package(dist, out)
assert result is True
assert (out / "index.html").exists()

def test_removes_existing_output_dir(self, tmp_path: Path) -> None:
dist = tmp_path / "dist"
dist.mkdir()
(dist / "index.html").write_text("<html></html>")
out = tmp_path / "out"
out.mkdir()
(out / "old_file.txt").write_text("old")
copy_frontend_to_package(dist, out)
assert not (out / "old_file.txt").exists()
assert (out / "index.html").exists()
Loading