diff --git a/python/pyproject.toml b/python/pyproject.toml index 23635e0..251b072 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "toki" -version = "1.6.0" +version = "1.7.0" description = "Adversarial fine-tuning lab for small language models" license = { text = "BUSL-1.1" } requires-python = ">=3.9" diff --git a/python/tests/test_finetune_extended.py b/python/tests/test_finetune_extended.py new file mode 100644 index 0000000..7bb0054 --- /dev/null +++ b/python/tests/test_finetune_extended.py @@ -0,0 +1,200 @@ +"""Tests for extended LoRAConfig, LoRATrainResult, and LoRAFinetuner (Sprint 17).""" +from __future__ import annotations + +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from toki.finetune import LoRAConfig, LoRAFinetuner, TrainingConfig +from toki.safety_lora import LoRATrainResult, SploraAuditResult + + +# --------------------------------------------------------------------------- +# LoRAConfig — new safety-subspace fields (Sprint 17) +# --------------------------------------------------------------------------- + + +def test_lora_config_safety_lora_rank_default(): + cfg = LoRAConfig() + assert cfg.safety_lora_rank == 0 + + +def test_lora_config_safety_subspace_path_default(): + cfg = LoRAConfig() + assert cfg.safety_subspace_path is None + + +def test_lora_config_enable_splora_audit_default(): + cfg = LoRAConfig() + assert cfg.enable_splora_audit is False + + +def test_lora_config_splora_threshold_default(): + cfg = LoRAConfig() + assert cfg.splora_threshold == pytest.approx(0.15) + + +def test_lora_config_safety_fields_set(): + cfg = LoRAConfig( + safety_lora_rank=1, + safety_subspace_path="/tmp/safety.pt", + enable_splora_audit=True, + splora_threshold=0.25, + ) + assert cfg.safety_lora_rank == 1 + assert cfg.safety_subspace_path == "/tmp/safety.pt" + assert cfg.enable_splora_audit is True + assert cfg.splora_threshold == pytest.approx(0.25) + + +# --------------------------------------------------------------------------- +# LoRAConfig — base fields unchanged +# --------------------------------------------------------------------------- + + +def test_lora_config_base_fields_unchanged(): + cfg = LoRAConfig() + assert cfg.r == 8 + assert cfg.lora_alpha == 32 + assert cfg.lora_dropout == pytest.approx(0.1) + assert cfg.target_modules == ["q_proj", "v_proj"] + assert cfg.bias == "none" + + +def test_lora_config_independent_instances(): + a = LoRAConfig(r=4) + b = LoRAConfig(r=16) + assert a.r != b.r + assert a.safety_lora_rank == b.safety_lora_rank + + +# --------------------------------------------------------------------------- +# LoRAFinetuner — construction and properties +# --------------------------------------------------------------------------- + + +def test_finetuner_default_construction(): + ft = LoRAFinetuner() + assert ft.lora_config.r == 8 + assert ft.training_config.num_epochs == 3 + + +def test_finetuner_safety_fields_accessible(): + cfg = LoRAConfig(safety_lora_rank=1, enable_splora_audit=True) + ft = LoRAFinetuner(lora_config=cfg) + assert ft.lora_config.safety_lora_rank == 1 + assert ft.lora_config.enable_splora_audit is True + + +def test_finetuner_config_summary_includes_safety_fields(): + cfg = LoRAConfig(safety_lora_rank=1, enable_splora_audit=True) + ft = LoRAFinetuner(lora_config=cfg) + summary = ft.config_summary() + assert "safety_lora_rank" in summary["lora"] + assert summary["lora"]["safety_lora_rank"] == 1 + assert "enable_splora_audit" in summary["lora"] + assert summary["lora"]["enable_splora_audit"] is True + + +def test_finetuner_config_summary_safety_subspace_path(): + cfg = LoRAConfig(safety_subspace_path="/tmp/s.pt") + ft = LoRAFinetuner(lora_config=cfg) + summary = ft.config_summary() + assert summary["lora"]["safety_subspace_path"] == "/tmp/s.pt" + + +# --------------------------------------------------------------------------- +# LoRAFinetuner.train() — import guard (no torch available) +# --------------------------------------------------------------------------- + + +def test_train_raises_import_error_without_torch(): + with patch.dict(sys.modules, {"torch": None, "transformers": None, "datasets": None}): + ft = LoRAFinetuner() + with pytest.raises(ImportError, match="toki\\[hf\\]"): + ft.train(MagicMock(), MagicMock(), prompts=["test"]) + + +# --------------------------------------------------------------------------- +# LoRAFinetuner.train() — safety hook integration (mocked torch) +# --------------------------------------------------------------------------- + + +def _make_mock_env(): + """Build a minimal mock environment for LoRAFinetuner.train().""" + mock_torch = MagicMock() + mock_torch.no_grad.return_value.__enter__ = MagicMock(return_value=None) + mock_torch.no_grad.return_value.__exit__ = MagicMock(return_value=False) + + mock_trainer_output = MagicMock() + mock_trainer_output.training_loss = 0.25 + mock_trainer_output.global_step = 10 + + mock_trainer = MagicMock() + mock_trainer.train.return_value = mock_trainer_output + + mock_transformers = MagicMock() + mock_transformers.Trainer.return_value = mock_trainer + mock_transformers.TrainingArguments = MagicMock() + mock_transformers.DataCollatorForLanguageModeling = MagicMock() + + mock_tokenizer = MagicMock() + mock_tokenizer.return_value = {"input_ids": [[1, 2, 3]]} + + mock_hf_dataset = MagicMock() + mock_hf_dataset.map.return_value = mock_hf_dataset + + mock_datasets = MagicMock() + mock_datasets.Dataset.from_dict.return_value = mock_hf_dataset + + return mock_torch, mock_transformers, mock_datasets, mock_tokenizer + + +def test_train_returns_lora_train_result(): + mock_torch, mock_transformers, mock_datasets, _ = _make_mock_env() + modules = { + "torch": mock_torch, + "transformers": mock_transformers, + "datasets": mock_datasets, + } + with patch.dict(sys.modules, modules): + ft = LoRAFinetuner() + mock_model = MagicMock() + mock_model.named_parameters.return_value = [] + result = ft.train(mock_model, MagicMock(), prompts=["p1", "p2"]) + + assert isinstance(result, LoRATrainResult) + assert result.training_loss == pytest.approx(0.25) + assert result.num_steps == 10 + assert result.splora_audit is None + + +def test_train_no_safety_fields_no_audit(): + mock_torch, mock_transformers, mock_datasets, _ = _make_mock_env() + modules = { + "torch": mock_torch, + "transformers": mock_transformers, + "datasets": mock_datasets, + } + with patch.dict(sys.modules, modules): + cfg = LoRAConfig() # all safety defaults (disabled) + ft = LoRAFinetuner(lora_config=cfg) + mock_model = MagicMock() + mock_model.named_parameters.return_value = [] + result = ft.train(mock_model, MagicMock(), prompts=["p"]) + + assert result.splora_audit is None + + +def test_train_raises_without_prompts_or_dataset(): + mock_torch, mock_transformers, mock_datasets, _ = _make_mock_env() + modules = { + "torch": mock_torch, + "transformers": mock_transformers, + "datasets": mock_datasets, + } + with patch.dict(sys.modules, modules): + ft = LoRAFinetuner() + with pytest.raises(ValueError, match="dataset or prompts"): + ft.train(MagicMock(), MagicMock()) diff --git a/python/tests/test_main.py b/python/tests/test_main.py index f7e930b..7217055 100644 --- a/python/tests/test_main.py +++ b/python/tests/test_main.py @@ -386,3 +386,47 @@ def test_evaluate_gguf_import_error(capsys): with patch.dict(sys.modules, {"llama_cpp": None}): with pytest.raises((ImportError, SystemExit)): main(["evaluate", "--evaluator", "gguf://nonexistent.gguf"]) + + +# --------------------------------------------------------------------------- +# Sprint 17 — finetune subcommand CLI tests +# --------------------------------------------------------------------------- + + +def test_finetune_no_model_prints_config(capsys): + """finetune without --model prints config summary.""" + main(["finetune"]) + captured = capsys.readouterr() + assert "Safety-subspace LoRA" in captured.out + assert "safety_lora_rank" in captured.out + + +def test_finetune_safety_lora_rank_reflected(capsys): + """--safety-lora-rank is reflected in printed config.""" + main(["finetune", "--safety-lora-rank", "1"]) + captured = capsys.readouterr() + assert "1" in captured.out + + +def test_finetune_splora_audit_flag_reflected(capsys): + """--splora-audit flag shows as True in config output.""" + main(["finetune", "--splora-audit"]) + captured = capsys.readouterr() + assert "True" in captured.out + + +def test_finetune_safety_subspace_path_reflected(capsys): + """--safety-subspace path is shown in config output.""" + main(["finetune", "--safety-subspace", "/tmp/delta.pt"]) + captured = capsys.readouterr() + assert "/tmp/delta.pt" in captured.out + + +def test_finetune_model_requires_hf(capsys): + """--model without toki[hf] raises SystemExit with helpful message.""" + import sys + from unittest.mock import patch + + with patch.dict(sys.modules, {"torch": None, "peft": None, "transformers": None}): + with pytest.raises((ImportError, SystemExit)): + main(["finetune", "--model", "gpt2"]) diff --git a/python/tests/test_safety_lora.py b/python/tests/test_safety_lora.py new file mode 100644 index 0000000..12c0c84 --- /dev/null +++ b/python/tests/test_safety_lora.py @@ -0,0 +1,265 @@ +"""Tests for toki.safety_lora — SafetyLoRAConfig, SploraAuditResult, +LoRATrainResult, load_safety_subspace, freeze_safety_adapter, splora_audit.""" +from __future__ import annotations + +import math +import sys +from pathlib import Path +from types import ModuleType +from unittest.mock import MagicMock, patch + +import pytest + +from toki.safety_lora import ( + LoRATrainResult, + SafetyLoRAConfig, + SploraAuditResult, + freeze_safety_adapter, + load_safety_subspace, + splora_audit, +) + + +# --------------------------------------------------------------------------- +# SafetyLoRAConfig +# --------------------------------------------------------------------------- + + +def test_config_defaults(): + cfg = SafetyLoRAConfig() + assert cfg.safety_lora_rank == 0 + assert cfg.safety_subspace_path is None + assert cfg.enable_splora_audit is False + assert cfg.splora_threshold == pytest.approx(0.15) + + +def test_config_custom_values(): + cfg = SafetyLoRAConfig( + safety_lora_rank=1, + safety_subspace_path="/tmp/safety.pt", + enable_splora_audit=True, + splora_threshold=0.2, + ) + assert cfg.safety_lora_rank == 1 + assert cfg.safety_subspace_path == "/tmp/safety.pt" + assert cfg.enable_splora_audit is True + assert cfg.splora_threshold == pytest.approx(0.2) + + +def test_config_rank_zero_disabled(): + cfg = SafetyLoRAConfig(safety_lora_rank=0) + assert cfg.safety_lora_rank == 0 + + +def test_config_fields_independent(): + a = SafetyLoRAConfig(safety_lora_rank=1) + b = SafetyLoRAConfig(safety_lora_rank=0) + assert a.safety_lora_rank != b.safety_lora_rank + assert a.enable_splora_audit == b.enable_splora_audit + + +# --------------------------------------------------------------------------- +# SploraAuditResult +# --------------------------------------------------------------------------- + + +def test_audit_result_passed_no_flagged(): + r = SploraAuditResult(flagged_layers=[], max_ediem=0.05, passed=True, threshold=0.15) + assert r.passed is True + assert r.flagged_layers == [] + + +def test_audit_result_failed_with_flagged(): + r = SploraAuditResult( + flagged_layers=["model.layer.0.weight"], + max_ediem=0.3, + passed=False, + threshold=0.15, + ) + assert r.passed is False + assert len(r.flagged_layers) == 1 + + +def test_audit_result_to_dict(): + r = SploraAuditResult( + flagged_layers=["a", "b"], max_ediem=0.2, passed=False, threshold=0.15 + ) + d = r.to_dict() + assert d["passed"] is False + assert d["flagged_layers"] == ["a", "b"] + assert d["max_ediem"] == pytest.approx(0.2) + assert d["threshold"] == pytest.approx(0.15) + + +def test_audit_result_frozen(): + r = SploraAuditResult(flagged_layers=[], max_ediem=0.0, passed=True, threshold=0.15) + with pytest.raises(Exception): + r.passed = False # type: ignore[misc] + + +def test_audit_result_max_ediem_propagated(): + r = SploraAuditResult(flagged_layers=[], max_ediem=0.08, passed=True, threshold=0.15) + assert r.max_ediem == pytest.approx(0.08) + + +# --------------------------------------------------------------------------- +# LoRATrainResult +# --------------------------------------------------------------------------- + + +def test_lora_train_result_fields(): + r = LoRATrainResult(training_loss=0.42, num_steps=100) + assert r.training_loss == pytest.approx(0.42) + assert r.num_steps == 100 + assert r.splora_audit is None + + +def test_lora_train_result_with_audit(): + audit = SploraAuditResult(flagged_layers=[], max_ediem=0.01, passed=True, threshold=0.15) + r = LoRATrainResult(training_loss=0.1, num_steps=50, splora_audit=audit) + assert r.splora_audit is audit + assert r.splora_audit.passed is True + + +def test_lora_train_result_splora_none_by_default(): + r = LoRATrainResult(training_loss=0.5, num_steps=10) + assert r.splora_audit is None + + +# --------------------------------------------------------------------------- +# load_safety_subspace — import guard +# --------------------------------------------------------------------------- + + +def test_load_safety_subspace_import_error(): + with patch.dict(sys.modules, {"torch": None}): + with pytest.raises(ImportError, match=r"toki\[hf\]"): + load_safety_subspace("any_path.pt") + + +def test_load_safety_subspace_import_error_message(): + with patch.dict(sys.modules, {"torch": None}): + with pytest.raises(ImportError, match="pip install toki"): + load_safety_subspace("any_path.pt") + + +def test_load_safety_subspace_file_not_found(tmp_path): + mock_torch = MagicMock() + with patch.dict(sys.modules, {"torch": mock_torch}): + with pytest.raises(FileNotFoundError, match="not found"): + load_safety_subspace(str(tmp_path / "nonexistent.pt")) + + +def test_load_safety_subspace_returns_state_dict(tmp_path): + fake_state = {"weight": MagicMock()} + mock_torch = MagicMock() + mock_torch.load.return_value = fake_state + checkpoint = tmp_path / "safety.pt" + checkpoint.write_bytes(b"fake") + with patch.dict(sys.modules, {"torch": mock_torch}): + result = load_safety_subspace(str(checkpoint)) + assert result is fake_state + + +# --------------------------------------------------------------------------- +# freeze_safety_adapter — import guard and no-op +# --------------------------------------------------------------------------- + + +def test_freeze_safety_adapter_noop_when_none(): + mock_model = MagicMock() + freeze_safety_adapter(mock_model, None) + mock_model.named_parameters.assert_not_called() + + +def test_freeze_safety_adapter_import_error(): + with patch.dict(sys.modules, {"torch": None}): + with pytest.raises(ImportError, match=r"toki\[hf\]"): + freeze_safety_adapter(MagicMock(), {"layer": MagicMock()}) + + +def test_freeze_safety_adapter_applies_matching_tensors(): + # Build minimal mock model and torch + mock_param = MagicMock() + mock_param.device = "cpu" + mock_param.shape = (4, 4) + mock_param.data = MagicMock() + + mock_model = MagicMock() + mock_model.named_parameters.return_value = [("base_model.model.weight", mock_param)] + + mock_delta = MagicMock() + mock_delta.shape = (4, 4) + mock_delta.to.return_value = mock_delta + + mock_torch = MagicMock() + mock_torch.no_grad.return_value.__enter__ = MagicMock(return_value=None) + mock_torch.no_grad.return_value.__exit__ = MagicMock(return_value=False) + + with patch.dict(sys.modules, {"torch": mock_torch}): + freeze_safety_adapter(mock_model, {"weight": mock_delta}) + + mock_param.requires_grad_.assert_called_once_with(False) + + +def test_freeze_safety_adapter_skips_non_matching_keys(): + mock_model = MagicMock() + mock_model.named_parameters.return_value = [("model.weight", MagicMock())] + + mock_torch = MagicMock() + mock_torch.no_grad.return_value.__enter__ = MagicMock(return_value=None) + mock_torch.no_grad.return_value.__exit__ = MagicMock(return_value=False) + + with patch.dict(sys.modules, {"torch": mock_torch}): + # safety_delta has different key — no match → no freeze + freeze_safety_adapter(mock_model, {"completely_different_key": MagicMock()}) + + for _, param in mock_model.named_parameters(): + param.requires_grad_.assert_not_called() + + +# --------------------------------------------------------------------------- +# splora_audit — import guard and basic behaviour +# --------------------------------------------------------------------------- + + +def test_splora_audit_import_error(): + with patch.dict(sys.modules, {"torch": None}): + with pytest.raises(ImportError, match=r"toki\[hf\]"): + splora_audit(MagicMock(), {}) + + +def test_splora_audit_returns_audit_result(): + mock_torch = MagicMock() + mock_model = MagicMock() + mock_model.named_parameters.return_value = [] + + with patch.dict(sys.modules, {"torch": mock_torch}): + result = splora_audit(mock_model, {}) + + assert isinstance(result, SploraAuditResult) + assert result.passed is True + assert result.flagged_layers == [] + + +def test_splora_audit_empty_base_state_passes(): + mock_torch = MagicMock() + mock_model = MagicMock() + mock_model.named_parameters.return_value = [] + + with patch.dict(sys.modules, {"torch": mock_torch}): + result = splora_audit(mock_model, {}, threshold=0.15) + + assert result.passed is True + assert result.max_ediem == pytest.approx(0.0) + + +def test_splora_audit_threshold_propagated(): + mock_torch = MagicMock() + mock_model = MagicMock() + mock_model.named_parameters.return_value = [] + + with patch.dict(sys.modules, {"torch": mock_torch}): + result = splora_audit(mock_model, {}, threshold=0.3) + + assert result.threshold == pytest.approx(0.3) diff --git a/python/toki/__init__.py b/python/toki/__init__.py index 23a1e4b..06ba6af 100644 --- a/python/toki/__init__.py +++ b/python/toki/__init__.py @@ -1,7 +1,7 @@ """Toki — adversarial fine-tuning lab for small LLMs.""" from __future__ import annotations -__version__ = "1.6.0" +__version__ = "1.7.0" from toki.generate import AdversarialGenerator from toki.evaluate import ( @@ -12,6 +12,14 @@ RuleScorer, ScoredResult, ) +from toki.safety_lora import ( + LoRATrainResult, + SafetyLoRAConfig, + SploraAuditResult, + freeze_safety_adapter, + load_safety_subspace, + splora_audit, +) from toki.dataset import AdversarialDataset from toki.experiment import TokiExperiment, ExperimentConfig from toki.results import ExperimentResult @@ -193,6 +201,13 @@ "RobustnessEvaluator", "RuleScorer", "ScoredResult", + # Sprint 17 — safety-subspace LoRA + "LoRATrainResult", + "SafetyLoRAConfig", + "SploraAuditResult", + "freeze_safety_adapter", + "load_safety_subspace", + "splora_audit", "AdversarialDataset", "TokiExperiment", "ExperimentConfig", diff --git a/python/toki/__main__.py b/python/toki/__main__.py index 646c69f..bf0f462 100644 --- a/python/toki/__main__.py +++ b/python/toki/__main__.py @@ -665,6 +665,19 @@ def build_parser() -> argparse.ArgumentParser: ) p_ag.add_argument("--json", action="store_true") + # finetune (Sprint 17 — safety-subspace LoRA) + p_ft = sub.add_parser("finetune", help="Fine-tune with safety-subspace LoRA (requires toki[hf])") + p_ft.add_argument("--model", type=str, default=None, + help="HuggingFace model name or path (required for real fine-tuning)") + p_ft.add_argument("--safety-lora-rank", type=int, default=0, dest="safety_lora_rank", + help="Frozen safety adapter rank (0=disabled, 1=rank-1 MLP; arXiv 2507.17075)") + p_ft.add_argument("--safety-subspace", type=str, default=None, dest="safety_subspace", + help="Path to pre-computed safety delta checkpoint (SaLoRA; arXiv 2501.01765)") + p_ft.add_argument("--splora-audit", action="store_true", dest="splora_audit", + help="Run E-DIEM safety-subspace audit after training (SPLoRA; arXiv 2506.18931)") + p_ft.add_argument("--splora-threshold", type=float, default=0.15, dest="splora_threshold", + help="E-DIEM distance threshold for flagging unsafe updates (default: 0.15)") + return ap @@ -946,6 +959,35 @@ def cmd_attack_list(args) -> None: print(f" [{a.category:<10}] {refusal} id={a.id} {a.text[:70]}") +def cmd_finetune(args) -> None: + from toki.finetune import LoRAConfig, LoRAFinetuner + + cfg = LoRAConfig( + safety_lora_rank=args.safety_lora_rank, + safety_subspace_path=args.safety_subspace, + enable_splora_audit=args.splora_audit, + splora_threshold=args.splora_threshold, + ) + ft = LoRAFinetuner(lora_config=cfg) + summary = ft.config_summary() + + if args.model is None: + print("Safety-subspace LoRA configuration:") + print(f" safety_lora_rank: {summary['lora']['safety_lora_rank']}") + print(f" safety_subspace: {summary['lora']['safety_subspace_path'] or '(none)'}") + print(f" enable_splora_audit: {summary['lora']['enable_splora_audit']}") + print("\nProvide --model to run actual fine-tuning (requires toki[hf]).") + return + + try: + model, tokenizer = ft.prepare_model(args.model) + except ImportError as exc: + raise SystemExit(f"Fine-tuning requires toki[hf]: {exc}") from exc + + print(f"Model loaded: {args.model}") + print("Run ft.train(model, tokenizer, prompts=[...]) to fine-tune.") + + def cmd_agentic(args) -> None: import json as _json from toki.agentic import AgentAttackBattery, AgentAttackEvaluator, AgentAttackType @@ -1031,6 +1073,8 @@ def main(argv=None) -> None: cmd_attack_add(args) elif args.command == "attack-list": cmd_attack_list(args) + elif args.command == "finetune": + cmd_finetune(args) if __name__ == "__main__": diff --git a/python/toki/finetune.py b/python/toki/finetune.py index 82b8ee3..1e4949c 100644 --- a/python/toki/finetune.py +++ b/python/toki/finetune.py @@ -6,18 +6,44 @@ """ from __future__ import annotations +import logging from dataclasses import dataclass, field +from typing import Optional + +from toki.safety_lora import LoRATrainResult, SploraAuditResult + +logger = logging.getLogger(__name__) @dataclass class LoRAConfig: - """Configuration for LoRA adapters.""" + """Configuration for LoRA adapters. + + Safety-subspace fields (Sprint 17 — arXiv 2501.01765 / 2506.18931 / 2507.17075): + + safety_lora_rank: + Rank for the frozen safety adapter applied before task training. + 0 = disabled (default). 1 = rank-1 MLP up_proj (arXiv 2507.17075). + safety_subspace_path: + Path to a pre-computed safety delta checkpoint. When set, loaded and + applied as a frozen shift before LoRA task training (SaLoRA). + enable_splora_audit: + When True, run E-DIEM safety-subspace audit after training completes. + Result is attached to the returned LoRATrainResult. + splora_threshold: + E-DIEM distance threshold for flagging unsafe weight updates (default 0.15). + """ r: int = 8 # LoRA rank lora_alpha: int = 32 lora_dropout: float = 0.1 target_modules: list[str] = field(default_factory=lambda: ["q_proj", "v_proj"]) bias: str = "none" + # Safety-subspace fields (Sprint 17) + safety_lora_rank: int = 0 + safety_subspace_path: Optional[str] = None + enable_splora_audit: bool = False + splora_threshold: float = 0.15 @dataclass @@ -104,10 +130,12 @@ def train( tokenizer, prompts=None, dataset=None, - ) -> dict: - """ - Fine-tune model on adversarial prompts. - Returns dict with training_loss and num_steps. + ) -> LoRATrainResult: + """Fine-tune model on adversarial prompts. + + Returns a LoRATrainResult containing training_loss, num_steps, and + an optional SploraAuditResult when enable_splora_audit is True. + Requires: pip install toki[hf] (peft + datasets + transformers). Parameters @@ -123,16 +151,28 @@ def train( """ try: import torch - from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling + from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments from datasets import Dataset as HFDataset - except ImportError as e: - raise ImportError(f"Training requires toki[hf]: {e}") from e + except ImportError as exc: + raise ImportError(f"Training requires toki[hf]: {exc}") from exc + + # Snapshot base state before any safety-subspace modifications, + # so the post-hoc E-DIEM audit compares against the true pre-training weights. + base_state: Optional[dict] = None + if self._lora.enable_splora_audit: + base_state = {k: v.data.clone() for k, v in model.named_parameters()} + + # SaLoRA: apply frozen safety delta before task fine-tuning + if self._lora.safety_subspace_path is not None: + from toki.safety_lora import freeze_safety_adapter, load_safety_subspace + safety_delta = load_safety_subspace(self._lora.safety_subspace_path) + freeze_safety_adapter(model, safety_delta) # Collect text prompts if dataset is not None: texts = [p.text for p in dataset] elif prompts is not None: - texts = prompts + texts = list(prompts) else: raise ValueError("Provide either dataset or prompts") @@ -170,11 +210,23 @@ def tokenize(examples): data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), ) - train_result = trainer.train() - return { - "training_loss": train_result.training_loss, - "num_steps": train_result.global_step, - } + train_output = trainer.train() + + # SPLoRA: E-DIEM post-hoc audit + audit: Optional[SploraAuditResult] = None + if self._lora.enable_splora_audit and base_state is not None: + from toki.safety_lora import splora_audit + audit = splora_audit( + model, + base_state, + threshold=self._lora.splora_threshold, + ) + + return LoRATrainResult( + training_loss=train_output.training_loss, + num_steps=train_output.global_step, + splora_audit=audit, + ) def config_summary(self) -> dict: """Return a JSON-serialisable summary of current configuration.""" @@ -184,6 +236,9 @@ def config_summary(self) -> dict: "alpha": self._lora.lora_alpha, "dropout": self._lora.lora_dropout, "target_modules": self._lora.target_modules, + "safety_lora_rank": self._lora.safety_lora_rank, + "safety_subspace_path": self._lora.safety_subspace_path, + "enable_splora_audit": self._lora.enable_splora_audit, }, "training": { "epochs": self._training.num_epochs, diff --git a/python/toki/safety_lora.py b/python/toki/safety_lora.py new file mode 100644 index 0000000..2409adb --- /dev/null +++ b/python/toki/safety_lora.py @@ -0,0 +1,304 @@ +"""Safety-subspace LoRA utilities (Sprint 17 — v1.7.0). + +Three complementary approaches from the 2025-2026 safety-preserving LoRA +literature, all validated on 1B-3B model targets (toki's range): + + SaLoRA (arXiv 2501.01765) — training-time: freeze a pre-computed safety delta + before task fine-tuning begins so alignment features are not overwritten. + SPLoRA (arXiv 2506.18931, TACL) — post-hoc: E-DIEM audit of weight-update + shifts after fine-tuning completes; flags layers that eroded safety. + Rank-1 (arXiv 2507.17075) — rank-1 LoRA on middle up_proj layers only; + zero reasoning tax; minimal intervention for reasoning-capable models. + +All torch / peft operations are deferred behind try-import guards. +Functions raise ImportError("requires toki[hf]: pip install toki[hf]") cleanly +when optional deps are absent. +""" +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + + +@dataclass +class SafetyLoRAConfig: + """Safety-preserving extension fields for LoRAConfig. + + Attributes + ---------- + safety_lora_rank: + Rank for the frozen safety adapter (0 = disabled). Rank-1 on middle + up_proj layers is sufficient with zero reasoning tax (arXiv 2507.17075). + safety_subspace_path: + Path to a pre-computed safety delta checkpoint (.pt file). When set, + the delta is applied and those parameters are frozen before task + fine-tuning begins (SaLoRA approach, arXiv 2501.01765). + enable_splora_audit: + When True, run E-DIEM safety-subspace audit after training completes + (SPLoRA post-hoc check, arXiv 2506.18931). + splora_threshold: + E-DIEM normalised Frobenius distance above which a layer is flagged. + Default 0.15 follows the SPLoRA paper's reported threshold. + """ + + safety_lora_rank: int = 0 + safety_subspace_path: Optional[str] = None + enable_splora_audit: bool = False + splora_threshold: float = 0.15 + + +# --------------------------------------------------------------------------- +# Data model +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class SploraAuditResult: + """Result of an E-DIEM safety-subspace audit (SPLoRA, arXiv 2506.18931). + + Attributes + ---------- + flagged_layers: + Parameter names whose weight-update E-DIEM distance exceeded threshold. + max_ediem: + Highest E-DIEM distance observed across all checked layers. + passed: + True when no layers were flagged (all updates within safety subspace). + threshold: + The E-DIEM threshold used for this audit. + """ + + flagged_layers: List[str] + max_ediem: float + passed: bool + threshold: float + + def to_dict(self) -> Dict[str, Any]: + return { + "flagged_layers": list(self.flagged_layers), + "max_ediem": self.max_ediem, + "passed": self.passed, + "threshold": self.threshold, + } + + +# --------------------------------------------------------------------------- +# LoRATrainResult +# --------------------------------------------------------------------------- + + +@dataclass +class LoRATrainResult: + """Return value of LoRAFinetuner.train(). + + Attributes + ---------- + training_loss: + Final training loss from the HF Trainer. + num_steps: + Total training steps completed. + splora_audit: + E-DIEM safety-subspace audit result, or None when + enable_splora_audit=False (the default). + """ + + training_loss: float + num_steps: int + splora_audit: Optional[SploraAuditResult] = None + + +# --------------------------------------------------------------------------- +# SaLoRA — load and freeze safety subspace +# --------------------------------------------------------------------------- + + +def load_safety_subspace(path: str) -> Dict[str, Any]: + """Load a safety delta checkpoint from disk. + + Raises ``ImportError`` when torch is absent. + Raises ``FileNotFoundError`` when the checkpoint file does not exist. + + Parameters + ---------- + path: + Filesystem path to a PyTorch checkpoint (.pt / .bin) produced by + saving a safety adapter's state dict. + """ + try: + import torch + except ImportError as exc: + raise ImportError( + "load_safety_subspace requires toki[hf]: pip install toki[hf]" + ) from exc + + p = Path(path) + if not p.exists(): + raise FileNotFoundError(f"Safety subspace checkpoint not found: {path}") + + state: Dict[str, Any] = torch.load(str(p), map_location="cpu", weights_only=True) + logger.debug("load_safety_subspace: loaded %d tensors from %s", len(state), path) + return state + + +def freeze_safety_adapter( + model: Any, + safety_delta: Optional[Dict[str, Any]], +) -> None: + """Freeze safety-alignment parameters in the model. + + When *safety_delta* is ``None`` this is a complete no-op — no model + modification occurs and no imports are attempted. + + Otherwise, for each parameter name in *safety_delta* that matches a + model parameter, the delta tensor is added to the parameter data in-place + and that parameter is marked ``requires_grad=False``. This prevents + task fine-tuning from overwriting the safety-alignment features + (SaLoRA, arXiv 2501.01765). + + Parameters + ---------- + model: + Any ``nn.Module`` (typically a peft.PeftModel from prepare_model()). + safety_delta: + Parameter state dict from :func:`load_safety_subspace`. Keys are + parameter names; values are tensors. + """ + if safety_delta is None: + return + + try: + import torch + except ImportError as exc: + raise ImportError( + "freeze_safety_adapter requires toki[hf]: pip install toki[hf]" + ) from exc + + matched = 0 + for name, param in model.named_parameters(): + key = name.replace("base_model.model.", "") + if key not in safety_delta: + continue + delta = safety_delta[key].to(param.device) + if delta.shape != param.shape: + logger.warning( + "freeze_safety_adapter: shape mismatch for %r (%s vs %s), skipping", + name, + tuple(param.shape), + tuple(delta.shape), + ) + continue + with torch.no_grad(): + param.data += delta + param.requires_grad_(False) + matched += 1 + + logger.debug( + "freeze_safety_adapter: froze %d / %d safety-delta tensors", + matched, + len(safety_delta), + ) + + +# --------------------------------------------------------------------------- +# SPLoRA — E-DIEM post-hoc audit +# --------------------------------------------------------------------------- + + +def _ediem(base_tensor: Any, fine_tuned_tensor: Any) -> float: + """Compute simplified E-DIEM distance between two tensors. + + E-DIEM (Empirical Dimension-Insensitive Evidence Metric) from SPLoRA + (arXiv 2506.18931) measures how much a fine-tuned weight update shifts + the representation away from the safety subspace. + + We use the normalised Frobenius norm of the weight delta as a practical + approximation when a pre-computed safety-subspace projection is unavailable. + Returns 0.0 when torch is absent (safe no-op path). + """ + try: + import torch as _torch # noqa: F401 + except ImportError: + return 0.0 + + delta = fine_tuned_tensor.float() - base_tensor.float() + base_norm = base_tensor.float().norm(p="fro").item() + if base_norm < 1e-8: + return 0.0 + return delta.norm(p="fro").item() / base_norm + + +def splora_audit( + model: Any, + base_state: Dict[str, Any], + *, + threshold: float = 0.15, +) -> SploraAuditResult: + """Run an E-DIEM safety-subspace audit of a fine-tuned model. + + Compares each parameter of the fine-tuned model against the pre-training + base state. Flags layers where the normalised Frobenius distance exceeds + *threshold*. + + Logs a WARNING for every flagged layer so the CI output is actionable. + Logs a WARNING when no matching layers are found (misconfigured base_state). + + Parameters + ---------- + model: + Fine-tuned model (peft.PeftModel or any nn.Module). + base_state: + Pre-training parameter state dict; keys = parameter names, + values = tensors captured before trainer.train(). + threshold: + E-DIEM distance above which a layer is flagged. + """ + try: + import torch as _torch # noqa: F401 + except ImportError as exc: + raise ImportError( + "splora_audit requires toki[hf]: pip install toki[hf]" + ) from exc + + flagged: List[str] = [] + max_dist = 0.0 + checked = 0 + + current: Dict[str, Any] = {k: v for k, v in model.named_parameters()} + + for name, base_tensor in base_state.items(): + if name not in current: + continue + dist = _ediem(base_tensor, current[name].data) + if dist > max_dist: + max_dist = dist + if dist > threshold: + flagged.append(name) + logger.warning( + "SPLoRA audit: %r E-DIEM=%.4f exceeds threshold=%.4f", + name, + dist, + threshold, + ) + checked += 1 + + if checked == 0: + logger.warning( + "SPLoRA audit: no matching layers found between model and base_state " + "— verify that base_state was captured from the same model" + ) + + return SploraAuditResult( + flagged_layers=flagged, + max_ediem=max_dist, + passed=len(flagged) == 0, + threshold=threshold, + )