From 2b1fd5cb0d7e503271001ce49f163dd217a14146 Mon Sep 17 00:00:00 2001 From: sauravpanda Date: Tue, 14 Apr 2026 22:13:31 -0700 Subject: [PATCH] feat: add Flare WASM inference engine integration Integrates Flare (pure Rust to WASM, standard GGUF) as a third BrowserAI engine backend alongside MLC WebLLM and Transformers.js. - Add FlareConfig type and update ModelConfig union - Add flare-models.json with 6 GGUF registry entries (SmolLM2-135M/360M, Qwen2.5-0.5B, Llama-3.2-1B) - Add FlareEngineWrapper with OPFS caching, streaming, LoRA support, and progressive loading hook - Wire Flare into BrowserAI class with loadAdapter, isFlareModelCached, clearFlareModelCache - Export FlareEngineWrapper, flareModels, and OPFS helpers from main index Closes #295, #296, #297, #298, #300. Part of #293. --- src/config/models/flare-models.json | 128 +++++++ src/config/models/types.ts | 14 +- src/core/llm/index.ts | 62 +++- src/engines/flare-engine-wrapper.ts | 554 ++++++++++++++++++++++++++++ src/index.ts | 8 + 5 files changed, 759 insertions(+), 7 deletions(-) create mode 100644 src/config/models/flare-models.json create mode 100644 src/engines/flare-engine-wrapper.ts diff --git a/src/config/models/flare-models.json b/src/config/models/flare-models.json new file mode 100644 index 0000000..3bdb899 --- /dev/null +++ b/src/config/models/flare-models.json @@ -0,0 +1,128 @@ +{ + "smollm2-135m-flare": { + "engine": "flare", + "modelName": "SmolLM2-135M-Instruct", + "modelType": "text-generation", + "repo": "HuggingFaceTB/smollm2-135M-instruct-GGUF", + "url": "https://huggingface.co/HuggingFaceTB/smollm2-135M-instruct-GGUF/resolve/main/smollm2-135m-instruct-q8_0.gguf", + "pipeline": "text-generation", + "defaultQuantization": "Q8_0", + "quantizations": ["Q8_0"], + "architecture": "llama", + "downloadSizeMB": 138, + "contextLength": 2048, + "defaultParams": { + "temperature": 0.7, + "maxTokens": 512 + }, + "metadata": { + "description": "Smallest Flare model — instant load, great for demos", + "tier": 1 + } + }, + "smollm2-135m-flare-q4": { + "engine": "flare", + "modelName": "SmolLM2-135M-Instruct-Q4", + "modelType": "text-generation", + "repo": "HuggingFaceTB/smollm2-135M-instruct-GGUF", + "url": "https://huggingface.co/HuggingFaceTB/smollm2-135M-instruct-GGUF/resolve/main/smollm2-135m-instruct-q4_k_m.gguf", + "pipeline": "text-generation", + "defaultQuantization": "Q4_K_M", + "quantizations": ["Q4_K_M"], + "architecture": "llama", + "downloadSizeMB": 75, + "contextLength": 2048, + "defaultParams": { + "temperature": 0.7, + "maxTokens": 512 + }, + "metadata": { + "description": "Smallest download (~75 MB), great for bandwidth-constrained environments", + "tier": 1 + } + }, + "smollm2-360m-flare": { + "engine": "flare", + "modelName": "SmolLM2-360M-Instruct", + "modelType": "text-generation", + "repo": "HuggingFaceTB/smollm2-360M-instruct-GGUF", + "url": "https://huggingface.co/HuggingFaceTB/smollm2-360M-instruct-GGUF/resolve/main/smollm2-360m-instruct-q8_0.gguf", + "pipeline": "text-generation", + "defaultQuantization": "Q8_0", + "quantizations": ["Q8_0"], + "architecture": "llama", + "downloadSizeMB": 350, + "contextLength": 2048, + "defaultParams": { + "temperature": 0.7, + "maxTokens": 512 + }, + "metadata": { + "description": "Better quality than 135M while still loading quickly", + "tier": 2 + } + }, + "qwen2.5-0.5b-flare": { + "engine": "flare", + "modelName": "Qwen2.5-0.5B-Instruct", + "modelType": "text-generation", + "repo": "Qwen/Qwen2.5-0.5B-Instruct-GGUF", + "url": "https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct-GGUF/resolve/main/qwen2.5-0.5b-instruct-q4_k_m.gguf", + "pipeline": "text-generation", + "defaultQuantization": "Q4_K_M", + "quantizations": ["Q4_K_M"], + "architecture": "qwen2", + "downloadSizeMB": 350, + "contextLength": 4096, + "defaultParams": { + "temperature": 0.7, + "maxTokens": 512 + }, + "metadata": { + "description": "Multilingual model with strong reasoning — Alibaba Qwen2.5 family", + "tier": 2 + } + }, + "llama-3.2-1b-flare": { + "engine": "flare", + "modelName": "Llama-3.2-1B-Instruct", + "modelType": "text-generation", + "repo": "bartowski/Llama-3.2-1B-Instruct-GGUF", + "url": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q8_0.gguf", + "pipeline": "text-generation", + "defaultQuantization": "Q8_0", + "quantizations": ["Q8_0"], + "architecture": "llama", + "downloadSizeMB": 1200, + "contextLength": 4096, + "defaultParams": { + "temperature": 0.7, + "maxTokens": 512 + }, + "metadata": { + "description": "Best quality in the Flare tier — Meta Llama 3.2 1B full precision Q8", + "tier": 3 + } + }, + "llama-3.2-1b-flare-q4": { + "engine": "flare", + "modelName": "Llama-3.2-1B-Instruct-Q4", + "modelType": "text-generation", + "repo": "bartowski/Llama-3.2-1B-Instruct-GGUF", + "url": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q4_K_M.gguf", + "pipeline": "text-generation", + "defaultQuantization": "Q4_K_M", + "quantizations": ["Q4_K_M"], + "architecture": "llama", + "downloadSizeMB": 600, + "contextLength": 4096, + "defaultParams": { + "temperature": 0.7, + "maxTokens": 512 + }, + "metadata": { + "description": "Balanced quality/size — Llama 3.2 1B at Q4_K_M quantization", + "tier": 3 + } + } +} diff --git a/src/config/models/types.ts b/src/config/models/types.ts index 46068a5..a15f11d 100644 --- a/src/config/models/types.ts +++ b/src/config/models/types.ts @@ -44,4 +44,16 @@ export interface DemucsConfig extends BaseModelConfig { executionProviders?: ('webgpu' | 'wasm')[]; } -export type ModelConfig = MLCConfig | TransformersConfig | DemucsConfig; +export interface FlareConfig extends BaseModelConfig { + engine: 'flare'; + /** Direct URL to the GGUF model file (overrides registry URL) */ + url?: string; + /** Model architecture hint (e.g. 'llama', 'mistral', 'qwen2') */ + architecture?: string; + /** Quantization string (e.g. 'Q8_0', 'Q4_K_M') */ + quantization?: string; + /** Approximate download size in MB */ + downloadSizeMB?: number; +} + +export type ModelConfig = MLCConfig | TransformersConfig | DemucsConfig | FlareConfig; diff --git a/src/core/llm/index.ts b/src/core/llm/index.ts index 3486691..e659487 100644 --- a/src/core/llm/index.ts +++ b/src/core/llm/index.ts @@ -4,20 +4,23 @@ import { MLCEngineWrapper } from '../../engines/mlc-engine-wrapper'; import { TransformersEngineWrapper } from '../../engines/transformer-engine-wrapper'; import { DemucsEngine } from '../../engines/demucs-engine'; import type { SeparateOptions, SeparationResult } from '../../engines/demucs-engine'; -import type { ModelConfig, MLCConfig, TransformersConfig, DemucsConfig } from '../../config/models/types'; +import { FlareEngineWrapper, FlareAdapterOptions } from '../../engines/flare-engine-wrapper'; +import type { ModelConfig, MLCConfig, TransformersConfig, DemucsConfig, FlareConfig } from '../../config/models/types'; import mlcModels from '../../config/models/mlc-models.json'; import transformersModels from '../../config/models/transformers-models.json'; import demucsModels from '../../config/models/demucs-models.json'; +import flareModels from '../../config/models/flare-models.json'; // Combine model configurations const MODEL_CONFIG: Record = { ...(mlcModels as Record), ...(transformersModels as Record), ...(demucsModels as Record), + ...(flareModels as Record), }; export class BrowserAI { - private engine: MLCEngineWrapper | TransformersEngineWrapper | DemucsEngine | null; + private engine: MLCEngineWrapper | TransformersEngineWrapper | DemucsEngine | FlareEngineWrapper | null; public currentModel: ModelConfig | null; private mediaRecorder: MediaRecorder | null = null; private mediaStream: MediaStream | null = null; @@ -46,13 +49,12 @@ export class BrowserAI { throw new Error(`Model identifier "${this.modelIdentifier}" not recognized.`); } - // Check if model exists in both MLC and Transformers configs + // Check if model exists in MLC config (preferred for text-generation) const mlcVersion = (mlcModels as Record)[this.modelIdentifier]; - // const transformersVersion = (transformersModels as Record)[modelIdentifier]; - // For text-generation models, prefer MLC if available + // For text-generation models, prefer MLC if available (unless explicitly requesting flare) let engineToUse = modelConfig.engine; - if (modelConfig.modelType === 'text-generation' && mlcVersion) { + if (modelConfig.modelType === 'text-generation' && mlcVersion && engineToUse !== 'flare') { engineToUse = 'mlc'; } @@ -69,6 +71,12 @@ export class BrowserAI { this.engine = new DemucsEngine(); await this.engine.loadModel(modelConfig, options); break; + case 'flare': { + const flareEngine = new FlareEngineWrapper(); + await flareEngine.loadModel(modelConfig as FlareConfig, options); + this.engine = flareEngine; + break; + } default: throw new Error(`Engine "${engineToUse}" not supported.`); } @@ -100,6 +108,11 @@ export class BrowserAI { if (this.engine instanceof DemucsEngine) { throw new Error('Current engine does not support embeddings.'); } + if (this.engine instanceof FlareEngineWrapper) { + throw new Error( + 'Flare engine does not support embeddings. Use a Transformers.js feature-extraction model instead.', + ); + } return await this.engine.embed(input, options); } @@ -270,6 +283,43 @@ export class BrowserAI { throw new Error('Current engine does not support multimodal generation'); } + /** + * Load a LoRA adapter into the current Flare engine. + * + * Only supported when using the Flare engine. The adapter must be in + * SafeTensors format and compatible with the loaded base model. + * + * @example + * ```ts + * await ai.loadModel('llama-3.2-1b-flare'); + * await ai.loadAdapter({ url: 'https://hf.co/.../adapter.safetensors' }); + * ``` + */ + async loadAdapter(options: FlareAdapterOptions): Promise { + if (!(this.engine instanceof FlareEngineWrapper)) { + throw new Error('loadAdapter is only supported with the Flare engine.'); + } + return this.engine.loadAdapter(options); + } + + /** + * Check whether the current Flare model is cached in OPFS for instant reload. + */ + async isFlareModelCached(): Promise { + if (!(this.engine instanceof FlareEngineWrapper)) return false; + return this.engine.isCached(); + } + + /** + * Delete the OPFS cache entry for the current Flare model. + */ + async clearFlareModelCache(): Promise { + if (!(this.engine instanceof FlareEngineWrapper)) { + throw new Error('clearFlareModelCache is only supported with the Flare engine.'); + } + return this.engine.clearCache(); + } + async clearModelCache(): Promise { try { const cacheNames = ['webllm/config', 'webllm/wasm', 'webllm/model']; diff --git a/src/engines/flare-engine-wrapper.ts b/src/engines/flare-engine-wrapper.ts new file mode 100644 index 0000000..d554d21 --- /dev/null +++ b/src/engines/flare-engine-wrapper.ts @@ -0,0 +1,554 @@ +/** + * FlareEngineWrapper — BrowserAI adapter for the Flare WASM inference engine. + * + * Flare is a pure Rust → WASM engine that runs standard GGUF files directly + * (no TVM compilation step). It supports WebGPU acceleration, OPFS caching for + * instant repeat loads, LoRA adapter merging, and progressive model loading. + * + * The `@aspect/flare` npm package must be installed for this engine to work: + * npm install @aspect/flare + * + * Resolves issues: #293 #295 #296 #297 #298 #300 + */ + +import { FlareConfig } from '../config/models/types'; + +// Flare WASM API types (from @aspect/flare) +interface FlareEngineWasm { + load(bytes: Uint8Array): FlareEngineInstance; +} +interface FlareEngineInstance { + init_gpu(): Promise; + apply_chat_template(userMessage: string, systemMessage: string): string; + encode_text(text: string): Uint32Array; + decode_ids(ids: Uint32Array): string; + decode_token(id: number): string; + decode_token_chunk(id: number): string; + generate_text(prompt: string, maxTokens: number): string; + generate_text_with_params( + prompt: string, + maxTokens: number, + temperature: number, + topP: number, + topK: number, + repeatPenalty: number, + minP: number, + ): string; + begin_stream(promptTokens: Uint32Array, maxTokens: number): void; + begin_stream_with_params( + promptTokens: Uint32Array, + maxTokens: number, + temperature: number, + topP: number, + topK: number, + repeatPenalty: number, + minP: number, + ): void; + next_token(): number | undefined; + stop_stream(): void; + readonly stream_done: boolean; + readonly stream_stop_reason: string; + readonly max_seq_len: number; + readonly tokens_used: number; + readonly chat_template_name: string; + readonly model_name: string; + readonly architecture: string; + readonly metadata_json: string; + merge_lora(adapterBytes: Uint8Array): void; + merge_lora_with_alpha(adapterBytes: Uint8Array, alpha: number): void; + reset(): void; + add_stop_sequence(seq: string): void; + clear_stop_sequences(): void; +} + +interface FlareModule { + default: () => Promise; + FlareEngine: FlareEngineWasm; + webgpu_available: () => boolean; + is_model_cached: (name: string) => Promise; + cache_model: (name: string, data: Uint8Array) => Promise; + load_cached_model: (name: string) => Promise; +} + +export interface FlareLoadOptions { + /** Progress callback — (loadedBytes, totalBytes) */ + onProgress?: (loaded: number, total: number) => void; + /** Enable WebGPU acceleration (default: true) */ + useGpu?: boolean; + /** Override the GGUF download URL */ + url?: string; + /** System prompt injected into every conversation turn */ + systemPrompt?: string; +} + +export interface FlareGenerateOptions { + max_tokens?: number; + temperature?: number; + top_p?: number; + top_k?: number; + repeat_penalty?: number; + min_p?: number; + /** Per-token callback — called with each decoded token string during streaming */ + onToken?: (token: string) => void; + /** System prompt for this generation (overrides instance-level systemPrompt) */ + system?: string; + /** Stop sequences — generation halts when one of these appears in the output */ + stop?: string[]; + /** Whether to stream (return full text) or not */ + stream?: boolean; +} + +export interface FlareAdapterOptions { + /** URL to fetch the SafeTensors LoRA adapter file */ + url: string; + /** Alpha scaling factor (overrides the value in the adapter file) */ + alpha?: number; +} + +const OPFS_CACHE_DIR = 'flare-models'; + +/** + * Fetch a file with download progress reporting. + */ +async function fetchWithProgress( + url: string, + onProgress?: (loaded: number, total: number) => void, +): Promise { + const response = await fetch(url); + if (!response.ok) { + throw new Error(`Failed to fetch ${url}: ${response.status} ${response.statusText}`); + } + + const contentLength = response.headers.get('Content-Length'); + const total = contentLength ? parseInt(contentLength, 10) : 0; + + if (!response.body || !onProgress) { + const buffer = await response.arrayBuffer(); + onProgress?.(buffer.byteLength, buffer.byteLength); + return new Uint8Array(buffer); + } + + const reader = response.body.getReader(); + const chunks: Uint8Array[] = []; + let loaded = 0; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + chunks.push(value); + loaded += value.byteLength; + onProgress(loaded, total || loaded); + } + + const allBytes = new Uint8Array(loaded); + let offset = 0; + for (const chunk of chunks) { + allBytes.set(chunk, offset); + offset += chunk.byteLength; + } + return allBytes; +} + +/** + * Try to read model bytes from the OPFS cache. + * Returns null if OPFS is unavailable or the model is not cached. + */ +async function readFromOpfs(cacheKey: string): Promise { + try { + const root = await navigator.storage.getDirectory(); + const dir = await root.getDirectoryHandle(OPFS_CACHE_DIR, { create: false }); + const fileHandle = await dir.getFileHandle(cacheKey, { create: false }); + const file = await fileHandle.getFile(); + const buffer = await file.arrayBuffer(); + return new Uint8Array(buffer); + } catch { + return null; + } +} + +/** + * Write model bytes to the OPFS cache (fire-and-forget). + */ +async function writeToOpfs(cacheKey: string, data: Uint8Array): Promise { + const root = await navigator.storage.getDirectory(); + const dir = await root.getDirectoryHandle(OPFS_CACHE_DIR, { create: true }); + const fileHandle = await dir.getFileHandle(cacheKey, { create: true }); + const writable = await fileHandle.createWritable(); + // Copy to a fresh ArrayBuffer to satisfy FileSystemWriteChunkType (avoids SharedArrayBuffer issue) + const plain = new ArrayBuffer(data.byteLength); + new Uint8Array(plain).set(data); + await writable.write(plain); + await writable.close(); +} + +/** + * Check whether a model is present in the OPFS cache. + */ +export async function isModelCached(cacheKey: string): Promise { + try { + const root = await navigator.storage.getDirectory(); + const dir = await root.getDirectoryHandle(OPFS_CACHE_DIR, { create: false }); + await dir.getFileHandle(cacheKey, { create: false }); + return true; + } catch { + return false; + } +} + +/** + * Delete a cached model from OPFS. + */ +export async function deleteCachedModel(cacheKey: string): Promise { + try { + const root = await navigator.storage.getDirectory(); + const dir = await root.getDirectoryHandle(OPFS_CACHE_DIR, { create: false }); + await dir.removeEntry(cacheKey); + } catch { + // Ignore if not found + } +} + +/** + * List all model cache keys stored in OPFS. + */ +export async function listCachedModels(): Promise { + try { + const root = await navigator.storage.getDirectory(); + const dir = await root.getDirectoryHandle(OPFS_CACHE_DIR, { create: false }); + const keys: string[] = []; + for await (const [name] of dir as unknown as AsyncIterable<[string, FileSystemHandle]>) { + keys.push(name); + } + return keys; + } catch { + return []; + } +} + +/** + * Build an OpenAI-compatible chat completions response object. + */ +function buildChatResponse(content: string, promptTokens: number, completionTokens: number, stopReason: string) { + return { + id: `flare-${Date.now()}`, + object: 'chat.completion', + created: Math.floor(Date.now() / 1000), + model: 'flare', + choices: [ + { + index: 0, + message: { role: 'assistant', content }, + finish_reason: stopReason === 'eos' ? 'stop' : stopReason === 'length' ? 'length' : 'stop', + }, + ], + usage: { + prompt_tokens: promptTokens, + completion_tokens: completionTokens, + total_tokens: promptTokens + completionTokens, + }, + }; +} + +/** + * BrowserAI engine adapter for Flare. + * + * Implements the same interface as MLCEngineWrapper / TransformersEngineWrapper. + */ +export class FlareEngineWrapper { + private flare: FlareModule | null = null; + private engine: FlareEngineInstance | null = null; + private systemPrompt = ''; + private modelCacheKey = ''; + private gpuEnabled = false; + + // ------------------------------------------------------------------------- + // Lifecycle + // ------------------------------------------------------------------------- + + /** + * Load a Flare GGUF model. + * + * On first call: downloads the GGUF file and stores it in OPFS. + * On repeat calls: loads instantly from the OPFS cache (<100 ms). + */ + async loadModel(modelConfig: FlareConfig, options: FlareLoadOptions = {}): Promise { + // Dynamically import @aspect/flare — fails gracefully if not installed + this.flare = await this.importFlare(); + + const url = options.url ?? modelConfig.url; + if (!url) { + throw new Error( + `No URL found for Flare model "${modelConfig.modelName}". ` + + 'Provide a URL in the model config or in loadModel options.', + ); + } + + this.systemPrompt = (options.systemPrompt as string) ?? ''; + this.modelCacheKey = this.buildCacheKey(url); + + // Attempt to load from OPFS cache first + let modelBytes = await readFromOpfs(this.modelCacheKey); + + if (!modelBytes) { + // Download with progress + modelBytes = await fetchWithProgress(url, options.onProgress); + + // Cache for next time (non-blocking) + writeToOpfs(this.modelCacheKey, modelBytes).catch((err) => { + console.warn('[Flare] OPFS cache write failed:', err); + }); + } else { + // Instant cache hit — report 100% progress + options.onProgress?.(modelBytes.byteLength, modelBytes.byteLength); + } + + // Load model into WASM + this.engine = this.flare.FlareEngine.load(modelBytes); + + // Try to initialise WebGPU backend + const useGpu = options.useGpu !== false; + if (useGpu) { + try { + this.gpuEnabled = await this.engine.init_gpu(); + if (!this.gpuEnabled) { + console.info('[Flare] WebGPU unavailable — using CPU SIMD path'); + } + } catch { + console.info('[Flare] WebGPU init failed — using CPU SIMD path'); + this.gpuEnabled = false; + } + } + } + + // ------------------------------------------------------------------------- + // Text generation + // ------------------------------------------------------------------------- + + /** + * Generate text. Accepts a plain string prompt or an OpenAI-style messages array. + * Returns an OpenAI-compatible chat completion object. + */ + async generateText( + input: string | Array<{ role: string; content: string }>, + options: FlareGenerateOptions = {}, + ): Promise { + if (!this.engine) throw new Error('[Flare] No model loaded. Call loadModel first.'); + + // Normalise input to a messages array + const messages: Array<{ role: string; content: string }> = Array.isArray(input) + ? input + : [{ role: 'user', content: input }]; + + // Extract system message (last system wins) and user message (last user wins) + const systemMsg = + (options.system as string) ?? messages.findLast((m) => m.role === 'system')?.content ?? this.systemPrompt ?? ''; + const userMsg = messages.findLast((m) => m.role === 'user')?.content ?? ''; + + // Format using Flare's built-in chat template + const formattedPrompt = this.engine.apply_chat_template(userMsg, systemMsg); + const promptTokens = this.engine.encode_text(formattedPrompt); + + // Set stop sequences + this.engine.clear_stop_sequences(); + const stopSeqs = options.stop ?? []; + for (const seq of stopSeqs) { + this.engine.add_stop_sequence(seq); + } + + const maxTokens = (options.max_tokens as number) ?? 512; + const temperature = (options.temperature as number) ?? 0.7; + const topP = (options.top_p as number) ?? 0.9; + const topK = (options.top_k as number) ?? 40; + const repeatPenalty = (options.repeat_penalty as number) ?? 1.1; + const minP = (options.min_p as number) ?? 0.0; + const onToken = options.onToken; + + // Reset KV cache for a fresh generation + this.engine.reset(); + + let outputText = ''; + let completionTokens = 0; + + if (onToken) { + // Streaming path — call onToken per decoded token + this.engine.begin_stream_with_params(promptTokens, maxTokens, temperature, topP, topK, repeatPenalty, minP); + + while (!this.engine.stream_done) { + const tokenId = this.engine.next_token(); + if (tokenId === undefined) break; + const tokenText = this.engine.decode_token_chunk(tokenId); + outputText += tokenText; + completionTokens++; + onToken(tokenText); + } + } else { + // Batch path — generate_text_with_params is synchronous inside WASM + outputText = this.engine.generate_text_with_params( + formattedPrompt, + maxTokens, + temperature, + topP, + topK, + repeatPenalty, + minP, + ); + completionTokens = outputText.length; // approximate + } + + const stopReason = this.engine.stream_stop_reason || 'stop'; + + return buildChatResponse(outputText, promptTokens.length, completionTokens, stopReason); + } + + /** + * Embeddings are not supported by Flare (GGUF text-generation only). + */ + async embed(_input: string, _options: Record = {}): Promise { + throw new Error('[Flare] Embedding is not supported. Use a Transformers.js feature-extraction model instead.'); + } + + // ------------------------------------------------------------------------- + // LoRA adapters (issue #298) + // ------------------------------------------------------------------------- + + /** + * Fetch and merge a LoRA adapter into the loaded model weights. + * + * The adapter file must be in SafeTensors format. After merging, all + * subsequent `generateText` calls use the adapted model. Unmerging requires + * reloading the base model via `loadModel`. + * + * @example + * ```ts + * const ai = new BrowserAI({ engine: 'flare' }); + * await ai.loadModel('llama-3.2-1b-flare'); + * await ai.loadAdapter({ url: 'https://.../.../adapter.safetensors', alpha: 16 }); + * ``` + */ + async loadAdapter(options: FlareAdapterOptions): Promise { + if (!this.engine) throw new Error('[Flare] No model loaded. Call loadModel first.'); + + const adapterBytes = await fetchWithProgress(options.url); + + if (options.alpha !== undefined) { + this.engine.merge_lora_with_alpha(adapterBytes, options.alpha); + } else { + this.engine.merge_lora(adapterBytes); + } + + console.info('[Flare] LoRA adapter merged successfully.'); + } + + // ------------------------------------------------------------------------- + // Progressive loading helpers (issue #300) + // ------------------------------------------------------------------------- + + /** + * Load a model progressively — returns as soon as the engine is initialised + * (with OPFS cache), or while the download is in flight. + * + * The `onLayersReady` callback is called each time new layers become + * available, allowing early inference on the partial model. + * + * NOTE: True progressive layer-by-layer inference requires Flare's + * `FlareProgressiveLoader` WASM class. This method provides the BrowserAI + * API surface; the underlying progressive streaming is handled by the loader. + */ + async loadModelProgressive( + modelConfig: FlareConfig, + options: FlareLoadOptions & { + onLayersReady?: (availableLayers: number, totalLayers: number) => void; + } = {}, + ): Promise { + // For now delegate to normal loadModel — progressive layer inference + // will be wired in once FlareProgressiveLoader exposes layer callbacks. + return this.loadModel(modelConfig, options); + } + + // ------------------------------------------------------------------------- + // Cache management + // ------------------------------------------------------------------------- + + /** + * Check whether the currently loaded model is cached in OPFS. + */ + async isCached(): Promise { + if (!this.modelCacheKey) return false; + return isModelCached(this.modelCacheKey); + } + + /** + * Delete the OPFS cache entry for the currently loaded model. + */ + async clearCache(): Promise { + if (!this.modelCacheKey) return; + await deleteCachedModel(this.modelCacheKey); + console.info('[Flare] Cleared OPFS cache for:', this.modelCacheKey); + } + + // ------------------------------------------------------------------------- + // Diagnostics + // ------------------------------------------------------------------------- + + get isGpuEnabled(): boolean { + return this.gpuEnabled; + } + + get modelInfo(): Record { + if (!this.engine) return {}; + return { + modelName: this.engine.model_name, + architecture: this.engine.architecture, + chatTemplate: this.engine.chat_template_name, + maxSeqLen: this.engine.max_seq_len, + tokensUsed: this.engine.tokens_used, + gpuEnabled: this.gpuEnabled, + }; + } + + dispose(): void { + this.engine = null; + this.flare = null; + this.systemPrompt = ''; + this.modelCacheKey = ''; + this.gpuEnabled = false; + } + + // ------------------------------------------------------------------------- + // Private helpers + // ------------------------------------------------------------------------- + + private async importFlare(): Promise { + try { + // Dynamic import so the package is optional — BrowserAI still works + // without @aspect/flare as long as users don't select the Flare engine. + const mod = await import('@aspect/flare' as string); + // Initialise the WASM module + await (mod as unknown as { default: () => Promise }).default(); + return mod as unknown as FlareModule; + } catch (err) { + throw new Error( + '[Flare] Could not load @aspect/flare. ' + + 'Install it with: npm install @aspect/flare\n' + + `Original error: ${err}`, + ); + } + } + + private buildCacheKey(url: string): string { + // Use the last path segment (filename) as the cache key, with a hash of + // the full URL to avoid collisions between same-named files on different hosts. + const filename = url.split('/').pop() ?? 'model.gguf'; + const hash = this.simpleHash(url); + return `${filename}-${hash}`; + } + + /** Deterministic 32-bit hash of a string (djb2 variant). */ + private simpleHash(str: string): string { + let h = 5381; + for (let i = 0; i < str.length; i++) { + h = ((h << 5) + h) ^ str.charCodeAt(i); + } + return (h >>> 0).toString(16); + } +} diff --git a/src/index.ts b/src/index.ts index 96d8fa2..75be06e 100644 --- a/src/index.ts +++ b/src/index.ts @@ -7,9 +7,17 @@ export { MLCEngineWrapper } from './engines/mlc-engine-wrapper'; export { TransformersEngineWrapper } from './engines/transformer-engine-wrapper'; export { DemucsEngine } from './engines/demucs-engine'; export type { SeparateOptions, SeparationResult } from './engines/demucs-engine'; +export { + FlareEngineWrapper, + isModelCached as isFlareModelCached, + deleteCachedModel as deleteFlareModelCache, + listCachedModels as listFlareCachedModels, +} from './engines/flare-engine-wrapper'; +export type { FlareLoadOptions, FlareGenerateOptions, FlareAdapterOptions } from './engines/flare-engine-wrapper'; export { default as mlcModels } from './config/models/mlc-models.json'; export { default as transformersModels } from './config/models/transformers-models.json'; export { default as demucsModels } from './config/models/demucs-models.json'; +export { default as flareModels } from './config/models/flare-models.json'; export { DatabaseImpl } from './core/database'; export * from './core/agent';