diff --git a/packages/agent-memory/src/MemoryStore.ts b/packages/agent-memory/src/MemoryStore.ts index d40d807f..3980a714 100644 --- a/packages/agent-memory/src/MemoryStore.ts +++ b/packages/agent-memory/src/MemoryStore.ts @@ -1,23 +1,107 @@ import { randomUUID } from 'node:crypto'; +import { encodeFloat32, parseFtSearchResponse } from '@betterdb/valkey-search-kit'; import { buildMemoryRecord } from './buildMemoryRecord'; -import type { EmbedFn, MemoryStoreClient, RememberOptions } from './types'; +import { buildRecallQuery, SCORE_FIELD } from './buildRecallQuery'; +import { parseMemoryItem } from './parseMemoryItem'; +import { compositeScore, similarityFromDistance, type RecallWeights } from './compositeScore'; +import type { + EmbedFn, + MemoryHit, + MemoryStoreClient, + RecallOptions, + RememberOptions, +} from './types'; + +const DEFAULT_THRESHOLD = 0.25; +const DEFAULT_WEIGHTS: RecallWeights = { similarity: 0.6, recency: 0.25, importance: 0.15 }; +const DEFAULT_HALF_LIFE_SECONDS = 604800; // 7 days +const DEFAULT_RECALL_K = 8; +const RECALL_OVERFETCH = 4; export interface MemoryStoreOptions { client: MemoryStoreClient; name: string; embedFn: EmbedFn; + defaultThreshold?: number; + weights?: RecallWeights; + halfLifeSeconds?: number; } export class MemoryStore { private readonly client: MemoryStoreClient; private readonly name: string; private readonly embedFn: EmbedFn; + private readonly defaultThreshold: number; + private readonly weights: RecallWeights; + private readonly halfLifeSeconds: number; private dims?: number; constructor(options: MemoryStoreOptions) { this.client = options.client; this.name = options.name; this.embedFn = options.embedFn; + this.defaultThreshold = options.defaultThreshold ?? DEFAULT_THRESHOLD; + this.weights = options.weights ?? DEFAULT_WEIGHTS; + this.halfLifeSeconds = options.halfLifeSeconds ?? DEFAULT_HALF_LIFE_SECONDS; + } + + async recall(query: string, options: RecallOptions = {}): Promise { + const k = options.k ?? DEFAULT_RECALL_K; + const threshold = options.threshold ?? this.defaultThreshold; + const weights = options.weights ?? this.weights; + const fetchK = k * RECALL_OVERFETCH; + const tags = options.tags ?? []; + const scope = { + threadId: options.threadId, + agentId: options.agentId, + namespace: options.namespace, + }; + + const vector = await this.embed(query); + const queryString = buildRecallQuery(fetchK, scope, tags); + const raw = await this.client.call( + 'FT.SEARCH', + `${this.name}:mem:idx`, + queryString, + 'PARAMS', + '2', + 'vec', + encodeFloat32(vector), + 'LIMIT', + '0', + String(fetchK), + 'DIALECT', + '2', + ); + + const now = Date.now(); + const hits: MemoryHit[] = []; + for (const hit of parseFtSearchResponse(raw)) { + const rawScore = hit.fields[SCORE_FIELD]; + if (rawScore === undefined || rawScore.trim() === '') { + continue; + } + const distance = Number(rawScore); + if (!Number.isFinite(distance) || distance > threshold) { + continue; + } + const item = parseMemoryItem(this.name, hit); + const ageSeconds = (now - item.createdAt) / 1000; + const score = compositeScore({ + similarity: similarityFromDistance(distance), + ageSeconds, + importance: item.importance, + weights, + halfLifeSeconds: this.halfLifeSeconds, + }); + if (!Number.isFinite(score)) { + continue; + } + hits.push({ item, similarity: distance, score }); + } + + hits.sort((a, b) => b.score - a.score); + return hits.slice(0, k); } async remember(content: string, options: RememberOptions = {}): Promise { diff --git a/packages/agent-memory/src/__tests__/MemoryStore.recall.test.ts b/packages/agent-memory/src/__tests__/MemoryStore.recall.test.ts new file mode 100644 index 00000000..4f00ee3f --- /dev/null +++ b/packages/agent-memory/src/__tests__/MemoryStore.recall.test.ts @@ -0,0 +1,119 @@ +import { describe, it, expect, vi } from 'vitest'; +import { MemoryStore } from '../MemoryStore'; +import { fakeEmbed } from './helpers/fakeEmbed'; +import { mockClient } from './helpers/mockClient'; + +interface Row { + key: string; + fields: Record; +} + +function searchReply(rows: Row[]): unknown[] { + const out: unknown[] = [String(rows.length)]; + for (const row of rows) { + out.push(row.key); + const flat: string[] = []; + for (const [field, value] of Object.entries(row.fields)) { + flat.push(field, value); + } + out.push(flat); + } + return out; +} + +const now = Date.now(); +function baseFields(over: Record): Record { + return { + content: 'c', + importance: '0.5', + tags: '', + created_at: String(now), + last_accessed_at: String(now), + access_count: '0', + ...over, + }; +} + +describe('MemoryStore.recall', () => { + it('embeds the query, runs a widened KNN FT.SEARCH, and returns ranked hits capped at k', async () => { + const embedFn = vi.fn(fakeEmbed(8)); + const reply = searchReply([ + { key: 'mem:mem:a', fields: baseFields({ content: 'closer', __score: '0.1' }) }, + { key: 'mem:mem:b', fields: baseFields({ content: 'farther', __score: '0.6' }) }, + ]); + const client = mockClient((command) => (command === 'FT.SEARCH' ? reply : 'OK')); + const store = new MemoryStore({ client, name: 'mem', embedFn }); + + const hits = await store.recall('what does the user prefer', { + k: 2, + threshold: 1, + threadId: 't1', + tags: ['x'], + }); + + expect(embedFn).toHaveBeenCalledWith('what does the user prefer'); + const search = client.call.mock.calls.find((args) => args[0] === 'FT.SEARCH'); + expect(search?.[1]).toBe('mem:mem:idx'); + // internal k widened to k*4 = 8 + expect(search?.[2]).toBe('(@threadId:{t1} @tags:{x})=>[KNN 8 @vector $vec AS __score]'); + expect(search).toContain('8'); + + expect(hits).toHaveLength(2); + expect(hits[0].item.id).toBe('a'); + expect(hits[0].item.content).toBe('closer'); + expect(hits[0].similarity).toBe(0.1); + expect(hits[0].score).toBeGreaterThan(hits[1].score); + }); + + it('drops candidates beyond the distance threshold', async () => { + const reply = searchReply([ + { key: 'mem:mem:a', fields: baseFields({ __score: '0.1' }) }, + { key: 'mem:mem:b', fields: baseFields({ __score: '0.9' }) }, + ]); + const client = mockClient((command) => (command === 'FT.SEARCH' ? reply : 'OK')); + const store = new MemoryStore({ client, name: 'mem', embedFn: fakeEmbed(8) }); + + const hits = await store.recall('q', { k: 5, threshold: 0.3 }); + + expect(hits.map((h) => h.item.id)).toEqual(['a']); + }); + + it('drops candidates whose distance score is missing or non-numeric', async () => { + const reply = searchReply([ + { key: 'mem:mem:a', fields: baseFields({ __score: '0.1' }) }, + { key: 'mem:mem:b', fields: baseFields({}) }, + ]); + const client = mockClient((command) => (command === 'FT.SEARCH' ? reply : 'OK')); + const store = new MemoryStore({ client, name: 'mem', embedFn: fakeEmbed(8) }); + + const hits = await store.recall('q', { k: 5, threshold: 1 }); + + expect(hits.map((h) => h.item.id)).toEqual(['a']); + }); + + it('drops a candidate whose distance score is empty (not treated as 0)', async () => { + const reply = searchReply([ + { key: 'mem:mem:a', fields: baseFields({ __score: '0.1' }) }, + { key: 'mem:mem:b', fields: baseFields({ __score: ' ' }) }, + ]); + const client = mockClient((command) => (command === 'FT.SEARCH' ? reply : 'OK')); + const store = new MemoryStore({ client, name: 'mem', embedFn: fakeEmbed(8) }); + + const hits = await store.recall('q', { k: 5, threshold: 1 }); + + expect(hits.map((h) => h.item.id)).toEqual(['a']); + }); + + it('drops a candidate whose composite score is NaN (malformed importance)', async () => { + const reply = searchReply([ + { key: 'mem:mem:a', fields: baseFields({ __score: '0.1' }) }, + { key: 'mem:mem:b', fields: baseFields({ __score: '0.1', importance: 'not-a-number' }) }, + ]); + const client = mockClient((command) => (command === 'FT.SEARCH' ? reply : 'OK')); + const store = new MemoryStore({ client, name: 'mem', embedFn: fakeEmbed(8) }); + + const hits = await store.recall('q', { k: 5, threshold: 1 }); + + expect(hits.map((h) => h.item.id)).toEqual(['a']); + }); +}); diff --git a/packages/agent-memory/src/__tests__/buildRecallQuery.test.ts b/packages/agent-memory/src/__tests__/buildRecallQuery.test.ts new file mode 100644 index 00000000..2d771f6e --- /dev/null +++ b/packages/agent-memory/src/__tests__/buildRecallQuery.test.ts @@ -0,0 +1,20 @@ +import { describe, it, expect } from 'vitest'; +import { buildRecallQuery } from '../buildRecallQuery'; + +describe('buildRecallQuery', () => { + it('builds a bare KNN query when there are no filters', () => { + expect(buildRecallQuery(32, {}, [])).toBe('*=>[KNN 32 @vector $vec AS __score]'); + }); + + it('filters by scope and tags with AND semantics', () => { + expect(buildRecallQuery(8, { threadId: 't1', namespace: 'user:1' }, ['pref'])).toBe( + '(@threadId:{t1} @namespace:{user\\:1} @tags:{pref})=>[KNN 8 @vector $vec AS __score]', + ); + }); + + it('escapes scope and tag values', () => { + expect(buildRecallQuery(8, { agentId: 'a:b' }, ['x y'])).toBe( + '(@agentId:{a\\:b} @tags:{x\\ y})=>[KNN 8 @vector $vec AS __score]', + ); + }); +}); diff --git a/packages/agent-memory/src/__tests__/compositeScore.test.ts b/packages/agent-memory/src/__tests__/compositeScore.test.ts new file mode 100644 index 00000000..68b2e631 --- /dev/null +++ b/packages/agent-memory/src/__tests__/compositeScore.test.ts @@ -0,0 +1,72 @@ +import { describe, it, expect } from 'vitest'; +import { compositeScore } from '../compositeScore'; + +const W = { similarity: 0.6, recency: 0.25, importance: 0.15 }; +const HALF = 604800; // 7 days + +describe('compositeScore', () => { + it('decays recency to ~0.5 at one half-life', () => { + const score = compositeScore({ + similarity: 0, + importance: 0, + ageSeconds: HALF, + weights: { similarity: 0, recency: 1, importance: 0 }, + halfLifeSeconds: HALF, + }); + expect(score).toBeCloseTo(0.5, 5); + }); + + it('combines weighted similarity, recency, and importance', () => { + const score = compositeScore({ + similarity: 1, + importance: 1, + ageSeconds: 0, + weights: W, + halfLifeSeconds: HALF, + }); + expect(score).toBeCloseTo(1, 5); + }); + + it('ranks an identical recent match above a distant one', () => { + const identical = compositeScore({ + similarity: 1, + importance: 0.5, + ageSeconds: 0, + weights: W, + halfLifeSeconds: HALF, + }); + const distant = compositeScore({ + similarity: 0.2, + importance: 0.5, + ageSeconds: 0, + weights: W, + halfLifeSeconds: HALF, + }); + expect(identical).toBeGreaterThan(distant); + }); + + it('lets recency promote a recent-but-weaker item over an old-but-closer one', () => { + const recentWeaker = compositeScore({ + similarity: 0.6, + importance: 0.5, + ageSeconds: 0, + weights: W, + halfLifeSeconds: HALF, + }); + const oldCloser = compositeScore({ + similarity: 0.8, + importance: 0.5, + ageSeconds: HALF * 5, + weights: W, + halfLifeSeconds: HALF, + }); + expect(recentWeaker).toBeGreaterThan(oldCloser); + }); + + it('breaks ties by importance', () => { + const base = { similarity: 0.5, ageSeconds: 0, weights: W, halfLifeSeconds: HALF }; + const high = compositeScore({ ...base, importance: 0.9 }); + const low = compositeScore({ ...base, importance: 0.1 }); + expect(high).toBeGreaterThan(low); + }); +}); diff --git a/packages/agent-memory/src/buildRecallQuery.ts b/packages/agent-memory/src/buildRecallQuery.ts new file mode 100644 index 00000000..2702570b --- /dev/null +++ b/packages/agent-memory/src/buildRecallQuery.ts @@ -0,0 +1,23 @@ +import { escapeTag } from '@betterdb/valkey-search-kit'; +import type { MemoryScope } from './types'; + +export const SCORE_FIELD = '__score'; +export const VECTOR_FIELD = 'vector'; + +export function buildRecallQuery(k: number, scope: MemoryScope, tags: string[]): string { + const clauses: string[] = []; + if (scope.threadId !== undefined) { + clauses.push(`@threadId:{${escapeTag(scope.threadId)}}`); + } + if (scope.agentId !== undefined) { + clauses.push(`@agentId:{${escapeTag(scope.agentId)}}`); + } + if (scope.namespace !== undefined) { + clauses.push(`@namespace:{${escapeTag(scope.namespace)}}`); + } + for (const tag of tags) { + clauses.push(`@tags:{${escapeTag(tag)}}`); + } + const filterExpr = clauses.length > 0 ? `(${clauses.join(' ')})` : '*'; + return `${filterExpr}=>[KNN ${k} @${VECTOR_FIELD} $vec AS ${SCORE_FIELD}]`; +} diff --git a/packages/agent-memory/src/compositeScore.ts b/packages/agent-memory/src/compositeScore.ts new file mode 100644 index 00000000..c201265b --- /dev/null +++ b/packages/agent-memory/src/compositeScore.ts @@ -0,0 +1,31 @@ +export interface RecallWeights { + similarity: number; + recency: number; + importance: number; +} + +export interface CompositeScoreParams { + similarity: number; // 0..1, mapped from cosine distance + ageSeconds: number; + importance: number; // 0..1 + weights: RecallWeights; + halfLifeSeconds: number; +} + +/** + * Weighted blend of semantic similarity, recency, and importance. + * Recency is a true half-life decay: 0.5 at one halfLifeSeconds. + */ +export function compositeScore(params: CompositeScoreParams): number { + const recency = Math.exp((-Math.LN2 * params.ageSeconds) / params.halfLifeSeconds); + return ( + params.weights.similarity * params.similarity + + params.weights.recency * recency + + params.weights.importance * params.importance + ); +} + +/** Map cosine distance (0..2, lower = closer) to a 0..1 similarity score. */ +export function similarityFromDistance(distance: number): number { + return 1 - distance / 2; +} diff --git a/packages/agent-memory/src/index.ts b/packages/agent-memory/src/index.ts index 1599ab13..d8e7422c 100644 --- a/packages/agent-memory/src/index.ts +++ b/packages/agent-memory/src/index.ts @@ -2,4 +2,14 @@ export * from '@betterdb/agent-cache'; export { MemoryStore } from './MemoryStore'; export type { MemoryStoreOptions } from './MemoryStore'; export { AgentMemory } from './AgentMemory'; -export type { EmbedFn, MemoryStoreClient, MemoryScope, RememberOptions } from './types'; +export type { + EmbedFn, + MemoryStoreClient, + MemoryScope, + RememberOptions, + MemoryItem, + RecallOptions, + MemoryHit, +} from './types'; +export { compositeScore, similarityFromDistance } from './compositeScore'; +export type { RecallWeights, CompositeScoreParams } from './compositeScore'; diff --git a/packages/agent-memory/src/parseMemoryItem.ts b/packages/agent-memory/src/parseMemoryItem.ts new file mode 100644 index 00000000..aa5d967f --- /dev/null +++ b/packages/agent-memory/src/parseMemoryItem.ts @@ -0,0 +1,36 @@ +import type { FtSearchHit } from '@betterdb/valkey-search-kit'; +import type { MemoryItem } from './types'; + +export function parseMemoryItem(name: string, hit: FtSearchHit): MemoryItem { + const prefix = `${name}:mem:`; + let id = hit.key; + if (hit.key.startsWith(prefix)) { + id = hit.key.slice(prefix.length); + } + + const fields = hit.fields; + const item: MemoryItem = { + id, + content: fields.content ?? '', + importance: parseFloat(fields.importance ?? '0'), + tags: fields.tags ? fields.tags.split(',') : [], + createdAt: parseInt(fields.created_at ?? '0', 10), + lastAccessedAt: parseInt(fields.last_accessed_at ?? '0', 10), + accessCount: parseInt(fields.access_count ?? '0', 10), + }; + + if (fields.source !== undefined) { + item.source = fields.source; + } + if (fields.threadId !== undefined) { + item.threadId = fields.threadId; + } + if (fields.agentId !== undefined) { + item.agentId = fields.agentId; + } + if (fields.namespace !== undefined) { + item.namespace = fields.namespace; + } + + return item; +} diff --git a/packages/agent-memory/src/types.ts b/packages/agent-memory/src/types.ts index a54f6251..2a697e4f 100644 --- a/packages/agent-memory/src/types.ts +++ b/packages/agent-memory/src/types.ts @@ -1,3 +1,5 @@ +import type { RecallWeights } from './compositeScore'; + export type EmbedFn = (text: string) => Promise; export interface MemoryStoreClient { @@ -15,3 +17,35 @@ export interface RememberOptions extends MemoryScope { tags?: string[]; source?: string; } + +export interface MemoryItem extends MemoryScope { + id: string; + content: string; + importance: number; + tags: string[]; + source?: string; + createdAt: number; + lastAccessedAt: number; + accessCount: number; +} + +export interface RecallOptions extends MemoryScope { + k?: number; + threshold?: number; + tags?: string[]; + weights?: RecallWeights; + reinforce?: boolean; +} + +export interface MemoryHit { + item: MemoryItem; + /** + * Raw KNN vector **distance** (cosine), not a similarity: lower means closer + * (a perfect match approaches 0). Despite the field name, do not assume + * higher is better — sort ascending if ranking by this alone. The composite + * `score` (higher is better) is the field to rank recall results by. + */ + similarity: number; + /** Composite recall score (similarity + recency + importance); higher is better. */ + score: number; +}