diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ab7a6b140..456eecafd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,6 +30,18 @@ repos: files: ^(doc/.*\.(py|ipynb|md)|doc/myst\.yml)$ pass_filenames: false additional_dependencies: ['pyyaml'] + - id: enforce_alembic_revision_immutability + name: Enforce Alembic Revision Immutability + entry: python ./build_scripts/enforce_alembic_revision_immutability.py + language: python + files: ^pyrit/memory/alembic/versions/.*\.py$ + pass_filenames: false + - id: memory-migrations-check + name: Check Memory Migrations + entry: python ./build_scripts/memory_migrations.py check + language: system + pass_filenames: false + files: ^pyrit/memory/(memory_models\.py|alembic/.*|migration\.py)$ - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 diff --git a/.pyrit_conf_example b/.pyrit_conf_example index e136b7a91..f66c075d1 100644 --- a/.pyrit_conf_example +++ b/.pyrit_conf_example @@ -95,6 +95,14 @@ operation: op_trash_panda # - /path/to/.env # - /path/to/.env.local +# Schema Migration Check +# --------------------- +# If true, runs database schema migration on startup to ensure the database +# is up to date with the latest PyRIT version. +# Set to false to skip the check (e.g., for read-only access, testing, or +# when managing migrations externally). +check_schema: true + # Silent Mode # ----------- # If true, suppresses print statements during initialization. diff --git a/build_scripts/enforce_alembic_revision_immutability.py b/build_scripts/enforce_alembic_revision_immutability.py new file mode 100644 index 000000000..1e1971542 --- /dev/null +++ b/build_scripts/enforce_alembic_revision_immutability.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Migration history must be immutable. This hook enforces that by preventing deletion or updates to migration scripts. + +Checks both staged changes (local pre-commit) and the full branch diff against origin/main (CI). +""" + +import subprocess +import sys + +_VERSIONS_PATH = "pyrit/memory/alembic/versions/" + + +def _git(*args: str) -> str: + result = subprocess.run(["git", *args], capture_output=True, text=True) + return result.stdout.strip() + + +def _has_non_add_changes(diff_spec: list[str]) -> bool: + output = _git("diff", "--name-status", *diff_spec, "--", _VERSIONS_PATH) + return any(line and not line.startswith("A") for line in output.splitlines()) + + +def has_revision_violations() -> bool: + # Local pre-commit: check staged changes + if _has_non_add_changes(["--cached"]): + return True + + # CI: check full branch diff against origin/main + merge_base = _git("merge-base", "origin/main", "HEAD") + return bool(merge_base and _has_non_add_changes([f"{merge_base}...HEAD"])) + + +if __name__ == "__main__": + if has_revision_violations(): + print("[ERROR] Migration scripts can only be added, not modified or deleted.") + sys.exit(1) diff --git a/build_scripts/memory_migrations.py b/build_scripts/memory_migrations.py new file mode 100644 index 000000000..bb37584c1 --- /dev/null +++ b/build_scripts/memory_migrations.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import argparse +import sys +import tempfile +from pathlib import Path + +from alembic.util.exc import AutogenerateDiffsDetected +from sqlalchemy import create_engine +from sqlalchemy.engine import Engine + +from pyrit.memory.migration import check_schema_migrations, generate_schema_migration, run_schema_migrations + +# ANSI color codes +_RED = "\033[91m" +_RESET = "\033[0m" + + +def _print_error(message: str) -> None: + """Print an error message in red to stderr.""" + print(f"{_RED}{message}{_RESET}", file=sys.stderr) + + +def _create_temp_engine() -> tuple[Engine, Path]: + """Create a temp SQLite database upgraded to head and return engine and path.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + tmp_path = Path(tmp.name) + engine = create_engine(f"sqlite:///{tmp_path}") + run_schema_migrations(engine=engine) + return engine, tmp_path + + +def _cmd_generate(*, message: str, force: bool = False) -> None: + """Generate a new Alembic revision from model changes.""" + engine, tmp_path = _create_temp_engine() + try: + generate_schema_migration(engine=engine, message=message, force=force) + print("Migration file generated. Review it carefully before committing.") + except RuntimeError as e: + _print_error(str(e)) + raise SystemExit(1) from e + finally: + engine.dispose() + tmp_path.unlink(missing_ok=True) + + +def _cmd_check() -> None: + """Verify all migrations apply cleanly and schema matches models.""" + engine, tmp_path = _create_temp_engine() + try: + check_schema_migrations(engine=engine) + except AutogenerateDiffsDetected as e: + _print_error(f"Migration check failed. Run 'generate' to create a migration. Error: {e}") + raise SystemExit(1) from e + finally: + engine.dispose() + tmp_path.unlink(missing_ok=True) + + +def _build_parser() -> argparse.ArgumentParser: + """Build the CLI argument parser.""" + parser = argparse.ArgumentParser( + description="PyRIT memory migration tool. Generate and validate migrations based on the current memory models." + ) + sub = parser.add_subparsers(dest="command", required=True) + + gen = sub.add_parser("generate", help="Generate a new migration from model changes.") + gen.add_argument("-m", "--message", required=True, help="Migration message.") + gen.add_argument("--force", action="store_true", help="Generate migration even if no changes detected.") + + sub.add_parser("check", help="Verify all migrations apply cleanly and add up to the current memory models.") + + return parser + + +def main() -> int: + """Dispatch the selected migration command.""" + args = _build_parser().parse_args() + + if args.command == "generate": + _cmd_generate(message=args.message, force=args.force) + elif args.command == "check": + _cmd_check() + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/doc/contributing/11_memory_models.md b/doc/contributing/11_memory_models.md new file mode 100644 index 000000000..ef215eeae --- /dev/null +++ b/doc/contributing/11_memory_models.md @@ -0,0 +1,97 @@ +# Memory Models & Migrations + +This guide covers how to work with PyRIT's memory models — where they live, how to add or update them, and how the migration system works. + +## Where Things Live + +| What | Path | +|---|---| +| ORM models (SQLAlchemy) | `pyrit/memory/memory_models.py` | +| Domain objects they map to | `pyrit/models/` (e.g. `MessagePiece`, `Score`, `Seed`, `AttackResult`, `ScenarioResult`) | +| Alembic migration environment | `pyrit/memory/alembic/env.py` | +| Migration revisions | `pyrit/memory/alembic/versions/` | +| Migration helpers | `pyrit/memory/migration.py` | +| CLI migration tool | `build_scripts/memory_migrations.py` | +| Schema diagram | `doc/code/memory/10_schema_diagram.md` | + +## Current Models + +All models inherit from the SQLAlchemy `Base` declarative class and live in `memory_models.py`: + +- **`PromptMemoryEntry`** — prompt/response data (`PromptMemoryEntries` table) +- **`ScoreEntry`** — evaluation results (`ScoreEntries` table) +- **`EmbeddingDataEntry`** — embeddings for semantic search (`EmbeddingData` table) +- **`SeedEntry`** — dataset prompts/templates (`SeedPromptEntries` table) +- **`AttackResultEntry`** — attack execution results (`AttackResultEntries` table) +- **`ScenarioResultEntry`** — scenario execution metadata (`ScenarioResultEntries` table) + +Each entry model has a corresponding domain object and conversion methods (e.g. `PromptMemoryEntry.__init__(entry: MessagePiece)` and `get_message_piece()`). + +## Adding or Updating a Model + +### 1. Edit the model + +Make your changes in `pyrit/memory/memory_models.py`. Follow these conventions: + +- Use `mapped_column()` with explicit types. +- Use `CustomUUID` for all UUID columns (handles cross-database compatibility). +- Add foreign keys where relationships exist. +- Include `pyrit_version` on new entry models. + +### 2. Generate a migration + +```bash +python build_scripts/memory_migrations.py generate -m "short description of change" +``` + +This creates a new revision file under `pyrit/memory/alembic/versions/`. **Review the generated file carefully** — auto-generated migrations may need manual adjustments (e.g. for data migrations or default values). + +### 3. Validate the migration + +```bash +python build_scripts/memory_migrations.py check +``` + +This verifies the schema produced by running all migrations matches the current models. Both pre-commit hooks (see below) and CI run this check. + +### 4. Update the schema diagram + +If you changed the schema in a meaningful way (added a table, added a foreign key, etc.), update the Mermaid diagram in `doc/code/memory/10_schema_diagram.md`. + +## How Migrations Run at Startup + +When `initialize_pyrit_async()` is called with `check_schema=True` (the default), migrations run automatically: + +``` +initialize_pyrit_async() + → memory._ensure_schema_is_current() # pyrit/memory/memory_interface.py + → run_schema_migrations(engine=...) # pyrit/memory/migration.py + → alembic upgrade head +``` + +This means any new migration you add will be applied automatically the next time a user initializes PyRIT. The behavior depends on the database state: + +| Database state | What happens | +|---|---| +| **Fresh (no tables)** | All migrations apply from scratch | +| **Already versioned** | Only unapplied migrations run (idempotent) | +| **Legacy (tables exist, no version tracking)** | Validates schema matches models, stamps current version, then upgrades. Raises `RuntimeError` on mismatch to prevent data corruption | + +Migrations run inside a transaction (`engine.begin()`), so a failed migration rolls back cleanly. The version tracking table is `pyrit_memory_alembic_version`. + +Users can skip this check by passing `check_schema=False` to `initialize_pyrit_async()`. + +## Important Rules + +### Migration revisions are immutable + +Once a migration revision is committed, it **must not be modified or deleted**. This is enforced by a pre-commit hook (`enforce_alembic_revision_immutability`). If you need to fix a migration, create a new revision instead. + +### Pre-commit hooks + +Two hooks run automatically when you touch memory-related files: + +1. **`enforce_alembic_revision_immutability`** — blocks modifications/deletions to existing revision files. +2. **`memory-migrations-check`** — runs `memory_migrations.py check` to verify the schema is in sync. + +These hooks trigger on changes to `pyrit/memory/memory_models.py`, `pyrit/memory/migration.py`, and files under `pyrit/memory/alembic/`. diff --git a/doc/myst.yml b/doc/myst.yml index c3fb071a2..2c995bd5c 100644 --- a/doc/myst.yml +++ b/doc/myst.yml @@ -53,6 +53,7 @@ project: - file: contributing/8_pre_commit.md - file: contributing/9_exception.md - file: contributing/10_release_process.md + - file: contributing/11_memory_models.md - file: gui/0_gui.md - file: scanner/0_scanner.md children: diff --git a/pyproject.toml b/pyproject.toml index 7aa3758c6..0d1994a98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ classifiers = [ requires-python = ">=3.10, <3.14" dependencies = [ "aiofiles>=24,<25", + "alembic>=1.16.0", "appdirs>=1.4.0", "art>=6.5.0", "av>=14.0.0", @@ -199,6 +200,8 @@ include = ["pyrit", "pyrit.*"] [tool.setuptools.package-data] pyrit = [ "backend/frontend/**/*", + "memory/alembic/**/*", + "memory/alembic.ini", "py.typed" ] diff --git a/pyrit/memory/alembic/env.py b/pyrit/memory/alembic/env.py new file mode 100644 index 000000000..672b62e79 --- /dev/null +++ b/pyrit/memory/alembic/env.py @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from alembic import context +from sqlalchemy.engine import Connection + +from pyrit.memory.memory_models import Base +from pyrit.memory.migration import PYRIT_MEMORY_ALEMBIC_VERSION_TABLE + +config = context.config +connection: Connection | None = config.attributes.get("connection") +target_metadata = Base.metadata + +if connection is None: + raise RuntimeError("No connection found for Alembic migration") + +context.configure( + connection=connection, + target_metadata=target_metadata, + compare_type=True, + version_table=PYRIT_MEMORY_ALEMBIC_VERSION_TABLE, +) +with context.begin_transaction(): + context.run_migrations() diff --git a/pyrit/memory/alembic/script.py.mako b/pyrit/memory/alembic/script.py.mako new file mode 100644 index 000000000..2b67b03e1 --- /dev/null +++ b/pyrit/memory/alembic/script.py.mako @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +${message}. + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = "${up_revision}" +down_revision: str | None = ${repr(down_revision).replace("'", '"')} +branch_labels: str | Sequence[str] | None = ${repr(branch_labels).replace("'", '"')} +depends_on: str | Sequence[str] | None = ${repr(depends_on).replace("'", '"')} + + +def upgrade() -> None: + """Apply this schema upgrade.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Revert this schema upgrade.""" + ${downgrades if downgrades else "pass"} diff --git a/pyrit/memory/alembic/versions/e373726d391b_initial_schema.py b/pyrit/memory/alembic/versions/e373726d391b_initial_schema.py new file mode 100644 index 000000000..935b077e8 --- /dev/null +++ b/pyrit/memory/alembic/versions/e373726d391b_initial_schema.py @@ -0,0 +1,204 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +initial schema. + +Revision ID: e373726d391b +Revises: +Create Date: 2026-04-17 10:03:04.066932 +""" + +import uuid +from collections.abc import Sequence +from typing import Any + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects.sqlite import CHAR +from sqlalchemy.engine import Dialect +from sqlalchemy.types import TypeDecorator, Uuid + + +class _CustomUUID(TypeDecorator[uuid.UUID]): + """ + Frozen copy of CustomUUID kept here so this revision stays self-contained. + + This class is embedded in the migration script rather than imported to ensure + the migration remains reproducible and independent of future changes to the + main CustomUUID implementation in memory_models.py. Any future modifications + to CustomUUID must NOT affect this frozen version. + """ + + impl = CHAR + cache_ok = True + + def load_dialect_impl(self, dialect: Dialect) -> Any: + if dialect.name == "sqlite": + return dialect.type_descriptor(CHAR(36)) + return dialect.type_descriptor(Uuid()) + + def process_bind_param(self, value: Any, dialect: Any) -> str | None: + return str(value) if value is not None else None + + def process_result_value(self, value: Any, dialect: Any) -> uuid.UUID | None: + if value is None: + return None + if isinstance(value, uuid.UUID): + return value + return uuid.UUID(value) + + +# revision identifiers, used by Alembic. +revision: str = "e373726d391b" +down_revision: str | None = None +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Apply this schema upgrade.""" + # ### commands auto generated by Alembic and reviewed by author### + op.create_table( + "PromptMemoryEntries", + sa.Column("id", _CustomUUID(), nullable=False), + sa.Column("role", sa.String(), nullable=False), + sa.Column("conversation_id", sa.String(), nullable=False), + sa.Column("sequence", sa.INTEGER(), nullable=False), + sa.Column("timestamp", sa.DateTime(), nullable=False), + sa.Column("labels", sa.JSON(), nullable=False), + sa.Column("prompt_metadata", sa.JSON(), nullable=False), + sa.Column("targeted_harm_categories", sa.JSON(), nullable=True), + sa.Column("converter_identifiers", sa.JSON(), nullable=True), + sa.Column("prompt_target_identifier", sa.JSON(), nullable=False), + sa.Column("attack_identifier", sa.JSON(), nullable=False), + sa.Column("response_error", sa.String(), nullable=True), + sa.Column("original_value_data_type", sa.String(), nullable=False), + sa.Column("original_value", sa.Unicode(), nullable=False), + sa.Column("original_value_sha256", sa.String(), nullable=True), + sa.Column("converted_value_data_type", sa.String(), nullable=False), + sa.Column("converted_value", sa.Unicode(), nullable=True), + sa.Column("converted_value_sha256", sa.String(), nullable=True), + sa.Column("original_prompt_id", _CustomUUID(), nullable=False), + sa.Column("pyrit_version", sa.String(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "ScenarioResultEntries", + sa.Column("id", _CustomUUID(), nullable=False), + sa.Column("scenario_name", sa.String(), nullable=False), + sa.Column("scenario_description", sa.Unicode(), nullable=True), + sa.Column("scenario_version", sa.INTEGER(), nullable=False), + sa.Column("pyrit_version", sa.String(), nullable=False), + sa.Column("scenario_init_data", sa.JSON(), nullable=True), + sa.Column("objective_target_identifier", sa.JSON(), nullable=False), + sa.Column("objective_scorer_identifier", sa.JSON(), nullable=True), + sa.Column("scenario_run_state", sa.String(), nullable=False), + sa.Column("attack_results_json", sa.Unicode(), nullable=False), + sa.Column("labels", sa.JSON(), nullable=True), + sa.Column("number_tries", sa.INTEGER(), nullable=False), + sa.Column("completion_time", sa.DateTime(), nullable=False), + sa.Column("timestamp", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "SeedPromptEntries", + sa.Column("id", _CustomUUID(), nullable=False), + sa.Column("value", sa.Unicode(), nullable=False), + sa.Column("value_sha256", sa.Unicode(), nullable=True), + sa.Column("data_type", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("dataset_name", sa.String(), nullable=True), + sa.Column("harm_categories", sa.JSON(), nullable=True), + sa.Column("description", sa.String(), nullable=True), + sa.Column("authors", sa.JSON(), nullable=True), + sa.Column("groups", sa.JSON(), nullable=True), + sa.Column("source", sa.String(), nullable=True), + sa.Column("date_added", sa.DateTime(), nullable=False), + sa.Column("added_by", sa.String(), nullable=False), + sa.Column("prompt_metadata", sa.JSON(), nullable=True), + sa.Column("parameters", sa.JSON(), nullable=True), + sa.Column("prompt_group_id", _CustomUUID(), nullable=True), + sa.Column("sequence", sa.INTEGER(), nullable=True), + sa.Column("role", sa.String(), nullable=True), + sa.Column("seed_type", sa.String(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "EmbeddingData", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column( + "embedding", + sa.ARRAY(sa.Float()).with_variant(sa.JSON(), "mssql").with_variant(sa.JSON(), "sqlite"), + nullable=True, + ), + sa.Column("embedding_type_name", sa.String(), nullable=True), + sa.ForeignKeyConstraint( + ["id"], + ["PromptMemoryEntries.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "ScoreEntries", + sa.Column("id", _CustomUUID(), nullable=False), + sa.Column("score_value", sa.String(), nullable=False), + sa.Column("score_value_description", sa.String(), nullable=True), + sa.Column("score_type", sa.String(), nullable=False), + sa.Column("score_category", sa.JSON(), nullable=True), + sa.Column("score_rationale", sa.String(), nullable=True), + sa.Column("score_metadata", sa.JSON(), nullable=False), + sa.Column("scorer_class_identifier", sa.JSON(), nullable=False), + sa.Column("prompt_request_response_id", _CustomUUID(), nullable=True), + sa.Column("timestamp", sa.DateTime(), nullable=False), + sa.Column("task", sa.String(), nullable=True), + sa.Column("objective", sa.String(), nullable=True), + sa.Column("pyrit_version", sa.String(), nullable=True), + sa.ForeignKeyConstraint( + ["prompt_request_response_id"], + ["PromptMemoryEntries.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "AttackResultEntries", + sa.Column("id", _CustomUUID(), nullable=False), + sa.Column("conversation_id", sa.String(), nullable=False), + sa.Column("objective", sa.Unicode(), nullable=False), + sa.Column("attack_identifier", sa.JSON(), nullable=False), + sa.Column("atomic_attack_identifier", sa.JSON(), nullable=True), + sa.Column("objective_sha256", sa.String(), nullable=True), + sa.Column("last_response_id", _CustomUUID(), nullable=True), + sa.Column("last_score_id", _CustomUUID(), nullable=True), + sa.Column("executed_turns", sa.INTEGER(), nullable=False), + sa.Column("execution_time_ms", sa.INTEGER(), nullable=False), + sa.Column("outcome", sa.String(), nullable=False), + sa.Column("outcome_reason", sa.String(), nullable=True), + sa.Column("attack_metadata", sa.JSON(), nullable=True), + sa.Column("pruned_conversation_ids", sa.JSON(), nullable=True), + sa.Column("adversarial_chat_conversation_ids", sa.JSON(), nullable=True), + sa.Column("timestamp", sa.DateTime(), nullable=False), + sa.Column("pyrit_version", sa.String(), nullable=True), + sa.ForeignKeyConstraint( + ["last_response_id"], + ["PromptMemoryEntries.id"], + ), + sa.ForeignKeyConstraint( + ["last_score_id"], + ["ScoreEntries.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Revert this schema upgrade.""" + # ### commands auto generated by Alembic and reviewed by author ### + op.drop_table("AttackResultEntries") + op.drop_table("ScoreEntries") + op.drop_table("EmbeddingData") + op.drop_table("SeedPromptEntries") + op.drop_table("ScenarioResultEntries") + op.drop_table("PromptMemoryEntries") + # ### end Alembic commands ### diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index d93d2bd4a..a36109f2c 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -103,7 +103,6 @@ def __init__( self._enable_azure_authorization() self.SessionFactory = sessionmaker(bind=self.engine) - self._create_tables_if_not_exist() super().__init__() @@ -215,23 +214,6 @@ def provide_token(_dialect: Any, _conn_rec: Any, cargs: list[Any], cparams: dict # add the encoded token cparams["attrs_before"] = {self.SQL_COPT_SS_ACCESS_TOKEN: packed_azure_token} - def _create_tables_if_not_exist(self) -> None: - """ - Create all tables defined in the Base metadata, if they don't already exist in the database. - - Raises: - Exception: If there's an issue creating the tables in the database. - RuntimeError: If the engine is not initialized. - """ - try: - # Using the 'checkfirst=True' parameter to avoid attempting to recreate existing tables - if self.engine is None: - raise RuntimeError("Engine is not initialized") - Base.metadata.create_all(self.engine, checkfirst=True) - except Exception as e: - logger.exception(f"Error during table creation: {e}") - raise - def _add_embeddings_to_memory(self, *, embedding_data: Sequence[EmbeddingDataEntry]) -> None: """ Insert embedding data into memory storage. @@ -798,18 +780,3 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict session.rollback() logger.exception(f"Error updating entries: {e}") raise - - def reset_database(self) -> None: - """ - Drop and recreate existing tables. - - Raises: - RuntimeError: If the engine is not initialized. - """ - # Drop all existing tables - if self.engine is None: - raise RuntimeError("Engine is not initialized") - - Base.metadata.drop_all(self.engine) - # Recreate the tables - Base.metadata.create_all(self.engine, checkfirst=True) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index d020a4a0a..18600a322 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -34,6 +34,7 @@ ScoreEntry, SeedEntry, ) +from pyrit.memory.migration import reset_database, run_schema_migrations from pyrit.models import ( AttackResult, ConversationStats, @@ -119,6 +120,22 @@ def disable_embedding(self) -> None: """ self.memory_embedding = None + def _ensure_schema_is_current(self) -> None: + """ + Upgrade the database schema to the latest Alembic revision. + + Raises: + Exception: If there's an issue applying schema migrations. + RuntimeError: If the engine is not initialized. + """ + if self.engine is None: + raise RuntimeError("Engine is not initialized") + try: + run_schema_migrations(engine=self.engine) + except Exception as e: + logger.exception(f"Error during schema migration: {e}") + raise + def _build_identifier_filter_conditions( self, *, @@ -951,6 +968,17 @@ def update_prompt_metadata_by_conversation_id( conversation_id=conversation_id, update_fields={"prompt_metadata": prompt_metadata} ) + def reset_database(self) -> None: + """ + Drop and recreate all tables in the database. + + Raises: + RuntimeError: If the engine is not initialized. + """ + if self.engine is None: + raise RuntimeError("Engine is not initialized") + reset_database(engine=self.engine) + @abc.abstractmethod def dispose_engine(self) -> None: """ diff --git a/pyrit/memory/migration.py b/pyrit/memory/migration.py new file mode 100644 index 000000000..b67c631bc --- /dev/null +++ b/pyrit/memory/migration.py @@ -0,0 +1,205 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +from pathlib import Path + +from alembic import command +from alembic.autogenerate.api import compare_metadata +from alembic.config import Config +from alembic.migration import MigrationContext +from alembic.util.exc import AutogenerateDiffsDetected +from sqlalchemy import MetaData, Table, inspect +from sqlalchemy.engine import Connection, Engine + +from pyrit.memory.memory_models import Base + +logger = logging.getLogger(__name__) + +PYRIT_MEMORY_ALEMBIC_VERSION_TABLE = "pyrit_memory_alembic_version" +_HEAD_REVISION = "head" +_MEMORY_TABLES = {table.name for table in Base.metadata.sorted_tables} + +_ERROR_UNVERSIONED_SCHEMA_COMPARISON_FAILED = ( + "Detected an unversioned legacy memory schema (memory tables exist, but " + "pyrit_memory_alembic_version is missing), " + "but failed to compare the existing schema against current models. " + "This could be transient, or you may need to repair or rebuild the database " + "before upgrading to this release." +) + +_ERROR_UNVERSIONED_SCHEMA_MISMATCH = ( + "Detected an unversioned legacy memory schema (memory tables exist, but " + "pyrit_memory_alembic_version is missing), " + "and it does not match current models. " + "Repair or rebuild the database before upgrading to this release." +) + + +def _include_name_for_memory_schema( + name: str | None, + type_: str, + parent_names: dict[str, str], +) -> bool: + """ + Restrict schema comparisons to PyRIT memory tables and their child objects. + + Args: + name (str | None): Name of the database object being considered. + type_ (str): SQLAlchemy object type (e.g., "table", "column", "index"). + parent_names (dict[str, str]): Parent-name context provided by Alembic. + + Returns: + bool: True when the object should be included in schema comparison. + """ + if type_ == "table": + return bool(name and name in _MEMORY_TABLES) + + table_name = parent_names.get("table_name") + if table_name: + return table_name in _MEMORY_TABLES + + return True + + +def _make_config(*, connection: Connection) -> Config: + """ + Build an Alembic config for the memory migration scripts. + + Args: + connection (Connection): Database connection for Alembic commands. + + Returns: + Config: Configured Alembic config object. + """ + script_location = Path(__file__).with_name("alembic") + + config = Config() + config.set_main_option("script_location", str(script_location)) + config.attributes["connection"] = connection + return config + + +def _validate_and_stamp_unversioned_memory_schema(*, config: Config, connection: Connection) -> None: + """ + Validate and stamp unversioned legacy memory schemas. + + Args: + config (Config): Alembic config bound to the current connection. + connection (Connection): Database connection to inspect. + + Raises: + RuntimeError: If an unversioned memory schema does not match models. + """ + # Perform all inspection in one atomic call to avoid race conditions + inspector = inspect(connection) + table_names = set(inspector.get_table_names()) + + # If version table already exists, migration has been stamped + if PYRIT_MEMORY_ALEMBIC_VERSION_TABLE in table_names: + return + + # If no memory tables exist, this is a fresh database + if not _MEMORY_TABLES.intersection(table_names): + return + + # Unversioned memory schema detected; validate it matches current models + try: + migration_context = MigrationContext.configure( + connection=connection, + opts={"compare_type": True, "include_name": _include_name_for_memory_schema}, + ) + diffs = compare_metadata(migration_context, Base.metadata) + except Exception as e: + raise RuntimeError(_ERROR_UNVERSIONED_SCHEMA_COMPARISON_FAILED) from e + + if diffs: + raise RuntimeError(_ERROR_UNVERSIONED_SCHEMA_MISMATCH) + + logger.warn("Detected matching unversioned memory schema; stamping revision %s", _HEAD_REVISION) + command.stamp(config, _HEAD_REVISION) + + +def run_schema_migrations(*, engine: Engine) -> None: + """ + Upgrade the database schema to the latest Alembic revision. + + Args: + engine (Engine): SQLAlchemy engine bound to the target database. + + Raises: + Exception: If Alembic fails to apply migrations. + """ + with engine.begin() as connection: + config = _make_config(connection=connection) + _validate_and_stamp_unversioned_memory_schema(config=config, connection=connection) + command.upgrade(config, _HEAD_REVISION) + + +def check_schema_migrations(*, engine: Engine) -> None: + """ + Verify that the current schema matches the models. + + Creates a fresh connection, builds an Alembic config, and runs + ``alembic check`` to detect any unapplied model changes. + + Args: + engine (Engine): SQLAlchemy engine bound to the target database. + + Raises: + AutogenerateDiffsDetected: If schema does not match models. + """ + with engine.begin() as connection: + config = _make_config(connection=connection) + command.check(config) + + +def generate_schema_migration(*, engine: Engine, message: str, force: bool = False) -> None: + """ + Generate a new Alembic revision from model changes. + + Args: + engine (Engine): SQLAlchemy engine upgraded to head. + message (str): Human-readable migration message. + force (bool): If True, generate even if no changes detected. Defaults to False. + + Raises: + RuntimeError: If no changes detected and force is False. + """ + with engine.begin() as connection: + config = _make_config(connection=connection) + + if force: + command.revision(config, autogenerate=True, message=message) + return + + try: + command.check(config) + except AutogenerateDiffsDetected: + command.revision(config, autogenerate=True, message=message) + return + + raise RuntimeError("No schema changes detected. Use force=True to generate an empty migration.") + + +def reset_database(*, engine: Engine) -> None: + """ + Drop all tables and recreate the database schema at the latest Alembic revision. + + This destroys all existing data. + + Args: + engine (Engine): SQLAlchemy engine bound to the target database. + """ + logger.debug("Resetting database using Alembic migrations") + with engine.begin() as connection: + # Drop version table first (not part of Base.metadata) + inspector = inspect(connection) + if PYRIT_MEMORY_ALEMBIC_VERSION_TABLE in inspector.get_table_names(): + version_table = Table(PYRIT_MEMORY_ALEMBIC_VERSION_TABLE, MetaData(), autoload_with=connection) + version_table.drop(connection) + + # Drop all application tables defined in models + Base.metadata.drop_all(connection) + # Rebuild schema from migrations + run_schema_migrations(engine=engine) diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index a4039c1b7..5547cafb6 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -80,7 +80,6 @@ def __init__( self.engine = self._create_engine(has_echo=verbose) self.SessionFactory = sessionmaker(bind=self.engine) - self._create_tables_if_not_exist() def _init_storage_io(self) -> None: # Handles disk-based storage for SQLite local memory. @@ -128,23 +127,6 @@ def _create_engine(self, *, has_echo: bool) -> Engine: logger.exception(f"Error creating the engine for the database: {e}") raise - def _create_tables_if_not_exist(self) -> None: - """ - Create all tables defined in the Base metadata, if they don't already exist in the database. - - Raises: - Exception: If there's an issue creating the tables in the database. - RuntimeError: If the engine is not initialized. - """ - try: - # Using the 'checkfirst=True' parameter to avoid attempting to recreate existing tables - if self.engine is None: - raise RuntimeError("Engine is not initialized") - Base.metadata.create_all(self.engine, checkfirst=True) - except Exception as e: - logger.exception(f"Error during table creation: {e}") - raise - def get_all_embeddings(self) -> Sequence[EmbeddingDataEntry]: """ Fetch all entries from the specified table and returns them as model instances. @@ -440,19 +422,6 @@ def get_session(self) -> Session: """ return self.SessionFactory() - def reset_database(self) -> None: - """ - Drop and recreates all tables in the database. - - Raises: - RuntimeError: If the engine is not initialized. - """ - if self.engine is None: - raise RuntimeError("Engine is not initialized") - - Base.metadata.drop_all(self.engine) - Base.metadata.create_all(self.engine) - def dispose_engine(self) -> None: """ Dispose the engine and close all connections. diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index 6f720b303..32279982a 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -100,6 +100,7 @@ class ConfigurationLoader(YamlLoadable): initialization_scripts: Optional[list[str]] = None env_files: Optional[list[str]] = None silent: bool = False + check_schema: bool = True operator: Optional[str] = None operation: Optional[str] = None @@ -189,6 +190,7 @@ def load_with_overrides( initializers: Optional[Sequence[Union[str, dict[str, Any]]]] = None, initialization_scripts: Optional[Sequence[str]] = None, env_files: Optional[Sequence[str]] = None, + check_schema: Optional[bool] = None, ) -> "ConfigurationLoader": """ Load configuration with optional overrides. @@ -208,6 +210,7 @@ def load_with_overrides( initializers: Override for initializer list. initialization_scripts: Override for initialization script paths. env_files: Override for environment file paths. + check_schema: Override for schema migration check. True to run, False to skip. Returns: A merged ConfigurationLoader instance. @@ -227,6 +230,7 @@ def load_with_overrides( "initialization_scripts": None, # None = use defaults "env_files": None, # None = use defaults "silent": False, + "check_schema": True, } # 1. Try loading default config file if it exists @@ -245,6 +249,7 @@ def load_with_overrides( config_data["initialization_scripts"] = default_config.initialization_scripts config_data["env_files"] = default_config.env_files config_data["silent"] = default_config.silent + config_data["check_schema"] = default_config.check_schema if default_config.operator: config_data["operator"] = default_config.operator if default_config.operation: @@ -268,6 +273,7 @@ def load_with_overrides( config_data["initialization_scripts"] = explicit_config.initialization_scripts config_data["env_files"] = explicit_config.env_files config_data["silent"] = explicit_config.silent + config_data["check_schema"] = explicit_config.check_schema if explicit_config.operator: config_data["operator"] = explicit_config.operator if explicit_config.operation: @@ -293,6 +299,9 @@ def load_with_overrides( if env_files is not None: config_data["env_files"] = list(env_files) + if check_schema is not None: + config_data["check_schema"] = check_schema + return ConfigurationLoader.from_dict(config_data) @classmethod @@ -418,6 +427,7 @@ async def initialize_pyrit_async(self) -> None: initializers=resolved_initializers if resolved_initializers else None, env_files=resolved_env_files, silent=self.silent, + check_schema=self.check_schema, ) diff --git a/pyrit/setup/initialization.py b/pyrit/setup/initialization.py index ea34362c3..8789e1124 100644 --- a/pyrit/setup/initialization.py +++ b/pyrit/setup/initialization.py @@ -240,6 +240,7 @@ async def initialize_pyrit_async( initializers: Optional[Sequence["PyRITInitializer"]] = None, env_files: Optional[Sequence[pathlib.Path]] = None, silent: bool = False, + check_schema: bool = True, **memory_instance_kwargs: Any, ) -> None: """ @@ -258,6 +259,9 @@ async def initialize_pyrit_async( All paths must be valid pathlib.Path objects. silent (bool): If True, suppresses print statements about environment file loading. Defaults to False. + check_schema (bool): If True, runs schema migration to ensure the database is up to date. + Set to False to bypass the migration check (e.g., for testing or read-only access). + Defaults to True. **memory_instance_kwargs (Optional[Any]): Additional keyword arguments to pass to the memory instance. Raises: @@ -287,6 +291,9 @@ async def initialize_pyrit_async( raise ValueError( f"Memory database type '{memory_db_type}' is not a supported type {get_args(MemoryDatabaseType)}" ) + if check_schema: + memory._ensure_schema_is_current() + CentralMemory.set_memory_instance(memory) # Combine directly provided initializers with those loaded from scripts diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index 4d9e056c5..9ffb02099 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -9,6 +9,7 @@ from unittest.mock import MagicMock, patch import pytest +from sqlalchemy import inspect, text from pyrit.memory import AzureSQLMemory, EmbeddingDataEntry, PromptMemoryEntry from pyrit.models import MessagePiece @@ -136,6 +137,28 @@ def test_default_embedding_raises(memory_interface: AzureSQLMemory): memory_interface.enable_embedding() +def test_reset_database_recreates_versioned_schema(memory_interface: AzureSQLMemory): + memory_interface.reset_database() + + inspector = inspect(memory_interface.engine) + table_names = set(inspector.get_table_names()) + + assert { + "AttackResultEntries", + "EmbeddingData", + "PromptMemoryEntries", + "ScenarioResultEntries", + "ScoreEntries", + "SeedPromptEntries", + "pyrit_memory_alembic_version", + }.issubset(table_names) + + with memory_interface.engine.connect() as connection: + version = connection.execute(text("SELECT version_num FROM pyrit_memory_alembic_version")).scalar_one() + + assert version + + def test_query_entries( memory_interface: AzureSQLMemory, sample_conversation_entries: MutableSequence[PromptMemoryEntry] ): @@ -439,11 +462,11 @@ def decorator(fn): captured_fn(None, None, ["some_connection_string"], {}) -def test_create_tables_if_not_exist_raises_when_engine_none(): +def test_ensure_schema_is_current_raises_when_engine_none(): obj = AzureSQLMemory.__new__(AzureSQLMemory) obj.engine = None with pytest.raises(RuntimeError, match="Engine is not initialized"): - obj._create_tables_if_not_exist() + obj._ensure_schema_is_current() def test_reset_database_raises_when_engine_none(): diff --git a/tests/unit/memory/test_migration.py b/tests/unit/memory/test_migration.py new file mode 100644 index 000000000..0140ce5b1 --- /dev/null +++ b/tests/unit/memory/test_migration.py @@ -0,0 +1,264 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import tempfile +import uuid +from pathlib import Path + +import pytest +from alembic import command +from alembic.config import Config +from alembic.script import ScriptDirectory +from alembic.util.exc import AutogenerateDiffsDetected +from sqlalchemy import create_engine, inspect, text + +from pyrit.memory.alembic.versions.e373726d391b_initial_schema import _CustomUUID +from pyrit.memory.migration import check_schema_migrations, generate_schema_migration, run_schema_migrations + + +def test_alembic_env_raises_when_no_connection(): + """Covers env.py line 15: RuntimeError when connection is None.""" + import importlib + import sys + from unittest.mock import MagicMock + + # Build a mock alembic context where config.attributes has no "connection" + mock_config = MagicMock() + mock_config.attributes = {} # .get("connection") → None + + mock_context = MagicMock() + mock_context.config = mock_config + + # Remove cached env module so reload runs module-level code + env_module_name = "pyrit.memory.alembic.env" + saved = sys.modules.pop(env_module_name, None) + + # Patch alembic.context to be our mock context + import alembic + + original_context = getattr(alembic, "context", None) + alembic.context = mock_context + + try: + with pytest.raises(RuntimeError, match="No connection found for Alembic migration"): + importlib.import_module(env_module_name) + finally: + # Restore original state + alembic.context = original_context + sys.modules.pop(env_module_name, None) + if saved is not None: + sys.modules[env_module_name] = saved + + +def _get_alembic_head_revision(*, config: Config) -> str: + """Return the current Alembic head revision for the configured script location.""" + head_revision = ScriptDirectory.from_config(config).get_current_head() + if head_revision is None: + raise RuntimeError("No Alembic head revision found for memory migrations.") + + return head_revision + + +def test_custom_uuid_process_bind_param_with_none(): + """Test _CustomUUID.process_bind_param with None value.""" + uuid_type = _CustomUUID() + result = uuid_type.process_bind_param(None, None) + assert result is None + + +def test_custom_uuid_process_bind_param_with_uuid(): + """Test _CustomUUID.process_bind_param with UUID value.""" + uuid_type = _CustomUUID() + test_uuid = uuid.uuid4() + result = uuid_type.process_bind_param(test_uuid, None) + assert result == str(test_uuid) + + +def test_custom_uuid_process_result_value_with_none(): + """Test _CustomUUID.process_result_value with None value.""" + uuid_type = _CustomUUID() + result = uuid_type.process_result_value(None, None) + assert result is None + + +def test_custom_uuid_process_result_value_with_uuid(): + """Test _CustomUUID.process_result_value with UUID value.""" + uuid_type = _CustomUUID() + test_uuid = uuid.uuid4() + result = uuid_type.process_result_value(test_uuid, None) + assert result == test_uuid + + +def test_custom_uuid_process_result_value_with_string(): + """Test _CustomUUID.process_result_value with string value.""" + uuid_type = _CustomUUID() + test_uuid = uuid.uuid4() + result = uuid_type.process_result_value(str(test_uuid), None) + assert result == test_uuid + + +def test_custom_uuid_load_dialect_impl_sqlite(): + """Test _CustomUUID.load_dialect_impl for SQLite dialect.""" + from sqlalchemy.dialects import sqlite + + uuid_type = _CustomUUID() + dialect = sqlite.dialect() + result = uuid_type.load_dialect_impl(dialect) + assert result is not None + + +def test_custom_uuid_load_dialect_impl_postgresql(): + """Test _CustomUUID.load_dialect_impl for PostgreSQL dialect.""" + from sqlalchemy.dialects import postgresql + + uuid_type = _CustomUUID() + dialect = postgresql.dialect() + result = uuid_type.load_dialect_impl(dialect) + assert result is not None + + +def test_run_schema_migrations_applies_head_revision(): + """Test that run_schema_migrations applies the current head revision.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "offline-test.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + pyrit_root = Path(__file__).resolve().parent.parent.parent.parent / "pyrit" + script_location = pyrit_root / "memory" / "alembic" + config = Config() + config.set_main_option("script_location", str(script_location)) + expected_head = _get_alembic_head_revision(config=config) + + run_schema_migrations(engine=engine) + + with engine.connect() as connection: + version = connection.execute(text("SELECT version_num FROM pyrit_memory_alembic_version")).scalar_one() + assert version == expected_head + finally: + engine.dispose() + + +def test_migration_online_mode(): + """ + Test that online migration configuration is valid. + This tests the run_migrations_online path in env.py. + """ + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "online-test.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + with engine.begin() as connection: + pyrit_root = Path(__file__).resolve().parent.parent.parent.parent / "pyrit" + script_location = pyrit_root / "memory" / "alembic" + config = Config() + config.set_main_option("script_location", str(script_location)) + config.attributes["connection"] = connection + config.attributes["version_table"] = "pyrit_memory_alembic_version" + expected_head = _get_alembic_head_revision(config=config) + + command.upgrade(config, "head") + + version = connection.execute(text("SELECT version_num FROM pyrit_memory_alembic_version")).scalar_one() + assert version == expected_head + finally: + engine.dispose() + + +def test_migration_script_metadata(): + """Test that the initial migration script has correct metadata.""" + from pyrit.memory.alembic.versions import e373726d391b_initial_schema + + assert e373726d391b_initial_schema.revision == "e373726d391b" + assert e373726d391b_initial_schema.down_revision is None + assert e373726d391b_initial_schema.branch_labels is None + assert e373726d391b_initial_schema.depends_on is None + + +def test_migration_downgrade_creates_proper_structure(): + """ + Test that downgrade function doesn't corrupt the database. + This indirectly tests the downgrade path in e373726d391b_initial_schema.py. + """ + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "downgrade-test.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + with engine.begin() as connection: + pyrit_root = Path(__file__).resolve().parent.parent.parent.parent / "pyrit" + script_location = pyrit_root / "memory" / "alembic" + config = Config() + config.set_main_option("script_location", str(script_location)) + config.attributes["connection"] = connection + config.attributes["version_table"] = "pyrit_memory_alembic_version" + + command.upgrade(config, "head") + + tables_before = set(inspect(connection).get_table_names()) + assert len(tables_before) > 0 + + command.downgrade(config, "base") + + tables_after = set(inspect(connection).get_table_names()) + assert "pyrit_memory_alembic_version" not in tables_after or len(tables_after) == 1 + finally: + engine.dispose() + + +def test_check_schema_migrations_calls_alembic_check(): + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "check-test.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + # First apply migrations so schema matches models + run_schema_migrations(engine=engine) + # Now check should succeed (no diffs) + check_schema_migrations(engine=engine) + finally: + engine.dispose() + + +def test_generate_schema_migration_force_creates_revision(): + from unittest.mock import patch + + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "gen-force-test.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + run_schema_migrations(engine=engine) + with patch("pyrit.memory.migration.command.revision") as mock_revision: + generate_schema_migration(engine=engine, message="force empty", force=True) + mock_revision.assert_called_once() + finally: + engine.dispose() + + +def test_generate_schema_migration_no_changes_raises(): + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "gen-nochange-test.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + run_schema_migrations(engine=engine) + with pytest.raises(RuntimeError, match="No schema changes detected"): + generate_schema_migration(engine=engine, message="should fail") + finally: + engine.dispose() + + +def test_generate_schema_migration_with_diffs_creates_revision(): + from unittest.mock import patch + + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "gen-diffs-test.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + run_schema_migrations(engine=engine) + exc = AutogenerateDiffsDetected.__new__(AutogenerateDiffsDetected) + with ( + patch("pyrit.memory.migration.command.check", side_effect=exc), + patch("pyrit.memory.migration.command.revision") as mock_revision, + ): + generate_schema_migration(engine=engine, message="with diffs") + mock_revision.assert_called_once() + finally: + engine.dispose() diff --git a/tests/unit/memory/test_sqlite_memory.py b/tests/unit/memory/test_sqlite_memory.py index 71ed421e8..7a34cd319 100644 --- a/tests/unit/memory/test_sqlite_memory.py +++ b/tests/unit/memory/test_sqlite_memory.py @@ -4,19 +4,22 @@ import io import logging import os +import tempfile import uuid from collections.abc import Sequence from datetime import timezone -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest -from sqlalchemy import ARRAY, DateTime, Integer, String, inspect +from sqlalchemy import ARRAY, DateTime, Integer, String, create_engine, inspect, text from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.sqlite import CHAR, JSON from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.sql.sqltypes import NullType -from pyrit.memory.memory_models import EmbeddingDataEntry, PromptMemoryEntry +from pyrit.memory.memory_models import Base, EmbeddingDataEntry, PromptMemoryEntry +from pyrit.memory.migration import run_schema_migrations +from pyrit.memory.sqlite_memory import SQLiteMemory from pyrit.models import MessagePiece from pyrit.prompt_converter.base64_converter import Base64Converter from pyrit.prompt_target.text_target import TextTarget @@ -156,6 +159,180 @@ def test_embedding_data_column_types(sqlite_instance): ], f"Unexpected type for 'embedding' column: {column_types['embedding']}" +def test_run_schema_migrations_stamps_matching_unversioned_legacy_database(): + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "legacy-memory.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + Base.metadata.create_all(engine, checkfirst=True) + + run_schema_migrations(engine=engine) + + table_names = set(inspect(engine).get_table_names()) + assert "pyrit_memory_alembic_version" in table_names + + with engine.connect() as connection: + version = connection.execute(text("SELECT version_num FROM pyrit_memory_alembic_version")).scalar_one() + + assert version + finally: + engine.dispose() + + +def test_run_schema_migrations_stamps_unversioned_legacy_database_with_extra_tables(): + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "legacy-memory.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + Base.metadata.create_all(engine, checkfirst=True) + with engine.begin() as connection: + connection.execute( + text( + """ + CREATE TABLE "SharedAuditLog" ( + id INTEGER PRIMARY KEY, + event_name VARCHAR NOT NULL + ) + """ + ) + ) + + run_schema_migrations(engine=engine) + + table_names = set(inspect(engine).get_table_names()) + assert "SharedAuditLog" in table_names + assert "pyrit_memory_alembic_version" in table_names + + with engine.connect() as connection: + version = connection.execute(text("SELECT version_num FROM pyrit_memory_alembic_version")).scalar_one() + + assert version + finally: + engine.dispose() + + +def test_run_schema_migrations_fails_synthetic_unversioned_schema_with_drift(): + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "legacy-memory.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + Base.metadata.create_all(engine, checkfirst=True) + with engine.begin() as connection: + connection.execute(text('DROP TABLE "ScoreEntries"')) + + with pytest.raises(RuntimeError, match="unversioned legacy memory schema"): + run_schema_migrations(engine=engine) + + table_names = set(inspect(engine).get_table_names()) + assert "pyrit_memory_alembic_version" not in table_names + finally: + engine.dispose() + + +def test_run_schema_migrations_fails_pre_alembic_like_schema() -> None: + """ + Validate strict behavior for unsupported legacy schemas. + + This schema shape intentionally resembles an older pre-Alembic layout where + newer columns (e.g. pyrit_version, converter_identifiers) are absent. + Such databases are intentionally unsupported and must fail migration checks. + """ + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "legacy-memory.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + with engine.begin() as connection: + connection.execute( + text( + """ + CREATE TABLE "PromptMemoryEntries" ( + id CHAR(36) NOT NULL, + role VARCHAR NOT NULL, + conversation_id VARCHAR NOT NULL, + sequence INTEGER NOT NULL, + timestamp DATETIME NOT NULL, + labels JSON NOT NULL, + prompt_metadata JSON NOT NULL, + prompt_target_identifier JSON NOT NULL, + attack_identifier JSON NOT NULL, + original_value_data_type VARCHAR NOT NULL, + original_value VARCHAR NOT NULL, + original_value_sha256 VARCHAR, + converted_value_data_type VARCHAR NOT NULL, + converted_value VARCHAR, + converted_value_sha256 VARCHAR, + original_prompt_id CHAR(36) NOT NULL, + PRIMARY KEY (id) + ) + """ + ) + ) + + with pytest.raises(RuntimeError, match="unversioned legacy memory schema"): + run_schema_migrations(engine=engine) + + table_names = set(inspect(engine).get_table_names()) + assert "pyrit_memory_alembic_version" not in table_names + finally: + engine.dispose() + + +def test_run_schema_migrations_isolates_foreign_alembic_version_table(): + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "legacy-memory.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + Base.metadata.create_all(engine, checkfirst=True) + with engine.begin() as connection: + connection.execute(text('CREATE TABLE "alembic_version" (version_num VARCHAR(32) NOT NULL)')) + + run_schema_migrations(engine=engine) + + table_names = set(inspect(engine).get_table_names()) + assert "alembic_version" in table_names + assert "pyrit_memory_alembic_version" in table_names + + with engine.connect() as connection: + version = connection.execute(text("SELECT version_num FROM pyrit_memory_alembic_version")).scalar_one() + + assert version + finally: + engine.dispose() + + +def test_reset_database_recreates_schema(sqlite_instance): + sqlite_instance.reset_database() + + inspector = inspect(sqlite_instance.engine) + table_names = set(inspector.get_table_names()) + + assert { + "AttackResultEntries", + "EmbeddingData", + "PromptMemoryEntries", + "ScenarioResultEntries", + "ScoreEntries", + "SeedPromptEntries", + "pyrit_memory_alembic_version", + }.issubset(table_names) + + with sqlite_instance.engine.connect() as connection: + version = connection.execute(text("SELECT version_num FROM pyrit_memory_alembic_version")).scalar_one() + + assert version + + +def test_reset_database_keeps_foreign_alembic_version_table(sqlite_instance): + with sqlite_instance.engine.begin() as connection: + connection.execute(text('CREATE TABLE "alembic_version" (version_num VARCHAR(32) NOT NULL)')) + + sqlite_instance.reset_database() + + table_names = set(inspect(sqlite_instance.engine).get_table_names()) + assert "alembic_version" in table_names + assert "pyrit_memory_alembic_version" in table_names + + @pytest.mark.asyncio() async def test_insert_entry(sqlite_instance): session = sqlite_instance.get_session() @@ -701,19 +878,80 @@ def test_create_engine_uses_static_pool_for_in_memory(sqlite_instance): assert isinstance(sqlite_instance.engine.pool, StaticPool) -def test_create_tables_raises_when_engine_none(): - from pyrit.memory import SQLiteMemory - - obj = SQLiteMemory.__new__(SQLiteMemory) - obj.engine = None - with pytest.raises(RuntimeError, match="Engine is not initialized"): - obj._create_tables_if_not_exist() - - -def test_reset_database_raises_when_engine_none(): - from pyrit.memory import SQLiteMemory - - obj = SQLiteMemory.__new__(SQLiteMemory) - obj.engine = None - with pytest.raises(RuntimeError, match="Engine is not initialized"): - obj.reset_database() +def test_run_schema_migrations_early_return_with_existing_version_table(): + """ + Test that migration early-returns when the version table already exists. + This tests the line 57 return in _validate_and_stamp_unversioned_memory_schema. + """ + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "versioned-memory.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + Base.metadata.create_all(engine, checkfirst=True) + + run_schema_migrations(engine=engine) + + table_names = set(inspect(engine).get_table_names()) + assert "pyrit_memory_alembic_version" in table_names + + with engine.connect() as connection: + version = connection.execute(text("SELECT version_num FROM pyrit_memory_alembic_version")).scalar_one() + assert version + + run_schema_migrations(engine=engine) + + table_names = set(inspect(engine).get_table_names()) + assert "pyrit_memory_alembic_version" in table_names + + with engine.connect() as connection: + version_after = connection.execute( + text("SELECT version_num FROM pyrit_memory_alembic_version") + ).scalar_one() + assert version_after == version + finally: + engine.dispose() + + +def test_run_schema_migrations_no_memory_tables(): + """ + Test that migration early-returns when no memory tables exist. + This tests the line 60 return in _validate_and_stamp_unversioned_memory_schema. + """ + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "empty-memory.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + run_schema_migrations(engine=engine) + + table_names = set(inspect(engine).get_table_names()) + assert { + "AttackResultEntries", + "EmbeddingData", + "PromptMemoryEntries", + "ScenarioResultEntries", + "ScoreEntries", + "SeedPromptEntries", + "pyrit_memory_alembic_version", + }.issubset(table_names) + finally: + engine.dispose() + + +def test_ensure_schema_is_current_with_schema_migration_exception(): + """ + Test that _ensure_schema_is_current properly handles and re-raises exceptions. + This tests lines 132-134 in memory_interface.py. + """ + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "error-memory.db") + + memory = SQLiteMemory(db_path=db_path) + + try: + with patch("pyrit.memory.memory_interface.run_schema_migrations") as mock_migrate: + mock_migrate.side_effect = RuntimeError("Mock migration error") + + with pytest.raises(RuntimeError, match="Mock migration error"): + memory._ensure_schema_is_current() + finally: + memory.dispose_engine() diff --git a/tests/unit/setup/test_configuration_loader.py b/tests/unit/setup/test_configuration_loader.py index 76776bf51..fbdc80db1 100644 --- a/tests/unit/setup/test_configuration_loader.py +++ b/tests/unit/setup/test_configuration_loader.py @@ -531,3 +531,25 @@ def test_load_with_overrides_preserves_silent_from_config_file(self, mock_defaul assert config.silent is True finally: config_path.unlink() + + @mock.patch("pyrit.setup.configuration_loader.DEFAULT_CONFIG_PATH") + def test_load_with_overrides_check_schema_from_default_config(self, mock_default_path): + """Test that check_schema is loaded from the default config file.""" + mock_default_path.exists.return_value = True + + with mock.patch.object(ConfigurationLoader, "from_yaml_file") as mock_from_yaml: + fake_config = ConfigurationLoader(memory_db_type="sqlite", check_schema=False) + mock_from_yaml.return_value = fake_config + + config = ConfigurationLoader.load_with_overrides() + + assert config.check_schema is False + + @mock.patch("pyrit.setup.configuration_loader.DEFAULT_CONFIG_PATH") + def test_load_with_overrides_check_schema_override(self, mock_default_path): + """Test that check_schema override takes precedence.""" + mock_default_path.exists.return_value = False + + config = ConfigurationLoader.load_with_overrides(check_schema=False) + + assert config.check_schema is False diff --git a/uv.lock b/uv.lock index 21952cbdc..165dc224d 100644 --- a/uv.lock +++ b/uv.lock @@ -4899,6 +4899,7 @@ version = "0.13.0.dev0" source = { editable = "." } dependencies = [ { name = "aiofiles" }, + { name = "alembic" }, { name = "appdirs" }, { name = "art" }, { name = "av" }, @@ -5025,6 +5026,7 @@ requires-dist = [ { name = "accelerate", marker = "extra == 'all'", specifier = ">=1.7.0" }, { name = "accelerate", marker = "extra == 'gcg'", specifier = ">=1.7.0" }, { name = "aiofiles", specifier = ">=24,<25" }, + { name = "alembic", specifier = ">=1.16.0" }, { name = "appdirs", specifier = ">=1.4.0" }, { name = "art", specifier = ">=6.5.0" }, { name = "av", specifier = ">=14.0.0" },