Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions packages/retrieval/src/__tests__/ft-create.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -451,3 +451,21 @@ describe('keyPrefix', () => {
expect(() => keyPrefix('')).toThrow(/Index name must not be empty/);
});
});

describe('buildFtCreateArgs reserved field names', () => {
it('rejects a schema field named __score', () => {
const schema: RetrievalSchema = {
fields: { __score: { type: 'tag' } },
vector: { metric: 'cosine', algorithm: 'hnsw', dims: 4 },
};
expect(() => buildFtCreateArgs('docs', schema)).toThrow(/reserved/i);
});

it('rejects a schema field named __text', () => {
const schema: RetrievalSchema = {
fields: { __text: { type: 'text' } },
vector: { metric: 'cosine', algorithm: 'hnsw', dims: 4 },
};
expect(() => buildFtCreateArgs('docs', schema)).toThrow(/reserved/i);
});
});
48 changes: 48 additions & 0 deletions packages/retrieval/src/__tests__/ft-search.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import { describe, it, expect } from 'vitest';
import { buildFtSearchQuery } from '../ft-search';
import type { RetrievalSchema } from '../schema';

const schema: RetrievalSchema = {
fields: {
source: { type: 'tag' },
title: { type: 'text' },
updated: { type: 'numeric' },
},
vector: { metric: 'cosine', algorithm: 'hnsw', dims: 4 },
};

describe('buildFtSearchQuery', () => {
it('emits a bare KNN query with no filter', () => {
expect(buildFtSearchQuery(schema, 10)).toBe('*=>[KNN 10 @embedding $vec AS __score]');
});

it('wraps a single TAG filter clause', () => {
expect(buildFtSearchQuery(schema, 5, { source: 'docs' })).toBe(
'(@source:{docs})=>[KNN 5 @embedding $vec AS __score]',
);
});

it('joins TAG and NUMERIC clauses with AND semantics', () => {
expect(buildFtSearchQuery(schema, 5, { source: 'docs', updated: 1717200000 })).toBe(
'(@source:{docs} @updated:[1717200000 1717200000])=>[KNN 5 @embedding $vec AS __score]',
);
});

it('escapes TAG filter values', () => {
expect(buildFtSearchQuery(schema, 5, { source: 'a:b c' })).toBe(
'(@source:{a\\:b\\ c})=>[KNN 5 @embedding $vec AS __score]',
);
});

it('throws for a filter on an unknown field', () => {
expect(() => buildFtSearchQuery(schema, 5, { missing: 'x' })).toThrow(/unknown/i);
});

it('throws for a filter on a TEXT field', () => {
expect(() => buildFtSearchQuery(schema, 5, { title: 'x' })).toThrow(/text/i);
});

it('throws when a NUMERIC filter value is not a number', () => {
expect(() => buildFtSearchQuery(schema, 5, { updated: 'recent' })).toThrow(/numeric/i);
});
});
231 changes: 231 additions & 0 deletions packages/retrieval/src/__tests__/query.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
import { describe, it, expect, vi } from 'vitest';
import { encodeFloat32 } from '@betterdb/valkey-search-kit';
import { Retriever } from '../retriever';
import type { RetrievalSchema } from '../schema';
import type { QueryHit } from '../retriever';

const schema: RetrievalSchema = {
fields: { source: { type: 'tag' }, updated: { type: 'numeric' } },
vector: { metric: 'cosine', algorithm: 'hnsw', dims: 4 },
};

interface Row {
key: string;
fields: Record<string, string>;
}

function searchReply(rows: Row[]): unknown[] {
const out: unknown[] = [String(rows.length)];
for (const row of rows) {
out.push(row.key);
const flat: string[] = [];
for (const [field, value] of Object.entries(row.fields)) {
flat.push(field, value);
}
out.push(flat);
}
return out;
}

describe('Retriever query', () => {
it('embeds the text, runs FT.SEARCH, and maps rows to hits', async () => {
const vec = [0.1, 0.2, 0.3, 0.4];
const embedFn = vi.fn(async () => vec);
const reply = searchReply([
{
key: 'docs:doc:1',
fields: {
source: 'docs',
updated: '1717200000',
__text: 'hello world',
__score: '0.12',
embedding: 'rawbytes',
},
},
]);
const call = vi.fn(async () => reply);
const retriever = new Retriever({ client: { call }, name: 'docs', schema, embedFn });

const hits = await retriever.query({ text: 'hi', k: 10, filter: { source: 'docs' } });

expect(embedFn).toHaveBeenCalledWith('hi');
expect(call).toHaveBeenCalledWith(
'FT.SEARCH',
'docs:idx',
'(@source:{docs})=>[KNN 10 @embedding $vec AS __score]',
'PARAMS',
'2',
'vec',
encodeFloat32(vec),
'LIMIT',
'0',
'10',
'DIALECT',
'2',
);
const expected: QueryHit[] = [
{
id: 'doc:1',
score: 0.12,
text: 'hello world',
fields: { source: 'docs', updated: '1717200000' },
},
];
expect(hits).toEqual(expected);
});

it('uses a precomputed vector and does not call embedFn', async () => {
const vec = [0.5, 0.5, 0.5, 0.5];
const embedFn = vi.fn(async () => [0, 0, 0, 0]);
const call = vi.fn(async () => searchReply([]));
const retriever = new Retriever({ client: { call }, name: 'docs', schema, embedFn });

await retriever.query({ vector: vec, k: 5 });

expect(embedFn).not.toHaveBeenCalled();
expect(call).toHaveBeenCalledWith(
'FT.SEARCH',
'docs:idx',
'*=>[KNN 5 @embedding $vec AS __score]',
'PARAMS',
'2',
'vec',
encodeFloat32(vec),
'LIMIT',
'0',
'5',
'DIALECT',
'2',
);
});

it('throws when both text and vector are provided', async () => {
const embedFn = vi.fn(async () => [0, 0, 0, 0]);
const call = vi.fn(async () => searchReply([]));
const retriever = new Retriever({ client: { call }, name: 'docs', schema, embedFn });

await expect(retriever.query({ text: 'a', vector: [1, 2, 3, 4], k: 5 })).rejects.toThrow(
/both/i,
);

expect(call).not.toHaveBeenCalled();
});

it('throws when neither text nor vector is provided', async () => {
const call = vi.fn(async () => searchReply([]));
const retriever = new Retriever({ client: { call }, name: 'docs', schema });

await expect(retriever.query({ k: 5 })).rejects.toThrow(/text or/i);

expect(call).not.toHaveBeenCalled();
});

it('returns an empty array when FT.SEARCH yields no hits', async () => {
const embedFn = vi.fn(async () => [0, 0, 0, 0]);
const call = vi.fn(async () => searchReply([]));
const retriever = new Retriever({ client: { call }, name: 'docs', schema, embedFn });

const hits = await retriever.query({ text: 'x', k: 5 });

expect(hits).toEqual([]);
});

it('reorders hits via rerankFn when hybrid is rerank', async () => {
const embedFn = vi.fn(async () => [0, 0, 0, 0]);
const reply = searchReply([
{ key: 'docs:a', fields: { __text: 'first', __score: '0.9', source: 'docs' } },
{ key: 'docs:b', fields: { __text: 'second', __score: '0.8', source: 'docs' } },
]);
const call = vi.fn(async () => reply);
const rerankFn = vi.fn(async (_queryText: string, hits: QueryHit[]) => [...hits].reverse());
const retriever = new Retriever({ client: { call }, name: 'docs', schema, embedFn, rerankFn });

const hits = await retriever.query({ text: 'q', k: 5, hybrid: 'rerank' });

const passedHits = rerankFn.mock.calls[0][1];
expect(passedHits).toEqual([
{ id: 'a', score: 0.9, text: 'first', fields: { source: 'docs' } },
{ id: 'b', score: 0.8, text: 'second', fields: { source: 'docs' } },
]);
expect(hits.map((h) => h.id)).toEqual(['b', 'a']);
});

it('throws for hybrid rerank without a rerankFn', async () => {
const embedFn = vi.fn(async () => [0, 0, 0, 0]);
const call = vi.fn(async () => searchReply([]));
const retriever = new Retriever({ client: { call }, name: 'docs', schema, embedFn });

await expect(retriever.query({ text: 'q', k: 5, hybrid: 'rerank' })).rejects.toThrow(
/rerankFn/,
);

expect(call).not.toHaveBeenCalled();
});

it('throws for hybrid rerank without text', async () => {
const rerankFn = vi.fn(async (_q: string, hits: QueryHit[]) => hits);
const call = vi.fn(async () => searchReply([]));
const retriever = new Retriever({ client: { call }, name: 'docs', schema, rerankFn });

await expect(retriever.query({ vector: [1, 2, 3, 4], k: 5, hybrid: 'rerank' })).rejects.toThrow(
/text/i,
);

expect(call).not.toHaveBeenCalled();
});

it('throws when k is not a positive integer', async () => {
const embedFn = vi.fn(async () => [0, 0, 0, 0]);
const call = vi.fn(async () => searchReply([]));
const retriever = new Retriever({ client: { call }, name: 'docs', schema, embedFn });

await expect(retriever.query({ text: 'x', k: 0 })).rejects.toThrow(/positive integer/i);

expect(call).not.toHaveBeenCalled();
});

it('throws when a precomputed vector has the wrong dimension', async () => {
const call = vi.fn(async () => searchReply([]));
const retriever = new Retriever({ client: { call }, name: 'docs', schema });

await expect(retriever.query({ vector: [1, 2], k: 5 })).rejects.toThrow(/dimension/i);

expect(call).not.toHaveBeenCalled();
});

it('rejects a precomputed vector that mismatches the inferred (cached) dimension', async () => {
const embedFn = vi.fn(async () => [0, 0, 0, 0]);
const noDims: RetrievalSchema = {
fields: { source: { type: 'tag' } },
vector: { metric: 'cosine', algorithm: 'hnsw' },
};
const call = vi.fn(async (command: string) => {
if (command === 'FT.INFO') {
throw new Error("Unknown index name 'docs:idx'");
}
return searchReply([]);
});
const retriever = new Retriever({ client: { call }, name: 'docs', schema: noDims, embedFn });

await retriever.createIndex();

await expect(retriever.query({ vector: [1, 2], k: 5 })).rejects.toThrow(/dimension/i);
const searchCalls = call.mock.calls.filter((args) => args[0] === 'FT.SEARCH');
expect(searchCalls).toHaveLength(0);
});

it('rejects a precomputed vector against inferred dims before the index is created', async () => {
const embedFn = vi.fn(async () => [0, 0, 0, 0]);
const noDims: RetrievalSchema = {
fields: { source: { type: 'tag' } },
vector: { metric: 'cosine', algorithm: 'hnsw' },
};
const call = vi.fn(async () => searchReply([]));
const retriever = new Retriever({ client: { call }, name: 'docs', schema: noDims, embedFn });

await expect(retriever.query({ vector: [1, 2], k: 5 })).rejects.toThrow(/dimension/i);

const searchCalls = call.mock.calls.filter((args) => args[0] === 'FT.SEARCH');
expect(searchCalls).toHaveLength(0);
});
});
18 changes: 18 additions & 0 deletions packages/retrieval/src/__tests__/upsert-delete.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,24 @@ describe('Retriever upsert', () => {
expect(hsetCalls).toHaveLength(0);
});

it('rejects an entry field named __score', async () => {
const embedFn = vi.fn(fakeEmbed(4));
const call = vi.fn(async () => 'OK');
const retriever = new Retriever({
client: { call },
name: 'docs',
schema: schemaWithDims,
embedFn,
});

await expect(
retriever.upsert([{ id: 'doc:1', text: 'x', fields: { __score: 'oops' } }]),
).rejects.toThrow(/reserved/i);

const hsetCalls = call.mock.calls.filter((args) => args[0] === 'HSET');
expect(hsetCalls).toHaveLength(0);
});

it('probes embedFn once and caches dims across multiple entries', async () => {
const embedFn = vi.fn(fakeEmbed(8));
const call = vi.fn(async () => 'OK');
Expand Down
4 changes: 4 additions & 0 deletions packages/retrieval/src/fields.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
export const TEXT_FIELD = '__text';
export const SCORE_FIELD = '__score';

export const RESERVED_FIELD_NAMES: readonly string[] = [TEXT_FIELD, SCORE_FIELD];
4 changes: 4 additions & 0 deletions packages/retrieval/src/ft-create.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { RetrievalSchema, FtCapabilities, FieldSpec, VectorSpec } from './schema';
import { RESERVED_FIELD_NAMES } from './fields';

const HNSW_DEFAULTS = {
m: 16,
Expand Down Expand Up @@ -29,6 +30,9 @@ function validateFieldNames(fields: Record<string, FieldSpec>, vectorFieldName:
`Field name '${name}' collides with the vector field name '${vectorFieldName}'`,
);
}
if (RESERVED_FIELD_NAMES.includes(name)) {
throw new Error(`Field name '${name}' is reserved and cannot be used in the schema`);
}
}
}

Expand Down
41 changes: 41 additions & 0 deletions packages/retrieval/src/ft-search.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import { escapeTag } from '@betterdb/valkey-search-kit';
import type { RetrievalSchema } from './schema';
import { resolveVectorFieldName } from './ft-create';
import { SCORE_FIELD } from './fields';

export type QueryFilter = Record<string, string | number>;

function buildFilterClause(field: string, value: string | number, schema: RetrievalSchema): string {
const spec = schema.fields[field];
if (spec === undefined) {
throw new Error(`Cannot filter on unknown field '${field}'`);
}
if (spec.type === 'tag') {
return `@${field}:{${escapeTag(String(value))}}`;
}
if (spec.type === 'numeric') {
if (typeof value !== 'number') {
throw new Error(`Numeric filter on field '${field}' requires a number, got: ${typeof value}`);
}
return `@${field}:[${value} ${value}]`;
}
throw new Error(
`Cannot filter on TEXT field '${field}'; only tag and numeric fields are filterable`,
);
}

export function buildFtSearchQuery(
schema: RetrievalSchema,
k: number,
filter?: QueryFilter,
): string {
const vectorField = resolveVectorFieldName(schema.vector);
const clauses: string[] = [];
if (filter !== undefined) {
for (const [field, value] of Object.entries(filter)) {
clauses.push(buildFilterClause(field, value, schema));
}
}
const filterExpr = clauses.length > 0 ? `(${clauses.join(' ')})` : '*';
return `${filterExpr}=>[KNN ${k} @${vectorField} $vec AS ${SCORE_FIELD}]`;
}
Loading
Loading