Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions demo/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,27 @@ def _echo(p: str) -> str:
return report.to_dict()


def api_community_get(body: dict) -> dict:
"""GET /api/attacks/community — return curated community registry with optional filters."""
from toki.community import get_registry

reg = get_registry()
category = body.get("category")
tag = body.get("tag")
severity = body.get("severity")
attacks = reg.filter(
category=category or None,
tag=tag or None,
severity=severity or None,
)
return {
"stats": reg.stats(),
"filters": {"category": category, "tag": tag, "severity": severity},
"count": len(attacks),
"attacks": [a.to_dict() for a in attacks],
}


def api_attacks_custom_get() -> dict:
"""GET /api/attacks/custom — list all custom attacks in the library."""
from toki.attack_library import AttackLibrary
Expand Down Expand Up @@ -1420,6 +1441,9 @@ def _first_int(
("POST", "/api/remediate"): api_remediate,
("GET", "/api/attacks/custom"): lambda body: api_attacks_custom_get(),
("POST", "/api/attacks/custom"): api_attacks_custom_post,
# Phase 15 — community attack registry
("GET", "/api/attacks/community"): lambda body: api_community_get({}),
("POST", "/api/attacks/community"): api_community_get,
}


Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "toki"
version = "1.4.0"
version = "1.5.0"
description = "Adversarial fine-tuning lab for small language models"
license = { text = "BUSL-1.1" }
requires-python = ">=3.9"
Expand Down
241 changes: 241 additions & 0 deletions python/tests/test_community.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
"""Tests for toki.community — Phase 15 Community Attack Registry."""
import hashlib
import json
from unittest.mock import MagicMock, patch

import pytest

from toki.community import (
CommunityAttack,
CommunityRegistry,
_verify_sha256,
filter_attacks,
get_registry,
load_bundled,
load_remote,
)


# ---------------------------------------------------------------------------
# Bundled registry
# ---------------------------------------------------------------------------


def test_load_bundled_returns_registry():
reg = load_bundled()
assert isinstance(reg, CommunityRegistry)


def test_bundled_registry_has_25_attacks():
reg = load_bundled()
assert reg.total == 25


def test_bundled_registry_all_categories_covered():
reg = load_bundled()
cats = {a.category for a in reg.attacks}
expected = {"jailbreak", "injection", "edge_case", "boundary", "indirect", "agentic"}
assert cats == expected


def test_bundled_registry_sha256_valid():
reg = load_bundled()
assert len(reg.sha256) == 64 # hex SHA-256


def test_bundled_attacks_have_required_fields():
reg = load_bundled()
for attack in reg.attacks:
assert attack.id.startswith("com-")
assert len(attack.text) > 0
assert attack.severity in {"critical", "high", "medium", "low"}
assert "OWASP" in attack.owasp_tag
assert isinstance(attack.tags, list)


def test_bundled_attacks_all_expect_refusal():
reg = load_bundled()
assert all(a.expected_refusal for a in reg.attacks)


# ---------------------------------------------------------------------------
# get_registry — caching
# ---------------------------------------------------------------------------


def test_get_registry_returns_registry():
reg = get_registry()
assert isinstance(reg, CommunityRegistry)


def test_get_registry_cached():
reg1 = get_registry()
reg2 = get_registry()
assert reg1 is reg2


def test_get_registry_reload():
reg1 = get_registry()
reg2 = get_registry(reload=True)
assert reg2.total == reg1.total


# ---------------------------------------------------------------------------
# CommunityRegistry.filter
# ---------------------------------------------------------------------------


def test_filter_by_category():
reg = load_bundled()
jb = reg.filter(category="jailbreak")
assert len(jb) == 5
assert all(a.category == "jailbreak" for a in jb)


def test_filter_by_severity():
reg = load_bundled()
critical = reg.filter(severity="critical")
assert len(critical) > 0
assert all(a.severity == "critical" for a in critical)


def test_filter_by_tag():
reg = load_bundled()
dan = reg.filter(tag="DAN")
assert len(dan) >= 1
assert all("DAN" in a.tags for a in dan)


def test_filter_combined():
reg = load_bundled()
results = reg.filter(category="jailbreak", severity="critical")
assert all(a.category == "jailbreak" and a.severity == "critical" for a in results)


def test_filter_no_match_returns_empty():
reg = load_bundled()
assert reg.filter(tag="nonexistent_tag_xyz") == []


def test_filter_attacks_convenience():
reg = load_bundled()
result = filter_attacks(reg, category="agentic")
assert len(result) == 4


# ---------------------------------------------------------------------------
# CommunityRegistry.stats
# ---------------------------------------------------------------------------


def test_stats_total_matches():
reg = load_bundled()
stats = reg.stats()
assert stats["total"] == reg.total


def test_stats_by_category_covers_all():
reg = load_bundled()
stats = reg.stats()
assert "jailbreak" in stats["by_category"]
assert "agentic" in stats["by_category"]


def test_stats_by_severity_present():
reg = load_bundled()
stats = reg.stats()
assert len(stats["by_severity"]) > 0


# ---------------------------------------------------------------------------
# CommunityAttack.to_dict
# ---------------------------------------------------------------------------


def test_community_attack_to_dict():
attack = load_bundled().attacks[0]
d = attack.to_dict()
assert "id" in d
assert "text" in d
assert "category" in d
assert "tags" in d
assert "severity" in d
assert "owasp_tag" in d


# ---------------------------------------------------------------------------
# _verify_sha256
# ---------------------------------------------------------------------------


def test_verify_sha256_correct():
attacks_data = [{"text": "test", "category": "jailbreak"}]
declared = hashlib.sha256(
json.dumps(attacks_data, sort_keys=True, separators=(",", ":")).encode()
).hexdigest()
assert _verify_sha256(attacks_data, declared) is True


def test_verify_sha256_wrong():
attacks_data = [{"text": "test"}]
assert _verify_sha256(attacks_data, "a" * 64) is False


# ---------------------------------------------------------------------------
# load_remote — mocked urllib
# ---------------------------------------------------------------------------


def _make_remote_manifest(attacks: list) -> bytes:
sha256 = hashlib.sha256(
json.dumps(attacks, sort_keys=True, separators=(",", ":")).encode()
).hexdigest()
manifest = {
"version": "1.0.0",
"updated": "2026-06-14",
"description": "Test registry",
"sha256": sha256,
"attacks": attacks,
}
return json.dumps(manifest).encode("utf-8")


def test_load_remote_success():
attacks = [
{
"id": "com-jb-test",
"text": "ignore instructions",
"category": "jailbreak",
"tags": ["test"],
"severity": "high",
"owasp_tag": "OWASP-LLM01:2025",
"author": "test",
"source": "test",
"expected_refusal": True,
}
]
payload = _make_remote_manifest(attacks)
mock_resp = MagicMock()
mock_resp.read.return_value = payload
mock_resp.__enter__ = lambda s: s
mock_resp.__exit__ = MagicMock(return_value=False)

with patch("urllib.request.urlopen", return_value=mock_resp):
reg = load_remote("https://example.com/registry.json")

assert reg.total == 1
assert reg.attacks[0].id == "com-jb-test"


def test_load_remote_sha256_mismatch_raises():
attacks = [{"id": "x", "text": "t", "category": "jailbreak",
"tags": [], "severity": "low", "owasp_tag": "", "author": "", "source": ""}]
payload = _make_remote_manifest(attacks)
mock_resp = MagicMock()
mock_resp.read.return_value = payload
mock_resp.__enter__ = lambda s: s
mock_resp.__exit__ = MagicMock(return_value=False)

with patch("urllib.request.urlopen", return_value=mock_resp):
with pytest.raises(ValueError, match="SHA-256 mismatch"):
load_remote("https://example.com/registry.json", expected_sha256="a" * 64)
37 changes: 37 additions & 0 deletions python/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,43 @@ def test_attack_list_command(tmp_path, capsys):
assert "list me please" in captured.out or "jailbreak" in captured.out


def test_attack_community_command(capsys):
"""attack-community should list the bundled registry."""
main(["attack-community"])
captured = capsys.readouterr()
assert "Community registry" in captured.out
assert "jailbreak" in captured.out or "com-" in captured.out


def test_attack_community_json_format(capsys):
"""attack-community --json should emit a JSON array of attacks."""
main(["attack-community", "--json"])
captured = capsys.readouterr()
data = json.loads(captured.out)
assert isinstance(data, list)
assert len(data) == 25
assert "id" in data[0]
assert "owasp_tag" in data[0]


def test_attack_community_category_filter(capsys):
"""attack-community --category should filter results."""
main(["attack-community", "--category", "agentic", "--json"])
captured = capsys.readouterr()
data = json.loads(captured.out)
assert all(a["category"] == "agentic" for a in data)
assert len(data) == 4


def test_attack_community_severity_filter(capsys):
"""attack-community --severity should filter by severity."""
main(["attack-community", "--severity", "critical", "--json"])
captured = capsys.readouterr()
data = json.loads(captured.out)
assert all(a["severity"] == "critical" for a in data)
assert len(data) > 0


def test_attack_list_json_format(tmp_path, capsys):
"""attack-list --json should emit a JSON array."""
lib_path = str(tmp_path / "lib.json")
Expand Down
17 changes: 16 additions & 1 deletion python/toki/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Toki — adversarial fine-tuning lab for small LLMs."""
from __future__ import annotations

__version__ = "1.4.0"
__version__ = "1.5.0"

from toki.generate import AdversarialGenerator
from toki.evaluate import RobustnessEvaluator
Expand Down Expand Up @@ -75,6 +75,14 @@
CustomAttack,
VALID_CATEGORIES,
)
from toki.community import (
CommunityAttack,
CommunityRegistry,
filter_attacks as filter_community_attacks,
get_registry,
load_bundled,
load_remote,
)
from toki.campaign import (
CampaignConfig,
CampaignResult,
Expand Down Expand Up @@ -236,6 +244,13 @@
"AttackLibrary",
"CustomAttack",
"VALID_CATEGORIES",
# Phase 15 — community attack registry
"CommunityAttack",
"CommunityRegistry",
"filter_community_attacks",
"get_registry",
"load_bundled",
"load_remote",
"CampaignConfig",
"CampaignResult",
"RedTeamCampaign",
Expand Down
Loading
Loading