diff --git a/src/shelfai/cli/main.py b/src/shelfai/cli/main.py index 3df187d..6a23dbd 100644 --- a/src/shelfai/cli/main.py +++ b/src/shelfai/cli/main.py @@ -34,11 +34,17 @@ from typing import Optional import typer +import yaml from rich.console import Console from rich.markup import escape as markup_escape from rich.table import Table from shelfai import __version__ +from shelfai.core.config_schema import ( + generate_config_template, + get_default_config, + validate_config_file, +) from shelfai.core.display import error as display_error from shelfai.core.display import header as display_header from shelfai.core.display import info as display_info @@ -50,6 +56,7 @@ help="📚 ShelfAI — Lightweight, file-based context management for AI agents.", no_args_is_help=True, ) +config_app = typer.Typer(name="config", help="Manage chunk defaults and validation.") console = Console() console_err = Console(stderr=True) @@ -68,11 +75,184 @@ }, } +app.add_typer(config_app) + + def _load_chunk_defaults(start: Path, agent_dir: Optional[Path] = None): from shelfai.core.config import ChunkDefaultsConfig repo_root = ChunkDefaultsConfig.find_repo_root(start=start) - return ChunkDefaultsConfig.load(agent_dir=agent_dir, repo_root=repo_root) + config = ChunkDefaultsConfig.default() + + if repo_root: + config = _merge_chunk_defaults_layer( + config, + repo_root / ".chunk-defaults.yaml", + label="repo", + ) + + if agent_dir: + config = _merge_chunk_defaults_layer( + config, + Path(agent_dir) / "chunk.yaml", + label="agent", + ) + + return config + + +def _merge_chunk_defaults_layer(base, path: Path, label: str): + from shelfai.core.config import ChunkDefaultsConfig + + if not path.exists(): + return base + + is_valid, errors = validate_config_file(str(path)) + if not is_valid: + _warn_invalid_chunk_defaults(path, label, errors) + return base + + try: + raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {} + except Exception as exc: + _warn_invalid_chunk_defaults(path, label, [f"Failed to load YAML: {exc}"]) + return base + + try: + return ChunkDefaultsConfig._merge(base, raw) + except Exception as exc: + _warn_invalid_chunk_defaults(path, label, [f"Failed to merge chunk defaults: {exc}"]) + return base + + +def _warn_invalid_chunk_defaults(path: Path, label: str, errors: list[str]) -> None: + console_err.print( + f"[yellow]Warning: invalid {label} chunk defaults at {path}. " + "Using built-in defaults for that layer.[/yellow]" + ) + for error in errors: + console_err.print(f"[yellow] - {error}[/yellow]") + + +def _load_effective_chunk_defaults_dict(start: Path, agent_dir: Optional[Path] = None) -> dict: + from shelfai.core.config import ChunkDefaultsConfig + + repo_root = ChunkDefaultsConfig.find_repo_root(start=start) + config = get_default_config() + + if repo_root: + config = _merge_raw_chunk_defaults_layer( + config, + repo_root / ".chunk-defaults.yaml", + label="repo", + ) + + if agent_dir: + config = _merge_raw_chunk_defaults_layer( + config, + Path(agent_dir) / "chunk.yaml", + label="agent", + ) + + return config + + +def _merge_raw_chunk_defaults_layer(base: dict, path: Path, label: str) -> dict: + if not path.exists(): + return base + + is_valid, errors = validate_config_file(str(path)) + if not is_valid: + _warn_invalid_chunk_defaults(path, label, errors) + return base + + try: + raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {} + except Exception as exc: + _warn_invalid_chunk_defaults(path, label, [f"Failed to load YAML: {exc}"]) + return base + + return _deep_merge_dicts(base, raw) + + +def _deep_merge_dicts(base: dict, overrides: dict) -> dict: + merged = dict(base) + for key, value in overrides.items(): + if key in merged and isinstance(merged[key], dict) and isinstance(value, dict): + merged[key] = _deep_merge_dicts(merged[key], value) + else: + merged[key] = value + return merged + + +def _resolve_chunk_defaults_path(path: Optional[str]) -> Path: + resolved = Path(path).expanduser() if path else Path.cwd() / ".chunk-defaults.yaml" + if resolved.is_dir(): + resolved = resolved / ".chunk-defaults.yaml" + return resolved.resolve() + + +def _resolve_chunk_defaults_context(path: Optional[str]) -> tuple[Path, Optional[Path]]: + if path is None: + return Path.cwd(), None + + resolved = Path(path).expanduser().resolve() + if resolved.is_dir(): + return resolved, None + if resolved.name == "chunk.yaml": + return resolved.parent, resolved.parent + if resolved.name == ".chunk-defaults.yaml": + return resolved.parent, None + return resolved.parent, None + + +@config_app.command("validate") +def config_validate( + path: Optional[str] = typer.Argument( + None, help="Path to .chunk-defaults.yaml or a directory containing it" + ), +): + """Validate a chunk defaults file.""" + config_path = _resolve_chunk_defaults_path(path) + is_valid, errors = validate_config_file(str(config_path)) + + if is_valid: + console.print(f"[green]Valid config:[/green] {config_path}") + return + + console.print(f"[red]Invalid config:[/red] {config_path}") + for error in errors: + console.print(f" - {error}") + raise typer.Exit(1) + + +@config_app.command("init") +def config_init( + path: Optional[str] = typer.Argument( + None, help="Directory or file path where .chunk-defaults.yaml should be created" + ), +): + """Generate a commented .chunk-defaults.yaml template.""" + config_path = _resolve_chunk_defaults_path(path) + if config_path.exists(): + console.print(f"[yellow]Config already exists at {config_path}[/yellow]") + raise typer.Exit(1) + + config_path.parent.mkdir(parents=True, exist_ok=True) + config_path.write_text(generate_config_template(), encoding="utf-8") + console.print(f"[green]Wrote template config to {config_path}[/green]") + + +@config_app.command("show") +def config_show( + path: Optional[str] = typer.Argument( + None, help="Directory or config file to use as the override source" + ), +): + """Show the effective chunk config.""" + start, agent_dir = _resolve_chunk_defaults_context(path) + effective = _load_effective_chunk_defaults_dict(start=start, agent_dir=agent_dir) + console.print(yaml.safe_dump(effective, sort_keys=False, default_flow_style=False)) def _chunk_version_store(shelf_root: Path): diff --git a/src/shelfai/core/config_schema.py b/src/shelfai/core/config_schema.py new file mode 100644 index 0000000..a30fc58 --- /dev/null +++ b/src/shelfai/core/config_schema.py @@ -0,0 +1,410 @@ +"""JSON Schema for .chunk-defaults.yaml validation.""" + +from __future__ import annotations + +import copy +from pathlib import Path +from typing import Any + +import yaml + +CONFIG_SCHEMA = { + "type": "object", + "properties": { + "always_load": { + "type": "array", + "items": {"type": "string"}, + }, + "chunking": { + "type": "object", + "properties": { + "min_lines": {"type": "integer", "minimum": 1}, + "max_tokens": {"type": "integer", "minimum": 100}, + "heading_level": {"type": "integer", "minimum": 1, "maximum": 6}, + "strategy": {"type": "string", "enum": ["heading", "semantic", "hybrid"]}, + "review_threshold": {"type": "integer", "minimum": 0}, + "model": {"type": "string"}, + }, + "additionalProperties": False, + }, + "classification": { + "type": "object", + "properties": { + "use_transformer": {"type": "boolean"}, + "min_confidence": {"type": "number", "minimum": 0, "maximum": 1}, + "confidence_threshold": {"type": "number", "minimum": 0, "maximum": 1}, + "custom_keywords": { + "type": "object", + "additionalProperties": { + "type": "array", + "items": {"type": "string"}, + }, + }, + }, + "additionalProperties": False, + }, + "telemetry": { + "type": "object", + "properties": { + "enabled": {"type": "boolean"}, + "db_path": {"type": "string"}, + "export_format": {"type": "string", "enum": ["jsonl", "sqlite"]}, + }, + "additionalProperties": False, + }, + "tokens": { + "type": "object", + "properties": { + "model": {"type": "string"}, + "max_budget": {"type": "integer", "minimum": 0}, + "warn_threshold": {"type": "integer", "minimum": 0}, + }, + "additionalProperties": False, + }, + "cache": { + "type": "object", + "properties": { + "enabled": {"type": "boolean"}, + "max_entries": {"type": "integer", "minimum": 1}, + "ttl_seconds": {"type": "integer", "minimum": 0}, + }, + "additionalProperties": False, + }, + }, + "additionalProperties": False, +} + +DEFAULT_CONFIG = { + "always_load": ["soul", "rules", "read-order"], + "classification": { + "use_transformer": False, + "min_confidence": 0.4, + "confidence_threshold": 0.4, + "custom_keywords": {}, + }, + "chunking": { + "min_lines": 150, + "max_tokens": 4000, + "heading_level": 2, + "strategy": "hybrid", + "review_threshold": 80, + "model": "claude-sonnet-4-6", + }, + "telemetry": { + "enabled": True, + "db_path": "shelf/patterns.db", + "export_format": "jsonl", + }, + "tokens": { + "model": "claude-sonnet-4-6", + "max_budget": 0, + "warn_threshold": 0, + }, + "cache": { + "enabled": True, + "max_entries": 1000, + "ttl_seconds": 86400, + }, +} + + +def get_default_config() -> dict: + """Return the full default configuration.""" + return copy.deepcopy(DEFAULT_CONFIG) + + +def validate_config(config: dict) -> tuple[bool, list[str]]: + """ + Validate a config dict against the schema. + Returns (is_valid, list_of_error_messages). + Uses jsonschema if available, falls back to manual validation. + """ + if config is None: + config = {} + if not isinstance(config, dict): + return False, ["Config must be a mapping/object."] + + try: + from jsonschema import Draft7Validator + except Exception: + return _manual_validate(config) + + validator = Draft7Validator(CONFIG_SCHEMA) + errors = sorted(validator.iter_errors(config), key=_jsonschema_error_sort_key) + if not errors: + return True, [] + return False, [_format_jsonschema_error(error) for error in errors] + + +def validate_config_file(path: str) -> tuple[bool, list[str]]: + """Load YAML file and validate.""" + file_path = Path(path) + if not file_path.exists(): + return False, [f"Config file not found: {file_path}"] + + try: + with file_path.open("r", encoding="utf-8") as handle: + config = yaml.safe_load(handle) or {} + except Exception as exc: + return False, [f"Failed to load YAML from {file_path}: {exc}"] + + return validate_config(config) + + +def generate_config_template() -> str: + """Generate a commented .chunk-defaults.yaml template.""" + defaults = get_default_config() + sections = [ + ("always_load", "Chunks that should always be included."), + ("classification", "Keyword and transformer classification settings."), + ("chunking", "Heuristics for chunk creation and router generation."), + ("telemetry", "Chunk telemetry and export settings."), + ("tokens", "Optional token budget configuration."), + ("cache", "Optional cache settings for chunk operations."), + ] + + parts: list[str] = [ + "# ShelfAI chunk defaults", + "# Generated by `shelfai config init`.", + "# Fields you omit inherit the built-in defaults.", + "", + ] + for key, comment in sections: + parts.append(f"# {comment}") + parts.append( + yaml.safe_dump({key: defaults[key]}, sort_keys=False, default_flow_style=False).rstrip() + ) + parts.append("") + + return "\n".join(parts).rstrip() + "\n" + + +def _jsonschema_error_sort_key(error: Any) -> tuple[int, str]: + path = ".".join(str(part) for part in error.path) + return (len(error.path), path) + + +def _format_jsonschema_error(error: Any) -> str: + path = ".".join(str(part) for part in error.path) or "config" + validator = getattr(error, "validator", "") + value = getattr(error, "validator_value", None) + + if validator == "type": + return f"{path} must be a {_type_label(value)}" + if validator == "enum": + return f"{path} must be one of: {', '.join(str(item) for item in value)}" + if validator == "minimum": + return f"{path} must be >= {value}" + if validator == "maximum": + return f"{path} must be <= {value}" + if validator == "additionalProperties": + return f"Unknown key in {path}: {getattr(error, 'message', 'not allowed')}" + return f"{path}: {getattr(error, 'message', 'invalid value')}" + + +def _type_label(expected: Any) -> str: + mapping = { + "object": "object", + "array": "array", + "string": "string", + "integer": "integer", + "number": "number", + "boolean": "boolean", + } + if isinstance(expected, list): + return " or ".join(mapping.get(item, str(item)) for item in expected) + return mapping.get(expected, str(expected)) + + +def _manual_validate(config: dict) -> tuple[bool, list[str]]: + errors: list[str] = [] + allowed_top_level = set(CONFIG_SCHEMA["properties"]) + + for key in sorted(set(config) - allowed_top_level): + errors.append(f"Unknown top-level key: {key}") + + for key in allowed_top_level: + if key in config: + errors.extend(_validate_section(key, config[key])) + + return (len(errors) == 0, errors) + + +def _validate_section(section: str, value: Any) -> list[str]: + if section == "always_load": + return _validate_string_list(value, section) + if section == "classification": + return _validate_classification(value) + if section == "chunking": + return _validate_chunking(value) + if section == "telemetry": + return _validate_telemetry(value) + if section == "tokens": + return _validate_tokens(value) + if section == "cache": + return _validate_cache(value) + return [] + + +def _validate_string_list(value: Any, path: str) -> list[str]: + if not isinstance(value, list): + return [f"{path} must be an array of strings"] + errors = [] + for index, item in enumerate(value): + if not isinstance(item, str): + errors.append(f"{path}[{index}] must be a string") + return errors + + +def _validate_classification(value: Any) -> list[str]: + if not isinstance(value, dict): + return ["classification must be an object"] + + allowed = {"use_transformer", "min_confidence", "confidence_threshold", "custom_keywords"} + errors = [f"Unknown key: classification.{key}" for key in sorted(set(value) - allowed)] + + if "use_transformer" in value and not isinstance(value["use_transformer"], bool): + errors.append("classification.use_transformer must be a boolean") + + for key in ("min_confidence", "confidence_threshold"): + if key in value: + if not _is_number(value[key]): + errors.append(f"classification.{key} must be a number between 0 and 1") + elif not 0 <= float(value[key]) <= 1: + errors.append(f"classification.{key} must be between 0 and 1") + + if "custom_keywords" in value: + custom_keywords = value["custom_keywords"] + if not isinstance(custom_keywords, dict): + errors.append("classification.custom_keywords must be an object") + else: + for keyword_category, keywords in sorted(custom_keywords.items()): + if not isinstance(keywords, list): + errors.append( + f"classification.custom_keywords.{keyword_category} must be an array of strings" + ) + continue + for index, keyword in enumerate(keywords): + if not isinstance(keyword, str): + errors.append( + f"classification.custom_keywords.{keyword_category}[{index}] must be a string" + ) + + return errors + + +def _validate_chunking(value: Any) -> list[str]: + if not isinstance(value, dict): + return ["chunking must be an object"] + + allowed = {"min_lines", "max_tokens", "heading_level", "strategy", "review_threshold", "model"} + errors = [f"Unknown key: chunking.{key}" for key in sorted(set(value) - allowed)] + + if "min_lines" in value: + if not _is_integer(value["min_lines"]): + errors.append("chunking.min_lines must be an integer") + elif value["min_lines"] < 1: + errors.append("chunking.min_lines must be >= 1") + + if "max_tokens" in value: + if not _is_integer(value["max_tokens"]): + errors.append("chunking.max_tokens must be an integer") + elif value["max_tokens"] < 100: + errors.append("chunking.max_tokens must be >= 100") + + if "heading_level" in value: + if not _is_integer(value["heading_level"]): + errors.append("chunking.heading_level must be an integer") + elif not 1 <= value["heading_level"] <= 6: + errors.append("chunking.heading_level must be between 1 and 6") + + if "strategy" in value: + if not isinstance(value["strategy"], str): + errors.append("chunking.strategy must be a string") + elif value["strategy"] not in {"heading", "semantic", "hybrid"}: + errors.append("chunking.strategy must be one of: heading, semantic, hybrid") + + if "review_threshold" in value: + if not _is_integer(value["review_threshold"]): + errors.append("chunking.review_threshold must be an integer") + elif value["review_threshold"] < 0: + errors.append("chunking.review_threshold must be >= 0") + + if "model" in value and not isinstance(value["model"], str): + errors.append("chunking.model must be a string") + + return errors + + +def _validate_telemetry(value: Any) -> list[str]: + if not isinstance(value, dict): + return ["telemetry must be an object"] + + allowed = {"enabled", "db_path", "export_format"} + errors = [f"Unknown key: telemetry.{key}" for key in sorted(set(value) - allowed)] + + if "enabled" in value and not isinstance(value["enabled"], bool): + errors.append("telemetry.enabled must be a boolean") + if "db_path" in value and not isinstance(value["db_path"], str): + errors.append("telemetry.db_path must be a string") + if "export_format" in value: + if not isinstance(value["export_format"], str): + errors.append("telemetry.export_format must be a string") + elif value["export_format"] not in {"jsonl", "sqlite"}: + errors.append("telemetry.export_format must be one of: jsonl, sqlite") + + return errors + + +def _validate_tokens(value: Any) -> list[str]: + if not isinstance(value, dict): + return ["tokens must be an object"] + + allowed = {"model", "max_budget", "warn_threshold"} + errors = [f"Unknown key: tokens.{key}" for key in sorted(set(value) - allowed)] + + if "model" in value and not isinstance(value["model"], str): + errors.append("tokens.model must be a string") + if "max_budget" in value: + if not _is_integer(value["max_budget"]): + errors.append("tokens.max_budget must be an integer") + elif value["max_budget"] < 0: + errors.append("tokens.max_budget must be >= 0") + if "warn_threshold" in value: + if not _is_integer(value["warn_threshold"]): + errors.append("tokens.warn_threshold must be an integer") + elif value["warn_threshold"] < 0: + errors.append("tokens.warn_threshold must be >= 0") + + return errors + + +def _validate_cache(value: Any) -> list[str]: + if not isinstance(value, dict): + return ["cache must be an object"] + + allowed = {"enabled", "max_entries", "ttl_seconds"} + errors = [f"Unknown key: cache.{key}" for key in sorted(set(value) - allowed)] + + if "enabled" in value and not isinstance(value["enabled"], bool): + errors.append("cache.enabled must be a boolean") + if "max_entries" in value: + if not _is_integer(value["max_entries"]): + errors.append("cache.max_entries must be an integer") + elif value["max_entries"] < 1: + errors.append("cache.max_entries must be >= 1") + if "ttl_seconds" in value: + if not _is_integer(value["ttl_seconds"]): + errors.append("cache.ttl_seconds must be an integer") + elif value["ttl_seconds"] < 0: + errors.append("cache.ttl_seconds must be >= 0") + + return errors + + +def _is_integer(value: Any) -> bool: + return isinstance(value, int) and not isinstance(value, bool) + + +def _is_number(value: Any) -> bool: + return (isinstance(value, (int, float)) and not isinstance(value, bool)) diff --git a/tests/test_config_validation.py b/tests/test_config_validation.py new file mode 100644 index 0000000..a398bcb --- /dev/null +++ b/tests/test_config_validation.py @@ -0,0 +1,172 @@ +"""Tests for chunk defaults validation and config CLI commands.""" + +from typer.testing import CliRunner +import yaml + +from shelfai.cli.main import app +from shelfai.core.config_schema import ( + generate_config_template, + get_default_config, + validate_config, + validate_config_file, +) + + +runner = CliRunner() + + +def test_valid_minimal_config(): + is_valid, errors = validate_config({}) + + assert is_valid + assert errors == [] + + +def test_valid_full_config(): + config = get_default_config() + + is_valid, errors = validate_config(config) + + assert is_valid + assert errors == [] + + +def test_invalid_min_lines_zero(): + config = {"chunking": {"min_lines": 0}} + + is_valid, errors = validate_config(config) + + assert not is_valid + assert any("chunking.min_lines" in error for error in errors) + + +def test_invalid_unknown_section(): + config = {"mystery": {"enabled": True}} + + is_valid, errors = validate_config(config) + + assert not is_valid + assert any("Unknown top-level key: mystery" in error for error in errors) + + +def test_invalid_unknown_nested_key(): + config = {"chunking": {"min_lines": 10, "extra": True}} + + is_valid, errors = validate_config(config) + + assert not is_valid + assert any("Unknown key: chunking.extra" in error for error in errors) + + +def test_invalid_strategy_value(): + config = {"chunking": {"strategy": "magic"}} + + is_valid, errors = validate_config(config) + + assert not is_valid + assert any("chunking.strategy" in error for error in errors) + + +def test_invalid_confidence_range(): + config = {"classification": {"confidence_threshold": 1.5}} + + is_valid, errors = validate_config(config) + + assert not is_valid + assert any("classification.confidence_threshold" in error for error in errors) + + +def test_validate_file(tmp_path): + config_path = tmp_path / ".chunk-defaults.yaml" + config_path.write_text( + yaml.safe_dump( + { + "chunking": {"min_lines": 75, "strategy": "heading"}, + "classification": {"use_transformer": True}, + }, + sort_keys=False, + ), + encoding="utf-8", + ) + + is_valid, errors = validate_config_file(str(config_path)) + + assert is_valid + assert errors == [] + + +def test_generate_template(): + template = generate_config_template() + parsed = yaml.safe_load(template) + + assert isinstance(parsed, dict) + is_valid, errors = validate_config(parsed) + assert is_valid + assert errors == [] + + +def test_default_config(): + defaults = get_default_config() + + is_valid, errors = validate_config(defaults) + + assert is_valid + assert errors == [] + assert defaults["always_load"] == ["soul", "rules", "read-order"] + + +def test_error_messages_helpful(): + config = { + "chunking": {"min_lines": "many"}, + "classification": {"custom_keywords": {"rules": [1]}}, + } + + is_valid, errors = validate_config(config) + + assert not is_valid + assert any("chunking.min_lines" in error and "integer" in error for error in errors) + assert any( + "classification.custom_keywords.rules[0]" in error and "string" in error + for error in errors + ) + + +def test_config_init_and_show(tmp_path): + init_result = runner.invoke(app, ["config", "init", str(tmp_path)]) + assert init_result.exit_code == 0 + + config_path = tmp_path / ".chunk-defaults.yaml" + assert config_path.exists() + + show_result = runner.invoke(app, ["config", "show", str(tmp_path)]) + assert show_result.exit_code == 0 + assert "min_lines: 150" in show_result.output + assert "always_load:" in show_result.output + + +def test_show_effective_config(tmp_path): + config_path = tmp_path / ".chunk-defaults.yaml" + config_path.write_text( + "chunking:\n" + " min_lines: 42\n" + "classification:\n" + " use_transformer: true\n", + encoding="utf-8", + ) + + result = runner.invoke(app, ["config", "show", str(tmp_path)]) + + assert result.exit_code == 0 + assert "min_lines: 42" in result.output + assert "use_transformer: true" in result.output.lower() + + +def test_config_validate_cli(tmp_path): + config_path = tmp_path / ".chunk-defaults.yaml" + config_path.write_text("chunking:\n min_lines: 0\n", encoding="utf-8") + + result = runner.invoke(app, ["config", "validate", str(tmp_path)]) + + assert result.exit_code == 1 + assert "Invalid config" in result.output + assert "chunking.min_lines" in result.output