diff --git a/src/hotmem/db.py b/src/hotmem/db.py index 2f24069..b27600f 100644 --- a/src/hotmem/db.py +++ b/src/hotmem/db.py @@ -20,6 +20,7 @@ from __future__ import annotations import math +import re import sqlite3 import struct from pathlib import Path @@ -47,8 +48,32 @@ ); CREATE INDEX IF NOT EXISTS idx_memories_identifier ON memories(identifier); CREATE INDEX IF NOT EXISTS idx_memories_content_hash ON memories(content_hash); + +CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5( + fact_text, + content='memories', + content_rowid='rowid', + tokenize='porter unicode61' +); + +CREATE TRIGGER IF NOT EXISTS memories_ai AFTER INSERT ON memories BEGIN + INSERT INTO memories_fts(rowid, fact_text) VALUES (new.rowid, new.fact_text); +END; + +CREATE TRIGGER IF NOT EXISTS memories_ad AFTER DELETE ON memories BEGIN + INSERT INTO memories_fts(memories_fts, rowid, fact_text) + VALUES('delete', old.rowid, old.fact_text); +END; + +CREATE TRIGGER IF NOT EXISTS memories_au AFTER UPDATE ON memories BEGIN + INSERT INTO memories_fts(memories_fts, rowid, fact_text) + VALUES('delete', old.rowid, old.fact_text); + INSERT INTO memories_fts(rowid, fact_text) VALUES (new.rowid, new.fact_text); +END; """ +_FTS_TOKEN_RE = re.compile(r"[\w]+") + def _cosine_similarity(blob_a: bytes | None, blob_b: bytes | None) -> float | None: """SQLite UDF: cosine similarity between two packed float32 blobs.""" @@ -70,6 +95,12 @@ def _has_column(conn: sqlite3.Connection, table: str, column: str) -> bool: return any(row["name"] == column for row in rows) +def _fts_query(query: str) -> str: + """Convert free text into a safe FTS5 prefix query.""" + terms = _FTS_TOKEN_RE.findall(query.lower()) + return " ".join(f"{term}*" for term in terms) + + class MemoryDB: """SQLite-backed memory store with cosine similarity UDF.""" @@ -79,9 +110,11 @@ def __init__(self, db_path: str | Path) -> None: self._conn.row_factory = sqlite3.Row self._conn.execute("PRAGMA journal_mode=WAL") self._conn.execute("PRAGMA synchronous=NORMAL") + self._conn.execute("PRAGMA recursive_triggers=ON") self._conn.create_function("cosine_sim", 2, _cosine_similarity) self._conn.executescript(_SCHEMA) self._migrate() + self._conn.execute("INSERT INTO memories_fts(memories_fts) VALUES('rebuild')") _trace.info("init", "database opened", detail={"path": self.db_path}) def _migrate(self) -> None: @@ -147,6 +180,27 @@ def search_with_cosine(self, query_embedding: bytes) -> list[dict[str, Any]]: ).fetchall() return [dict(r) for r in rows] + def fts_search(self, query: str) -> list[dict[str, Any]]: + """Return full-text matches with raw BM25 scores.""" + fts_query = _fts_query(query) + if not fts_query: + return [] + + rows = self._conn.execute( + """SELECT m.id, m.identifier, m.fact_text, m.importance, m.metadata_json, + m.source, bm25(memories_fts) AS bm25_score + FROM memories_fts + JOIN memories AS m ON m.rowid = memories_fts.rowid + WHERE memories_fts MATCH ? + AND ( + m.ttl_seconds IS NULL + OR (strftime('%s', 'now') - strftime('%s', m.created_at)) < m.ttl_seconds + ) + ORDER BY bm25_score ASC""", + (fts_query,), + ).fetchall() + return [dict(r) for r in rows] + def count(self) -> int: """Return total number of stored memories.""" row = self._conn.execute("SELECT COUNT(*) FROM memories").fetchone() diff --git a/src/hotmem/search.py b/src/hotmem/search.py index fcd692b..87425ae 100644 --- a/src/hotmem/search.py +++ b/src/hotmem/search.py @@ -2,7 +2,7 @@ Purpose: Given a query, embed it, retrieve candidates from the DB, apply hybrid scoring - (cosine + keyword overlap + importance), and return LLM-ready message objects. + (cosine + FTS5 BM25 + importance), and return LLM-ready message objects. Interface: search_memories(db, query, top_k, max_chars?) -> list[MessageObject] @@ -23,18 +23,22 @@ # Scoring weights W_COSINE = 0.6 -W_KEYWORD = 0.2 +W_FTS = 0.2 W_IMPORTANCE = 0.2 -def _keyword_overlap(query: str, text: str) -> float: - """Compute Jaccard-like keyword overlap between query and text.""" - q_words = set(query.lower().split()) - t_words = set(text.lower().split()) - if not q_words: - return 0.0 - overlap = q_words & t_words - return len(overlap) / len(q_words) +def _normalize_bm25(rows: list[dict[str, Any]]) -> dict[str, float]: + """Convert raw BM25 scores into 0..1 scores where 1.0 is best.""" + if not rows: + return {} + + scores = [float(row["bm25_score"]) for row in rows] + best = min(scores) + worst = max(scores) + if best == worst: + return {row["id"]: 1.0 for row in rows} + + return {row["id"]: 1.0 - ((float(row["bm25_score"]) - best) / (worst - best)) for row in rows} def search_memories( @@ -55,17 +59,16 @@ def search_memories( # Get all candidates with cosine scores from DB candidates = db.search_with_cosine(query_blob) + fts_scores = _normalize_bm25(db.fts_search(query)) # Apply hybrid scoring scored = [] for row in candidates: cosine_score = row.get("cosine_score") or 0.0 - keyword_score = _keyword_overlap(query, row["fact_text"]) + fts_score = fts_scores.get(row["id"], 0.0) importance = row.get("importance", 0.5) - final_score = ( - W_COSINE * cosine_score + W_KEYWORD * keyword_score + W_IMPORTANCE * importance - ) + final_score = W_COSINE * cosine_score + W_FTS * fts_score + W_IMPORTANCE * importance scored.append({**row, "final_score": final_score}) diff --git a/tests/test_db.py b/tests/test_db.py index 240b50a..f936e6e 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -46,6 +46,34 @@ def test_cosine_search(tmp_db: MemoryDB): assert results[0]["fact_text"] == "the quick brown fox" +def test_fts_search(tmp_db: MemoryDB): + vec = embed_text("invoice validation required") + blob = pack_embedding(vec) + tmp_db.insert( + id="fts1", + identifier="test", + fact_text="invoice validation required", + embedding=blob, + ) + + results = tmp_db.fts_search("invoice valid") + + assert len(results) == 1 + assert results[0]["id"] == "fts1" + assert "bm25_score" in results[0] + + +def test_fts_search_updates_on_replace(tmp_db: MemoryDB): + vec = embed_text("old invoice text") + blob = pack_embedding(vec) + tmp_db.insert(id="same", identifier="test", fact_text="old invoice text", embedding=blob) + tmp_db.insert(id="same", identifier="test", fact_text="new contract text", embedding=blob) + + assert tmp_db.fts_search("old invoice") == [] + results = tmp_db.fts_search("new contract") + assert [r["id"] for r in results] == ["same"] + + def test_all_rows(tmp_db: MemoryDB): vec = embed_text("fact") blob = pack_embedding(vec) diff --git a/tests/test_search.py b/tests/test_search.py index a16bc43..5fd0958 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -47,6 +47,15 @@ def test_search_ranking_uses_importance(tmp_db: MemoryDB): assert results[0]["memory_id"] == "high" +def test_search_ranking_uses_fts(tmp_db: MemoryDB): + _add_fact(tmp_db, "exact", "duplicate invoice risk for vendor x", importance=0.1) + _add_fact(tmp_db, "other", "payment terms are net 30", importance=1.0) + + results = search_memories(tmp_db, "duplicate invoice", top_k=2) + + assert results[0]["memory_id"] == "exact" + + def test_search_empty_db(tmp_db: MemoryDB): results = search_memories(tmp_db, "anything", top_k=5) assert results == []