Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions src/hotmem/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from __future__ import annotations

import math
import re
import sqlite3
import struct
from pathlib import Path
Expand Down Expand Up @@ -47,8 +48,32 @@
);
CREATE INDEX IF NOT EXISTS idx_memories_identifier ON memories(identifier);
CREATE INDEX IF NOT EXISTS idx_memories_content_hash ON memories(content_hash);

CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
fact_text,
content='memories',
content_rowid='rowid',
tokenize='porter unicode61'
);

CREATE TRIGGER IF NOT EXISTS memories_ai AFTER INSERT ON memories BEGIN
INSERT INTO memories_fts(rowid, fact_text) VALUES (new.rowid, new.fact_text);
END;

CREATE TRIGGER IF NOT EXISTS memories_ad AFTER DELETE ON memories BEGIN
INSERT INTO memories_fts(memories_fts, rowid, fact_text)
VALUES('delete', old.rowid, old.fact_text);
END;

CREATE TRIGGER IF NOT EXISTS memories_au AFTER UPDATE ON memories BEGIN
INSERT INTO memories_fts(memories_fts, rowid, fact_text)
VALUES('delete', old.rowid, old.fact_text);
INSERT INTO memories_fts(rowid, fact_text) VALUES (new.rowid, new.fact_text);
END;
"""

_FTS_TOKEN_RE = re.compile(r"[\w]+")


def _cosine_similarity(blob_a: bytes | None, blob_b: bytes | None) -> float | None:
"""SQLite UDF: cosine similarity between two packed float32 blobs."""
Expand All @@ -70,6 +95,12 @@ def _has_column(conn: sqlite3.Connection, table: str, column: str) -> bool:
return any(row["name"] == column for row in rows)


def _fts_query(query: str) -> str:
"""Convert free text into a safe FTS5 prefix query."""
terms = _FTS_TOKEN_RE.findall(query.lower())
return " ".join(f"{term}*" for term in terms)


class MemoryDB:
"""SQLite-backed memory store with cosine similarity UDF."""

Expand All @@ -79,9 +110,11 @@ def __init__(self, db_path: str | Path) -> None:
self._conn.row_factory = sqlite3.Row
self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute("PRAGMA synchronous=NORMAL")
self._conn.execute("PRAGMA recursive_triggers=ON")
self._conn.create_function("cosine_sim", 2, _cosine_similarity)
self._conn.executescript(_SCHEMA)
self._migrate()
self._conn.execute("INSERT INTO memories_fts(memories_fts) VALUES('rebuild')")
_trace.info("init", "database opened", detail={"path": self.db_path})

def _migrate(self) -> None:
Expand Down Expand Up @@ -147,6 +180,27 @@ def search_with_cosine(self, query_embedding: bytes) -> list[dict[str, Any]]:
).fetchall()
return [dict(r) for r in rows]

def fts_search(self, query: str) -> list[dict[str, Any]]:
"""Return full-text matches with raw BM25 scores."""
fts_query = _fts_query(query)
if not fts_query:
return []

rows = self._conn.execute(
"""SELECT m.id, m.identifier, m.fact_text, m.importance, m.metadata_json,
m.source, bm25(memories_fts) AS bm25_score
FROM memories_fts
JOIN memories AS m ON m.rowid = memories_fts.rowid
WHERE memories_fts MATCH ?
AND (
m.ttl_seconds IS NULL
OR (strftime('%s', 'now') - strftime('%s', m.created_at)) < m.ttl_seconds
)
ORDER BY bm25_score ASC""",
(fts_query,),
).fetchall()
return [dict(r) for r in rows]

def count(self) -> int:
"""Return total number of stored memories."""
row = self._conn.execute("SELECT COUNT(*) FROM memories").fetchone()
Expand Down
31 changes: 17 additions & 14 deletions src/hotmem/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Purpose:
Given a query, embed it, retrieve candidates from the DB, apply hybrid scoring
(cosine + keyword overlap + importance), and return LLM-ready message objects.
(cosine + FTS5 BM25 + importance), and return LLM-ready message objects.

Interface:
search_memories(db, query, top_k, max_chars?) -> list[MessageObject]
Expand All @@ -23,18 +23,22 @@

# Scoring weights
W_COSINE = 0.6
W_KEYWORD = 0.2
W_FTS = 0.2
W_IMPORTANCE = 0.2


def _keyword_overlap(query: str, text: str) -> float:
"""Compute Jaccard-like keyword overlap between query and text."""
q_words = set(query.lower().split())
t_words = set(text.lower().split())
if not q_words:
return 0.0
overlap = q_words & t_words
return len(overlap) / len(q_words)
def _normalize_bm25(rows: list[dict[str, Any]]) -> dict[str, float]:
"""Convert raw BM25 scores into 0..1 scores where 1.0 is best."""
if not rows:
return {}

scores = [float(row["bm25_score"]) for row in rows]
best = min(scores)
worst = max(scores)
if best == worst:
return {row["id"]: 1.0 for row in rows}

return {row["id"]: 1.0 - ((float(row["bm25_score"]) - best) / (worst - best)) for row in rows}


def search_memories(
Expand All @@ -55,17 +59,16 @@ def search_memories(

# Get all candidates with cosine scores from DB
candidates = db.search_with_cosine(query_blob)
fts_scores = _normalize_bm25(db.fts_search(query))

# Apply hybrid scoring
scored = []
for row in candidates:
cosine_score = row.get("cosine_score") or 0.0
keyword_score = _keyword_overlap(query, row["fact_text"])
fts_score = fts_scores.get(row["id"], 0.0)
importance = row.get("importance", 0.5)

final_score = (
W_COSINE * cosine_score + W_KEYWORD * keyword_score + W_IMPORTANCE * importance
)
final_score = W_COSINE * cosine_score + W_FTS * fts_score + W_IMPORTANCE * importance

scored.append({**row, "final_score": final_score})

Expand Down
28 changes: 28 additions & 0 deletions tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,34 @@ def test_cosine_search(tmp_db: MemoryDB):
assert results[0]["fact_text"] == "the quick brown fox"


def test_fts_search(tmp_db: MemoryDB):
vec = embed_text("invoice validation required")
blob = pack_embedding(vec)
tmp_db.insert(
id="fts1",
identifier="test",
fact_text="invoice validation required",
embedding=blob,
)

results = tmp_db.fts_search("invoice valid")

assert len(results) == 1
assert results[0]["id"] == "fts1"
assert "bm25_score" in results[0]


def test_fts_search_updates_on_replace(tmp_db: MemoryDB):
vec = embed_text("old invoice text")
blob = pack_embedding(vec)
tmp_db.insert(id="same", identifier="test", fact_text="old invoice text", embedding=blob)
tmp_db.insert(id="same", identifier="test", fact_text="new contract text", embedding=blob)

assert tmp_db.fts_search("old invoice") == []
results = tmp_db.fts_search("new contract")
assert [r["id"] for r in results] == ["same"]


def test_all_rows(tmp_db: MemoryDB):
vec = embed_text("fact")
blob = pack_embedding(vec)
Expand Down
9 changes: 9 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ def test_search_ranking_uses_importance(tmp_db: MemoryDB):
assert results[0]["memory_id"] == "high"


def test_search_ranking_uses_fts(tmp_db: MemoryDB):
_add_fact(tmp_db, "exact", "duplicate invoice risk for vendor x", importance=0.1)
_add_fact(tmp_db, "other", "payment terms are net 30", importance=1.0)

results = search_memories(tmp_db, "duplicate invoice", top_k=2)

assert results[0]["memory_id"] == "exact"


def test_search_empty_db(tmp_db: MemoryDB):
results = search_memories(tmp_db, "anything", top_k=5)
assert results == []
Expand Down
Loading