diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b2332e3..dac4c695 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ ## Unreleased +* Added llama.cpp ngram-simple speculative decoding via + `SpeculativeDecodingConfig.ngramSimple(...)`, including Dart routing, native + wrapper bindings, docs, and local benchmark matrix coverage. + * Added `LlamaStructuredOutput` and `LlamaEngine.createStructuredJson(...)` helpers for strict JSON-object / JSON-schema generation with final-output validation and typed decoding. diff --git a/README.md b/README.md index b9861686..dda97614 100644 --- a/README.md +++ b/README.md @@ -288,6 +288,30 @@ Higher `draftTokenMax` values can be faster on some models/devices, but they should be benchmarked with the target model because excess draft depth can add verification overhead. +For GGUF models without an MTP or separate draft model, llama.cpp ngram-simple +speculative decoding uses recent token history as the drafter: + +```dart +params: const GenerationParams( + maxTokens: 128, + speculativeDecodingConfig: SpeculativeDecodingConfig.ngramSimple( + draftTokenMax: 2, + ngramSize: 12, + ), +), +``` + +Reserve `ModelParams.speculativeRollbackTokenMax` at least as large as +`draftTokenMax` before using ngram-simple. llamadart currently supports +ngram-simple `draftTokenMax <= 2`; deeper drafts can diverge from +non-speculative decoding on some model/backend combinations. Ngram-simple is +workload-dependent and can be slower than baseline decoding on prompts with +little repetition, so validate it with your model and prompt shape. For local +measurements, set +`LLAMADART_MTP_BENCHMARK_NGRAM=true` and +`LLAMADART_MTP_BENCHMARK_NGRAM_ONLY=true` when running +`tool/testing/llama_cpp_mtp_benchmark.dart`. + For target/draft model pairs, pass the separate drafter GGUF with `draftModelPath`: diff --git a/doc/testing_matrix.md b/doc/testing_matrix.md index fecbecba..ad6db3bc 100644 --- a/doc/testing_matrix.md +++ b/doc/testing_matrix.md @@ -159,6 +159,12 @@ dart run tool/testing/native_inference_benchmark.dart \ --mode all \ --runs 3 \ --max-tokens 128 + +LLAMADART_MTP_BENCHMARK_NGRAM=true \ +LLAMADART_MTP_BENCHMARK_NGRAM_ONLY=true \ +LLAMADART_MTP_BENCHMARK_NGRAM_SIZE=1 \ + dart run tool/testing/llama_cpp_mtp_benchmark.dart \ + models/Qwen3.5-0.8B-Q4_K_M.gguf - 128 3 1,2,4 1 ``` Use `--dry-run` first when a scenario starts servers, builds Flutter web, or diff --git a/lib/src/backends/litert_lm/litert_lm_service.dart b/lib/src/backends/litert_lm/litert_lm_service.dart index ddf430d7..035f8b14 100644 --- a/lib/src/backends/litert_lm/litert_lm_service.dart +++ b/lib/src/backends/litert_lm/litert_lm_service.dart @@ -1013,6 +1013,10 @@ class LiteRtLmService { if (config == null) { return; } + if (config.strategy != SpeculativeDecodingStrategy.backendDefault && + config.strategy != SpeculativeDecodingStrategy.mtp) { + unsupported.add('speculativeDecodingConfig.strategy'); + } if (config.draftTokenMax != null) { unsupported.add('speculativeDecodingConfig.draftTokenMax'); } @@ -1025,6 +1029,9 @@ class LiteRtLmService { if (config.draftModelPath != null) { unsupported.add('speculativeDecodingConfig.draftModelPath'); } + if (config.ngramSize != null) { + unsupported.add('speculativeDecodingConfig.ngramSize'); + } } int _defaultSamplerSeed() { diff --git a/lib/src/backends/llama_cpp/bindings.dart b/lib/src/backends/llama_cpp/bindings.dart index edef1540..573dbcf4 100644 --- a/lib/src/backends/llama_cpp/bindings.dart +++ b/lib/src/backends/llama_cpp/bindings.dart @@ -7841,6 +7841,70 @@ external void llama_dart_mtp_accept( int accepted_count, ); +@ffi.Native Function(ffi.Int32, ffi.Int32)>() +external ffi.Pointer llama_dart_ngram_simple_init( + int ngram_size, + int draft_token_max, +); + +@ffi.Native)>() +external void llama_dart_ngram_free(ffi.Pointer ngram); + +@ffi.Native< + ffi.Bool Function( + ffi.Pointer, + llama_seq_id, + ffi.Pointer, + ffi.Int32, + ) +>() +external bool llama_dart_ngram_begin( + ffi.Pointer ngram, + int seq_id, + ffi.Pointer prompt, + int prompt_count, +); + +@ffi.Native, llama_batch)>() +external bool llama_dart_ngram_process_batch( + ffi.Pointer ngram, + llama_batch batch, +); + +@ffi.Native< + ffi.Int32 Function( + ffi.Pointer, + llama_seq_id, + llama_pos, + llama_token, + ffi.Pointer, + ffi.Int32, + ffi.Int32, + ffi.Pointer, + ffi.Int32, + ) +>() +external int llama_dart_ngram_draft( + ffi.Pointer ngram, + int seq_id, + int n_past, + int id_last, + ffi.Pointer prompt, + int prompt_count, + int draft_token_max, + ffi.Pointer out_tokens, + int out_capacity, +); + +@ffi.Native< + ffi.Void Function(ffi.Pointer, llama_seq_id, ffi.Uint16) +>() +external void llama_dart_ngram_accept( + ffi.Pointer ngram, + int seq_id, + int accepted_count, +); + @ffi.Native< ffi.Int32 Function( ffi.Pointer, @@ -10279,6 +10343,8 @@ final class mtmd_helper_video_init_params extends ffi.Struct { final class llama_dart_mtp extends ffi.Opaque {} +final class llama_dart_ngram extends ffi.Opaque {} + const int LLAMA_DEFAULT_SEED = 4294967295; const int LLAMA_TOKEN_NULL = -1; diff --git a/lib/src/backends/llama_cpp/llama_cpp_service.dart b/lib/src/backends/llama_cpp/llama_cpp_service.dart index ff0e0ee9..b560d39c 100644 --- a/lib/src/backends/llama_cpp/llama_cpp_service.dart +++ b/lib/src/backends/llama_cpp/llama_cpp_service.dart @@ -7,6 +7,7 @@ import 'dart:math' as math; import 'package:ffi/ffi.dart'; import 'package:path/path.dart' as path; +import '../../core/exceptions.dart'; import '../../core/llama_logger.dart'; import '../../core/models/chat/content_part.dart'; import '../../core/models/config/gpu_backend.dart'; @@ -240,6 +241,53 @@ typedef _LlamaDartMtpAcceptNative = Void Function(Pointer, llama_seq_id, Uint16); typedef _LlamaDartMtpAcceptDart = void Function(Pointer, int, int); +typedef _LlamaDartNgramSimpleInitNative = + Pointer Function(Int32, Int32); +typedef _LlamaDartNgramSimpleInitDart = + Pointer Function(int, int); +typedef _LlamaDartNgramFreeNative = Void Function(Pointer); +typedef _LlamaDartNgramFreeDart = void Function(Pointer); +typedef _LlamaDartNgramBeginNative = + Bool Function( + Pointer, + llama_seq_id, + Pointer, + Int32, + ); +typedef _LlamaDartNgramBeginDart = + bool Function(Pointer, int, Pointer, int); +typedef _LlamaDartNgramProcessBatchNative = + Bool Function(Pointer, llama_batch); +typedef _LlamaDartNgramProcessBatchDart = + bool Function(Pointer, llama_batch); +typedef _LlamaDartNgramDraftNative = + Int32 Function( + Pointer, + llama_seq_id, + llama_pos, + llama_token, + Pointer, + Int32, + Int32, + Pointer, + Int32, + ); +typedef _LlamaDartNgramDraftDart = + int Function( + Pointer, + int, + int, + int, + Pointer, + int, + int, + Pointer, + int, + ); +typedef _LlamaDartNgramAcceptNative = + Void Function(Pointer, llama_seq_id, Uint16); +typedef _LlamaDartNgramAcceptDart = + void Function(Pointer, int, int); typedef _LlamaDartSamplerSampleAndAcceptNNative = Int32 Function( Pointer, @@ -351,6 +399,67 @@ external void _llamadartWrapperMtpAccept( int acceptedCount, ); +@Native<_LlamaDartNgramSimpleInitNative>( + assetId: _llamadartWrapperAssetId, + symbol: 'llama_dart_ngram_simple_init', +) +external Pointer _llamadartWrapperNgramSimpleInit( + int ngramSize, + int draftTokenMax, +); + +@Native<_LlamaDartNgramFreeNative>( + assetId: _llamadartWrapperAssetId, + symbol: 'llama_dart_ngram_free', +) +external void _llamadartWrapperNgramFree(Pointer ngram); + +@Native<_LlamaDartNgramBeginNative>( + assetId: _llamadartWrapperAssetId, + symbol: 'llama_dart_ngram_begin', +) +external bool _llamadartWrapperNgramBegin( + Pointer ngram, + int seqId, + Pointer prompt, + int promptCount, +); + +@Native<_LlamaDartNgramProcessBatchNative>( + assetId: _llamadartWrapperAssetId, + symbol: 'llama_dart_ngram_process_batch', +) +external bool _llamadartWrapperNgramProcessBatch( + Pointer ngram, + llama_batch batch, +); + +@Native<_LlamaDartNgramDraftNative>( + assetId: _llamadartWrapperAssetId, + symbol: 'llama_dart_ngram_draft', +) +external int _llamadartWrapperNgramDraft( + Pointer ngram, + int seqId, + int nPast, + int idLast, + Pointer prompt, + int promptCount, + int draftTokenMax, + Pointer outTokens, + int outCapacity, +); + +@Native<_LlamaDartNgramAcceptNative>( + assetId: _llamadartWrapperAssetId, + symbol: 'llama_dart_ngram_accept', +) +external void _llamadartWrapperNgramAccept( + Pointer ngram, + int seqId, + int acceptedCount, +); + @Native<_LlamaDartSamplerSampleAndAcceptNNative>( assetId: _llamadartWrapperAssetId, symbol: 'llama_dart_sampler_sample_and_accept_n', @@ -384,6 +493,16 @@ class _LlamaCppMtpConfig { final String? draftModelPath; } +class _LlamaCppNgramSimpleConfig { + const _LlamaCppNgramSimpleConfig({ + required this.draftTokenMax, + required this.ngramSize, + }); + + final int draftTokenMax; + final int ngramSize; +} + /// Service responsible for managing Llama.cpp models and contexts. /// /// This service handles the direct interaction with the native Llama.cpp library, @@ -461,6 +580,8 @@ class LlamaCppService { _MtmdApi? _mtmdFallbackApi; bool _mtpApiLookupAttempted = false; _MtpApi? _mtpApi; + bool _ngramApiLookupAttempted = false; + _NgramApi? _ngramApi; final List _startupDiagnostics = []; // --- Internal State --- @@ -2603,6 +2724,54 @@ class LlamaCppService { 'runtime bundle (missing llama_dart_mtp_* wrapper symbols).'; } + _NgramApi _resolveNgramApi() { + final cached = _ngramApi; + if (cached != null) { + return cached; + } + + if (_ngramApiLookupAttempted) { + throw UnsupportedError(_ngramUnavailableMessage()); + } + _ngramApiLookupAttempted = true; + + if (!Platform.isWindows) { + try { + final direct = _NgramApi.direct(); + _ngramApi = direct; + return direct; + } catch (_) {} + } + + if (Platform.isWindows) { + try { + final asset = _NgramApi.windowsAsset(); + _ngramApi = asset; + return asset; + } catch (_) {} + } + + for (final candidate in _llamadartWrapperLibraryCandidates()) { + try { + final library = DynamicLibrary.open(candidate); + final api = _NgramApi.tryLoad(library); + if (api != null) { + _ngramApi = api; + return api; + } + } catch (_) { + continue; + } + } + + throw UnsupportedError(_ngramUnavailableMessage()); + } + + String _ngramUnavailableMessage() { + return 'llama.cpp ngram-simple speculative decoding is unavailable in this ' + 'native runtime bundle (missing llama_dart_ngram_* wrapper symbols).'; + } + List _llamadartWrapperLibraryCandidates() { final candidates = [..._llamadartAssetUriCandidates()]; final fileNameCandidates = _llamadartLibraryCandidateFileNames(); @@ -3286,7 +3455,7 @@ class LlamaCppService { _contexts.remove(handle)?.dispose(); } - _LlamaCppMtpConfig? _resolveLlamaCppMtpConfig( + Object? _resolveLlamaCppSpeculativeConfig( GenerationParams params, { required bool hasMediaParts, }) { @@ -3297,13 +3466,13 @@ class LlamaCppService { if (hasMediaParts) { throw UnsupportedError( - 'llama.cpp MTP speculative decoding currently supports text-only ' + 'llama.cpp speculative decoding currently supports text-only ' 'generation in llamadart.', ); } if (params.grammar != null) { throw UnsupportedError( - 'llama.cpp MTP speculative decoding does not yet support grammar ' + 'llama.cpp speculative decoding does not yet support grammar ' 'sampling in llamadart.', ); } @@ -3311,49 +3480,95 @@ class LlamaCppService { switch (speculativeConfig.strategy) { case SpeculativeDecodingStrategy.backendDefault: case SpeculativeDecodingStrategy.mtp: - break; - } + final draftTokenMax = speculativeConfig.draftTokenMax ?? 1; + final draftTokenMin = speculativeConfig.draftTokenMin ?? 0; + final minProbability = speculativeConfig.minProbability ?? 0.0; + final draftModelPath = speculativeConfig.draftModelPath; + final ngramSize = speculativeConfig.ngramSize; + + if (draftTokenMax <= 0) { + throw RangeError.value( + draftTokenMax, + 'draftTokenMax', + 'must be greater than zero for llama.cpp MTP', + ); + } + if (draftTokenMin < 0 || draftTokenMin > draftTokenMax) { + throw RangeError.value( + draftTokenMin, + 'draftTokenMin', + 'must be between zero and draftTokenMax for llama.cpp MTP', + ); + } + if (minProbability < 0.0 || minProbability > 1.0) { + throw RangeError.value( + minProbability, + 'minProbability', + 'must be between 0.0 and 1.0 for llama.cpp MTP', + ); + } + if (draftModelPath != null && draftModelPath.trim().isEmpty) { + throw ArgumentError.value( + draftModelPath, + 'draftModelPath', + 'must be null or a non-empty path for llama.cpp MTP', + ); + } + if (ngramSize != null) { + throw UnsupportedError( + 'llama.cpp MTP speculative decoding does not use ngramSize. ' + 'Use SpeculativeDecodingConfig.ngramSimple for n-gram ' + 'self-speculative decoding.', + ); + } - final draftTokenMax = speculativeConfig.draftTokenMax ?? 1; - final draftTokenMin = speculativeConfig.draftTokenMin ?? 0; - final minProbability = speculativeConfig.minProbability ?? 0.0; - final draftModelPath = speculativeConfig.draftModelPath; + return _LlamaCppMtpConfig( + draftTokenMax: draftTokenMax, + draftTokenMin: draftTokenMin, + minProbability: minProbability, + draftModelPath: draftModelPath, + ); + case SpeculativeDecodingStrategy.ngramSimple: + if (speculativeConfig.draftTokenMin != null || + speculativeConfig.minProbability != null || + speculativeConfig.draftModelPath != null) { + throw UnsupportedError( + 'llama.cpp ngram-simple speculative decoding uses token history ' + 'and does not support draftTokenMin, minProbability, or ' + 'draftModelPath.', + ); + } - if (draftTokenMax <= 0) { - throw RangeError.value( - draftTokenMax, - 'draftTokenMax', - 'must be greater than zero for llama.cpp MTP', - ); - } - if (draftTokenMin < 0 || draftTokenMin > draftTokenMax) { - throw RangeError.value( - draftTokenMin, - 'draftTokenMin', - 'must be between zero and draftTokenMax for llama.cpp MTP', - ); - } - if (minProbability < 0.0 || minProbability > 1.0) { - throw RangeError.value( - minProbability, - 'minProbability', - 'must be between 0.0 and 1.0 for llama.cpp MTP', - ); - } - if (draftModelPath != null && draftModelPath.trim().isEmpty) { - throw ArgumentError.value( - draftModelPath, - 'draftModelPath', - 'must be null or a non-empty path for llama.cpp MTP', - ); - } + final draftTokenMax = speculativeConfig.draftTokenMax ?? 48; + final ngramSize = speculativeConfig.ngramSize ?? 12; + if (draftTokenMax <= 0) { + throw RangeError.value( + draftTokenMax, + 'draftTokenMax', + 'must be greater than zero for llama.cpp ngram-simple', + ); + } + if (ngramSize <= 0) { + throw RangeError.value( + ngramSize, + 'ngramSize', + 'must be greater than zero for llama.cpp ngram-simple', + ); + } + if (draftTokenMax > 2) { + throw LlamaUnsupportedException( + 'llama.cpp ngram-simple speculative decoding with ' + 'draftTokenMax > 2 is not enabled in llamadart yet. Deeper ' + 'ngram-simple drafts can diverge from non-speculative decoding on ' + 'some model/backend combinations. Use draftTokenMax <= 2.', + ); + } - return _LlamaCppMtpConfig( - draftTokenMax: draftTokenMax, - draftTokenMin: draftTokenMin, - minProbability: minProbability, - draftModelPath: draftModelPath, - ); + return _LlamaCppNgramSimpleConfig( + draftTokenMax: draftTokenMax, + ngramSize: ngramSize, + ); + } } /// Generates text based on the given [prompt] and [params]. @@ -3383,6 +3598,8 @@ class LlamaCppService { Pointer sampler = nullptr; Pointer mtpSession = nullptr; _MtpApi? mtpApi; + Pointer ngramSession = nullptr; + _NgramApi? ngramApi; try { final modelHandle = _contextToModel[contextHandle]!; @@ -3392,7 +3609,7 @@ class LlamaCppService { final hasMediaParts = parts?.any((p) => p is LlamaImageContent || p is LlamaAudioContent) ?? false; - final mtpConfig = _resolveLlamaCppMtpConfig( + final speculativeConfig = _resolveLlamaCppSpeculativeConfig( params, hasMediaParts: hasMediaParts, ); @@ -3401,7 +3618,9 @@ class LlamaCppService { contextHandle, ctx, clearMemory: - hasMediaParts || mtpConfig != null || !params.reusePromptPrefix, + hasMediaParts || + speculativeConfig != null || + !params.reusePromptPrefix, ); ctx.resetLastPerf(); llama_perf_context_reset(ctx.pointer); @@ -3416,9 +3635,9 @@ class LlamaCppService { tokensPtr = malloc(nCtx); pieceBuf = malloc(256); - if (mtpConfig != null) { + if (speculativeConfig is _LlamaCppMtpConfig) { mtpApi = _resolveMtpApi(); - final draftModelPath = mtpConfig.draftModelPath; + final draftModelPath = speculativeConfig.draftModelPath; final draftModel = draftModelPath == null ? null : _loadMtpDraftModel(modelHandle, draftModelPath); @@ -3427,9 +3646,9 @@ class LlamaCppService { draftModel: draftModel?.pointer, targetContext: ctx.pointer, contextParams: modelParams, - draftTokenMax: mtpConfig.draftTokenMax, - draftTokenMin: mtpConfig.draftTokenMin, - minProbability: mtpConfig.minProbability, + draftTokenMax: speculativeConfig.draftTokenMax, + draftTokenMin: speculativeConfig.draftTokenMin, + minProbability: speculativeConfig.minProbability, backendSampling: true, ); if (mtpSession == nullptr) { @@ -3445,6 +3664,19 @@ class LlamaCppService { ); } llama_perf_context_reset(ctx.pointer); + } else if (speculativeConfig is _LlamaCppNgramSimpleConfig) { + ngramApi = _resolveNgramApi(); + ngramSession = ngramApi.initSimple( + speculativeConfig.ngramSize, + speculativeConfig.draftTokenMax, + ); + if (ngramSession == nullptr) { + throw UnsupportedError( + 'llama.cpp ngram-simple speculative decoding is not available for ' + 'this native runtime bundle. Use a libllamadart build that exports ' + 'llama_dart_ngram_* wrapper symbols.', + ); + } } if (params.grammar != null) { @@ -3469,7 +3701,9 @@ class LlamaCppService { nCtx, modelParams, allowTextPromptReuse: - mtpConfig == null && !hasMediaParts && params.reusePromptPrefix, + speculativeConfig == null && + !hasMediaParts && + params.reusePromptPrefix, mtpSession: mtpSession, mtpApi: mtpApi, ); @@ -3483,6 +3717,12 @@ class LlamaCppService { !mtpApi!.begin(mtpSession, 0, tokensPtr, initialTokens)) { throw Exception("Failed to initialize llama.cpp MTP prompt state"); } + if (ngramSession != nullptr && + !ngramApi!.begin(ngramSession, 0, tokensPtr, initialTokens)) { + throw Exception( + "Failed to initialize llama.cpp ngram-simple prompt state", + ); + } // 4. Initialize and Run Sampler Loop sampler = _initializeSampler( @@ -3504,14 +3744,14 @@ class LlamaCppService { params.preservedTokens, ); - if (mtpSession != nullptr && mtpConfig != null) { + if (mtpSession != nullptr && speculativeConfig is _LlamaCppMtpConfig) { yield* _runMtpInferenceLoop( ctx, batch, vocab, sampler, params, - mtpConfig, + speculativeConfig, initialTokens, nCtx, cancelTokenAddress, @@ -3522,6 +3762,25 @@ class LlamaCppService { mtpApi!, tokensPtr, ); + } else if (ngramSession != nullptr && + speculativeConfig is _LlamaCppNgramSimpleConfig) { + yield* _runNgramInferenceLoop( + ctx, + batch, + vocab, + sampler, + params, + speculativeConfig, + initialTokens, + nCtx, + cancelTokenAddress, + pieceBuf, + preservedTokenIds, + effectiveStopSequences, + ngramSession, + ngramApi!, + tokensPtr, + ); } else { yield* _runInferenceLoop( ctx, @@ -3540,6 +3799,7 @@ class LlamaCppService { } } finally { if (mtpSession != nullptr) mtpApi?.free(mtpSession); + if (ngramSession != nullptr) ngramApi?.free(ngramSession); if (sampler != nullptr) llama_sampler_free(sampler); final remaining = (_generatingContexts[contextHandle] ?? 1) - 1; if (remaining <= 0) { @@ -4631,7 +4891,6 @@ class LlamaCppService { var speculativeDraftTokens = 0; var speculativeAcceptedDraftTokens = 0; var shouldStop = false; - try { while (!shouldStop && generatedTokens < params.maxTokens) { if (cancelToken.value == 1) break; @@ -4709,7 +4968,7 @@ class LlamaCppService { currentPos, draftLimit, draftPtr, - draftCapacity, + draftLimit, ); draftTick.stop(); draftMicros += draftTick.elapsedMicroseconds; @@ -4880,6 +5139,382 @@ class LlamaCppService { } } + Stream> _runNgramInferenceLoop( + _LlamaContextWrapper ctx, + llama_batch batch, + Pointer vocab, + Pointer sampler, + GenerationParams params, + _LlamaCppNgramSimpleConfig ngramConfig, + int startPos, + int nCtx, + int cancelTokenAddress, + Pointer pieceBuf, + Set preservedTokenIds, + List stopSequences, + Pointer ngramSession, + _NgramApi ngramApi, + Pointer tokensPtr, + ) async* { + final cancelToken = Pointer.fromAddress(cancelTokenAddress); + final draftCapacity = ngramConfig.draftTokenMax; + final draftPtr = malloc(draftCapacity); + final idxPtr = malloc(draftCapacity + 1); + final acceptedPtr = malloc(draftCapacity + 1); + + int currentPos = startPos; + int? pendingSampledToken; + final accumulatedBytes = []; + final evalStopwatch = Stopwatch()..start(); + var sampleMicros = 0; + var evalMicros = 0; + var draftMicros = 0; + var verifyMicros = 0; + var generatedTokens = 0; + var speculativeDraftTokens = 0; + var speculativeAcceptedDraftTokens = 0; + var shouldStop = false; + + try { + while (!shouldStop && generatedTokens < params.maxTokens) { + if (cancelToken.value == 1) break; + if (currentPos >= nCtx) break; + + int selectedToken; + if (pendingSampledToken != null) { + selectedToken = pendingSampledToken; + pendingSampledToken = null; + } else { + final sampleTick = Stopwatch()..start(); + selectedToken = llama_sampler_sample(sampler, ctx.pointer, -1); + sampleTick.stop(); + sampleMicros += sampleTick.elapsedMicroseconds; + if (llama_vocab_is_eog(vocab, selectedToken)) break; + + final pieceTick = Stopwatch()..start(); + final n = llama_token_to_piece( + vocab, + selectedToken, + pieceBuf.cast(), + 256, + 0, + preservedTokenIds.contains(selectedToken), + ); + pieceTick.stop(); + sampleMicros += pieceTick.elapsedMicroseconds; + generatedTokens++; + + if (n > 0) { + final bytes = pieceBuf.asTypedList(n).toList(); + yield bytes; + if (stopSequences.isNotEmpty) { + accumulatedBytes.addAll(bytes); + if (accumulatedBytes.length > 64) { + accumulatedBytes.removeRange(0, accumulatedBytes.length - 64); + } + final text = utf8.decode(accumulatedBytes, allowMalformed: true); + if (stopSequences.any((s) => text.endsWith(s))) { + shouldStop = true; + } + } + } + + if (shouldStop) { + break; + } + } + + final remainingToGenerate = params.maxTokens - generatedTokens; + final batchCapacity = math.max(1, llama_n_batch(ctx.pointer)); + final rollbackCapacity = llama_n_rs_seq(ctx.pointer); + final contextDraftCapacity = rollbackCapacity > 0 + ? math.min(nCtx - currentPos - 2, rollbackCapacity) + : nCtx - currentPos - 2; + final draftLimit = remainingToGenerate <= 1 + ? 0 + : math.min( + ngramConfig.draftTokenMax, + math.min( + math.min(remainingToGenerate - 1, contextDraftCapacity), + batchCapacity - 1, + ), + ); + + var draftCount = 0; + if (draftLimit > 0) { + final draftTick = Stopwatch()..start(); + draftCount = ngramApi.draft( + ngramSession, + 0, + currentPos, + selectedToken, + tokensPtr, + currentPos, + draftLimit, + draftPtr, + draftLimit, + ); + draftTick.stop(); + draftMicros += draftTick.elapsedMicroseconds; + if (draftCount < 0) { + throw Exception("llama.cpp ngram-simple draft failed"); + } + speculativeDraftTokens += draftCount; + } + + if (draftCount <= 0) { + batch.n_tokens = 1; + batch.token[0] = selectedToken; + batch.pos[0] = currentPos; + batch.n_seq_id[0] = 1; + batch.seq_id[0][0] = 0; + batch.logits[0] = 1; + + final evalTick = Stopwatch()..start(); + final decodeStatus = llama_decode(ctx.pointer, batch); + if (decodeStatus == 0 && + !ngramApi.processBatch(ngramSession, batch)) { + throw Exception("ngram-simple decode processing failed"); + } + evalTick.stop(); + evalMicros += evalTick.elapsedMicroseconds; + if (decodeStatus != 0) break; + + tokensPtr[currentPos] = selectedToken; + currentPos++; + continue; + } + + final batchTokens = draftCount + 1; + Pointer seqCheckpoint = nullptr; + var seqCheckpointSize = 0; + try { + llama_synchronize(ctx.pointer); + seqCheckpointSize = llama_state_seq_get_size_ext( + ctx.pointer, + 0, + LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY, + ); + if (seqCheckpointSize > 0) { + seqCheckpoint = malloc(seqCheckpointSize); + final written = llama_state_seq_get_data_ext( + ctx.pointer, + seqCheckpoint, + seqCheckpointSize, + 0, + LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY, + ); + if (written != seqCheckpointSize) { + throw Exception( + 'ngram-simple checkpoint capture failed ' + '(expected $seqCheckpointSize bytes, got $written)', + ); + } + } + + batch.n_tokens = batchTokens; + batch.token[0] = selectedToken; + batch.pos[0] = currentPos; + batch.n_seq_id[0] = 1; + batch.seq_id[0][0] = 0; + batch.logits[0] = 1; + for (int i = 0; i < draftCount; i++) { + final batchIndex = i + 1; + batch.token[batchIndex] = draftPtr[i]; + batch.pos[batchIndex] = currentPos + batchIndex; + batch.n_seq_id[batchIndex] = 1; + batch.seq_id[batchIndex][0] = 0; + batch.logits[batchIndex] = 1; + } + + final evalTick = Stopwatch()..start(); + final decodeStatus = llama_decode(ctx.pointer, batch); + if (decodeStatus == 0 && + !ngramApi.processBatch(ngramSession, batch)) { + throw Exception("ngram-simple decode processing failed"); + } + if (decodeStatus == 0) { + llama_synchronize(ctx.pointer); + } + evalTick.stop(); + evalMicros += evalTick.elapsedMicroseconds; + if (decodeStatus != 0) break; + + for (int i = 0; i < batchTokens; i++) { + idxPtr[i] = i; + } + + final verifyTick = Stopwatch()..start(); + final acceptedCount = ngramApi.sampleAndAcceptN( + sampler, + ctx.pointer, + idxPtr, + batchTokens, + draftPtr, + draftCount, + acceptedPtr, + batchTokens, + ); + verifyTick.stop(); + verifyMicros += verifyTick.elapsedMicroseconds; + if (acceptedCount <= 0) { + throw Exception("llama.cpp ngram-simple draft verification failed"); + } + + final acceptedDraftCount = acceptedCount - 1; + final rejectedTailCount = batchTokens - acceptedCount; + if (rejectedTailCount > 0) { + if (seqCheckpoint == nullptr) { + throw UnsupportedError( + 'llama.cpp ngram-simple checkpoint rollback is unavailable ' + 'for this context.', + ); + } + + final restored = llama_state_seq_set_data_ext( + ctx.pointer, + seqCheckpoint, + seqCheckpointSize, + 0, + LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY, + ); + if (restored != seqCheckpointSize) { + throw Exception( + 'ngram-simple checkpoint restore failed ' + '(expected $seqCheckpointSize bytes, got $restored)', + ); + } + + final targetMemory = llama_get_memory(ctx.pointer); + if (targetMemory == nullptr || + !llama_memory_seq_rm(targetMemory, 0, currentPos, -1)) { + throw UnsupportedError( + 'llama.cpp ngram-simple checkpoint rollback failed for this ' + 'context.', + ); + } + + final replayTokens = 1 + acceptedDraftCount; + batch.n_tokens = replayTokens; + batch.token[0] = selectedToken; + batch.pos[0] = currentPos; + batch.n_seq_id[0] = 1; + batch.seq_id[0][0] = 0; + batch.logits[0] = 0; + for (int i = 0; i < acceptedDraftCount; i++) { + final batchIndex = i + 1; + batch.token[batchIndex] = acceptedPtr[i]; + batch.pos[batchIndex] = currentPos + batchIndex; + batch.n_seq_id[batchIndex] = 1; + batch.seq_id[batchIndex][0] = 0; + batch.logits[batchIndex] = 0; + } + + final replayTick = Stopwatch()..start(); + final replayStatus = llama_decode(ctx.pointer, batch); + if (replayStatus == 0 && + !ngramApi.processBatch(ngramSession, batch)) { + throw Exception("ngram-simple replay processing failed"); + } + if (replayStatus == 0) { + llama_synchronize(ctx.pointer); + } + replayTick.stop(); + evalMicros += replayTick.elapsedMicroseconds; + if (replayStatus != 0) { + throw Exception("ngram-simple replay decode failed"); + } + } + speculativeAcceptedDraftTokens += acceptedDraftCount; + ngramApi.accept(ngramSession, 0, acceptedDraftCount); + + final keepUntil = currentPos + 1 + acceptedDraftCount; + final targetMemory = llama_get_memory(ctx.pointer); + if (targetMemory == nullptr || + !llama_memory_seq_rm(targetMemory, 0, keepUntil, -1)) { + throw UnsupportedError( + 'llama.cpp ngram-simple target rollback failed for this context. ' + 'Set ModelParams.speculativeRollbackTokenMax >= draftTokenMax if ' + 'this model/backend uses bounded rollback snapshots.', + ); + } + + tokensPtr[currentPos] = selectedToken; + for (int i = 0; i < acceptedDraftCount; i++) { + tokensPtr[currentPos + 1 + i] = acceptedPtr[i]; + } + currentPos = keepUntil; + + for (int i = 0; i < acceptedCount; i++) { + final token = acceptedPtr[i]; + if (llama_vocab_is_eog(vocab, token)) { + shouldStop = true; + break; + } + + final pieceTick = Stopwatch()..start(); + final n = llama_token_to_piece( + vocab, + token, + pieceBuf.cast(), + 256, + 0, + preservedTokenIds.contains(token), + ); + pieceTick.stop(); + sampleMicros += pieceTick.elapsedMicroseconds; + generatedTokens++; + + if (n > 0) { + final bytes = pieceBuf.asTypedList(n).toList(); + yield bytes; + if (stopSequences.isNotEmpty) { + accumulatedBytes.addAll(bytes); + if (accumulatedBytes.length > 64) { + accumulatedBytes.removeRange(0, accumulatedBytes.length - 64); + } + final text = utf8.decode( + accumulatedBytes, + allowMalformed: true, + ); + if (stopSequences.any((s) => text.endsWith(s))) { + shouldStop = true; + } + } + } + + if (shouldStop || generatedTokens >= params.maxTokens) { + break; + } + } + + if (!shouldStop && generatedTokens < params.maxTokens) { + pendingSampledToken = acceptedPtr[acceptedCount - 1]; + } + } finally { + if (seqCheckpoint != nullptr) { + malloc.free(seqCheckpoint); + } + } + } + } finally { + malloc.free(draftPtr); + malloc.free(idxPtr); + malloc.free(acceptedPtr); + evalStopwatch.stop(); + ctx.lastPerfEvalMs = evalMicros / 1000.0; + ctx.lastPerfSampleMs = sampleMicros / 1000.0; + ctx.lastPerfDecodeMs = evalMicros / 1000.0; + ctx.lastPerfSpeculativeDraftMs = draftMicros / 1000.0; + ctx.lastPerfSpeculativeVerifyMs = verifyMicros / 1000.0; + ctx.lastPerfEvalTokens = generatedTokens; + ctx.lastPerfSampleCount = generatedTokens; + ctx.lastPerfSpeculativeDraftTokens = speculativeDraftTokens; + ctx.lastPerfSpeculativeAcceptedDraftTokens = + speculativeAcceptedDraftTokens; + } + } + _LazyGrammarConfig? _buildLazyGrammarConfig(GenerationParams params) { final triggerPatterns = []; final triggerTokens = []; @@ -6383,6 +7018,132 @@ class _MtpApi { } } +class _NgramApi { + final _LlamaDartNgramSimpleInitDart initSimple; + final _LlamaDartNgramFreeDart free; + final _LlamaDartNgramBeginDart begin; + final _LlamaDartNgramProcessBatchDart processBatch; + final _LlamaDartNgramDraftDart draft; + final _LlamaDartNgramAcceptDart accept; + final _LlamaDartSamplerSampleAndAcceptNDart sampleAndAcceptN; + + const _NgramApi({ + required this.initSimple, + required this.free, + required this.begin, + required this.processBatch, + required this.draft, + required this.accept, + required this.sampleAndAcceptN, + }); + + factory _NgramApi.direct() { + final api = _NgramApi( + initSimple: llama_dart_ngram_simple_init, + free: llama_dart_ngram_free, + begin: llama_dart_ngram_begin, + processBatch: llama_dart_ngram_process_batch, + draft: llama_dart_ngram_draft, + accept: llama_dart_ngram_accept, + sampleAndAcceptN: llama_dart_sampler_sample_and_accept_n, + ); + api.probe(); + return api; + } + + factory _NgramApi.windowsAsset() { + _llamadartWrapperNgramFree(nullptr.cast()); + return _NgramApi( + initSimple: _llamadartWrapperNgramSimpleInit, + free: _llamadartWrapperNgramFree, + begin: _llamadartWrapperNgramBegin, + processBatch: _llamadartWrapperNgramProcessBatch, + draft: _llamadartWrapperNgramDraft, + accept: _llamadartWrapperNgramAccept, + sampleAndAcceptN: _llamadartWrapperSamplerSampleAndAcceptN, + ); + } + + void probe() { + final nullNgram = nullptr.cast(); + final nullSampler = nullptr.cast(); + final nullContext = nullptr.cast(); + final nullTokenArray = nullptr.cast(); + final batch = llama_batch_init(1, 0, 1); + Pointer probeSession = nullptr; + try { + probeSession = initSimple(12, 1); + if (probeSession != nullptr) { + free(probeSession); + probeSession = nullptr; + } + free(nullNgram); + begin(nullNgram, 0, nullTokenArray, 0); + processBatch(nullNgram, batch); + draft(nullNgram, 0, 0, 0, nullTokenArray, 0, 1, nullTokenArray, 0); + accept(nullNgram, 0, 0); + sampleAndAcceptN( + nullSampler, + nullContext, + nullTokenArray, + 0, + nullTokenArray, + 0, + nullTokenArray, + 0, + ); + } finally { + if (probeSession != nullptr) { + free(probeSession); + } + llama_batch_free(batch); + } + } + + static _NgramApi? tryLoad(DynamicLibrary library) { + try { + return _NgramApi( + initSimple: library + .lookupFunction< + _LlamaDartNgramSimpleInitNative, + _LlamaDartNgramSimpleInitDart + >('llama_dart_ngram_simple_init'), + free: library + .lookupFunction<_LlamaDartNgramFreeNative, _LlamaDartNgramFreeDart>( + 'llama_dart_ngram_free', + ), + begin: library + .lookupFunction< + _LlamaDartNgramBeginNative, + _LlamaDartNgramBeginDart + >('llama_dart_ngram_begin'), + processBatch: library + .lookupFunction< + _LlamaDartNgramProcessBatchNative, + _LlamaDartNgramProcessBatchDart + >('llama_dart_ngram_process_batch'), + draft: library + .lookupFunction< + _LlamaDartNgramDraftNative, + _LlamaDartNgramDraftDart + >('llama_dart_ngram_draft'), + accept: library + .lookupFunction< + _LlamaDartNgramAcceptNative, + _LlamaDartNgramAcceptDart + >('llama_dart_ngram_accept'), + sampleAndAcceptN: library + .lookupFunction< + _LlamaDartSamplerSampleAndAcceptNNative, + _LlamaDartSamplerSampleAndAcceptNDart + >('llama_dart_sampler_sample_and_accept_n'), + ); + } catch (_) { + return null; + } + } +} + class _MtmdApi { final _MtmdDefaultMarkerDart defaultMarker; final _MtmdContextParamsDefaultDart contextParamsDefault; diff --git a/lib/src/core/models/inference/generation_params.dart b/lib/src/core/models/inference/generation_params.dart index c08b7590..ad2a15d6 100644 --- a/lib/src/core/models/inference/generation_params.dart +++ b/lib/src/core/models/inference/generation_params.dart @@ -42,6 +42,12 @@ enum SpeculativeDecodingStrategy { /// llama.cpp maps this to its `draft-mtp` speculative path. LiteRT-LM native /// currently maps this to its runtime speculative decoding switch. mtp, + + /// Self-speculative n-gram pattern matching. + /// + /// llama.cpp maps this to its `ngram-simple` path. It uses token history + /// rather than a separate draft model. + ngramSimple, } /// Backend-neutral speculative decoding configuration. @@ -56,6 +62,10 @@ class SpeculativeDecodingConfig { /// Maximum number of draft tokens to propose per speculative step. /// /// `null` lets the backend choose its default. + /// + /// For llama.cpp ngram-simple, pass `draftTokenMax: 2` or lower. Deeper + /// ngram-simple drafts are not enabled yet because they can diverge from + /// non-speculative decoding on some model/backend combinations. final int? draftTokenMax; /// Minimum number of draft tokens required for speculative verification. @@ -75,6 +85,12 @@ class SpeculativeDecodingConfig { /// Leave null for models that carry their own MTP layers. final String? draftModelPath; + /// Lookup n-gram size for n-gram self-speculative decoding. + /// + /// `null` lets the backend choose its default. This is currently only used by + /// [SpeculativeDecodingStrategy.ngramSimple]. + final int? ngramSize; + /// Creates a backend-neutral speculative decoding configuration. const SpeculativeDecodingConfig({ this.strategy = SpeculativeDecodingStrategy.backendDefault, @@ -82,8 +98,10 @@ class SpeculativeDecodingConfig { this.draftTokenMin, this.minProbability, this.draftModelPath, + this.ngramSize, }) : assert(draftTokenMax == null || draftTokenMax >= 0), assert(draftTokenMin == null || draftTokenMin >= 0), + assert(ngramSize == null || ngramSize > 0), assert( minProbability == null || (minProbability >= 0.0 && minProbability <= 1.0), @@ -95,7 +113,8 @@ class SpeculativeDecodingConfig { draftTokenMax = null, draftTokenMin = null, minProbability = null, - draftModelPath = null; + draftModelPath = null, + ngramSize = null; /// Enables multi-token prediction speculative decoding. const SpeculativeDecodingConfig.mtp({ @@ -104,12 +123,29 @@ class SpeculativeDecodingConfig { this.minProbability, this.draftModelPath, }) : strategy = SpeculativeDecodingStrategy.mtp, + ngramSize = null, assert(draftTokenMax == null || draftTokenMax >= 0), assert(draftTokenMin == null || draftTokenMin >= 0), assert( minProbability == null || (minProbability >= 0.0 && minProbability <= 1.0), ); + + /// Enables llama.cpp ngram-simple speculative decoding. + /// + /// Ngram-simple uses previous tokens as its draft source. In llama.cpp, + /// pass `draftTokenMax: 2` or lower. Leaving `draftTokenMax` null allows the + /// backend default, which can be above 2 and may be rejected by a backend + /// that requires deterministic parity with non-speculative decoding. + const SpeculativeDecodingConfig.ngramSimple({ + this.draftTokenMax, + this.ngramSize, + }) : strategy = SpeculativeDecodingStrategy.ngramSimple, + draftTokenMin = null, + minProbability = null, + draftModelPath = null, + assert(draftTokenMax == null || draftTokenMax >= 0), + assert(ngramSize == null || ngramSize > 0); } /// Parameters controlling the token sampling and generation process. diff --git a/test/integration/backends/llama_cpp/native_symbol_integration_test.dart b/test/integration/backends/llama_cpp/native_symbol_integration_test.dart index e3d8ccc9..da9c871a 100644 --- a/test/integration/backends/llama_cpp/native_symbol_integration_test.dart +++ b/test/integration/backends/llama_cpp/native_symbol_integration_test.dart @@ -10,6 +10,7 @@ import 'package:llamadart/src/backends/llama_cpp/bindings.dart'; import 'package:test/test.dart'; const _llamadartWrapperAssetId = 'package:llamadart/llamadart_wrapper'; +const _llamadartPrimaryAssetId = 'package:llamadart/llamadart'; const _mtpSymbols = [ 'llama_dart_mtp_init', @@ -23,6 +24,15 @@ const _mtpSymbols = [ 'llama_dart_sampler_sample_and_accept_n', ]; +const _ngramSymbols = [ + 'llama_dart_ngram_simple_init', + 'llama_dart_ngram_free', + 'llama_dart_ngram_begin', + 'llama_dart_ngram_process_batch', + 'llama_dart_ngram_draft', + 'llama_dart_ngram_accept', +]; + const _transparentPngBytes = [ 0x89, 0x50, @@ -116,6 +126,25 @@ typedef _MtmdBitmapFreeDart = void Function(ffi.Pointer); ) external void _windowsMtpFree(ffi.Pointer mtp); +File? _llamadartWrapperLibraryFileOrNull() { + if (Platform.isWindows) { + try { + return _windowsMtpWrapperLibraryFile(); + } catch (_) { + return null; + } + } + + final nativeAssetPath = + _nativeAssetFilePath(_llamadartWrapperAssetId) ?? + _nativeAssetFilePath(_llamadartPrimaryAssetId); + if (nativeAssetPath == null) { + return null; + } + final file = File(nativeAssetPath); + return file.existsSync() ? file : null; +} + File _windowsMtpWrapperLibraryFile() { final dartToolLibPath = [ Directory.current.path, @@ -125,6 +154,7 @@ File _windowsMtpWrapperLibraryFile() { final dartToolLibDir = Directory(dartToolLibPath); final candidates = [ ?_nativeAssetFilePath(_llamadartWrapperAssetId), + ?_nativeAssetFilePath(_llamadartPrimaryAssetId), ..._matchingWindowsLibraryPaths( dartToolLibDir, RegExp(r'^llamadart(?:[-_][^.\\/]+)*\.dll$'), @@ -165,7 +195,9 @@ File? _mtmdFallbackLibraryFile() { 'lib', ].join(Platform.pathSeparator); final directories = []; - final nativeAssetPath = _nativeAssetFilePath(_llamadartWrapperAssetId); + final nativeAssetPath = + _nativeAssetFilePath(_llamadartWrapperAssetId) ?? + _nativeAssetFilePath(_llamadartPrimaryAssetId); if (nativeAssetPath != null) { directories.add(File(nativeAssetPath).parent); } @@ -327,12 +359,12 @@ void _expectBitmapHelperDecodesTransparentPng( void main() { group('Native Symbol Availability', () { - test('Verify MTP symbols are declared in generated bindings', () { + test('Verify speculative symbols are declared in generated bindings', () { final bindingsSource = File( 'lib/src/backends/llama_cpp/bindings.dart', ).readAsStringSync(); - for (final symbol in _mtpSymbols) { + for (final symbol in [..._mtpSymbols, ..._ngramSymbols]) { expect( bindingsSource, matches(RegExp(r'external\s+[\s\S]*?\b' + RegExp.escape(symbol))), @@ -426,6 +458,78 @@ void main() { } }); + test('Verify ngram wrapper symbols are resolvable when exported', () { + final wrapper = _llamadartWrapperLibraryFileOrNull(); + if (wrapper == null) { + markTestSkipped('Unable to locate the llama.cpp wrapper library.'); + return; + } + + final missing = _ngramSymbols + .where((symbol) => !_fileContainsAscii(wrapper, symbol)) + .toList(growable: false); + if (missing.isNotEmpty) { + markTestSkipped( + 'Current native bundle does not export ngram wrapper symbols: ' + '${missing.join(', ')}.', + ); + return; + } + + final nullNgram = ffi.nullptr.cast(); + final nullTokenArray = ffi.nullptr.cast(); + final session = llama_dart_ngram_simple_init(1, 1); + expect(session.address, isNot(0)); + + try { + expect(() => llama_dart_ngram_free(nullNgram), returnsNormally); + expect( + llama_dart_ngram_begin(nullNgram, 0, nullTokenArray, 0), + isFalse, + ); + expect(llama_dart_ngram_begin(session, 1, nullTokenArray, 0), isFalse); + expect( + llama_dart_ngram_draft( + nullNgram, + 0, + 0, + 0, + nullTokenArray, + 0, + 1, + nullTokenArray, + 0, + ), + -1, + ); + expect( + llama_dart_ngram_draft( + session, + 1, + 0, + 0, + nullTokenArray, + 0, + 1, + nullTokenArray, + 0, + ), + -1, + ); + expect(() => llama_dart_ngram_accept(nullNgram, 0, 0), returnsNormally); + expect(() => llama_dart_ngram_accept(session, 0, 0), returnsNormally); + + final batch = llama_batch_init(1, 0, 1); + try { + expect(llama_dart_ngram_process_batch(nullNgram, batch), isFalse); + } finally { + llama_batch_free(batch); + } + } finally { + llama_dart_ngram_free(session); + } + }); + test('Verify multimodal symbols are resolvable', () { // Some bundles export mtmd via the primary llama asset while others ship // it as a dedicated mtmd shared library loaded via runtime fallback. diff --git a/test/unit/backends/litert_lm/litert_lm_service_test.dart b/test/unit/backends/litert_lm/litert_lm_service_test.dart index a2e12a2f..65000bfc 100644 --- a/test/unit/backends/litert_lm/litert_lm_service_test.dart +++ b/test/unit/backends/litert_lm/litert_lm_service_test.dart @@ -2157,6 +2157,30 @@ void main() { ), ), ); + + await expectLater( + service.generate( + contextHandle, + 'hello', + const GenerationParams( + speculativeDecodingConfig: SpeculativeDecodingConfig.ngramSimple( + draftTokenMax: 4, + ngramSize: 12, + ), + ), + ), + emitsError( + isA().having( + (error) => error.message.toString(), + 'message', + allOf( + contains('speculativeDecodingConfig.strategy'), + contains('speculativeDecodingConfig.draftTokenMax'), + contains('speculativeDecodingConfig.ngramSize'), + ), + ), + ), + ); } finally { service.dispose(); } diff --git a/test/unit/core/models/inference/generation_params_test.dart b/test/unit/core/models/inference/generation_params_test.dart index 2cde100c..5ef9ebb7 100644 --- a/test/unit/core/models/inference/generation_params_test.dart +++ b/test/unit/core/models/inference/generation_params_test.dart @@ -75,6 +75,25 @@ void main() { ); }); + test('SpeculativeDecodingConfig.ngramSimple stores n-gram controls', () { + const config = SpeculativeDecodingConfig.ngramSimple( + draftTokenMax: 6, + ngramSize: 10, + ); + const params = GenerationParams(speculativeDecodingConfig: config); + + expect(params.isSpeculativeDecodingEnabled, isTrue); + expect( + params.resolvedSpeculativeDecodingConfig?.strategy, + SpeculativeDecodingStrategy.ngramSimple, + ); + expect(params.resolvedSpeculativeDecodingConfig?.draftTokenMax, 6); + expect(params.resolvedSpeculativeDecodingConfig?.ngramSize, 10); + expect(params.resolvedSpeculativeDecodingConfig?.draftModelPath, isNull); + expect(params.resolvedSpeculativeDecodingConfig?.draftTokenMin, isNull); + expect(params.resolvedSpeculativeDecodingConfig?.minProbability, isNull); + }); + test('GenerationParams copyWith can clear speculative decoding config', () { const params = GenerationParams( speculativeDecodingConfig: SpeculativeDecodingConfig.mtp( diff --git a/test/unit/tooling/test_matrix_test.dart b/test/unit/tooling/test_matrix_test.dart index 097bfe0e..58086e71 100644 --- a/test/unit/tooling/test_matrix_test.dart +++ b/test/unit/tooling/test_matrix_test.dart @@ -28,6 +28,7 @@ void main() { expect(ids, contains('gguf-chat-features-smoke')); expect(ids, contains('litert-lm-chat-features-smoke')); expect(ids, contains('native-inference-benchmark')); + expect(ids, contains('llama-cpp-speculative-benchmark')); expect(ids, contains('web-mock-chat-smoke')); expect(ids, contains('web-real-model-smoke')); expect(ids, contains('webgpu-multimodal-regression')); @@ -72,6 +73,9 @@ void main() { expect(table, contains('| ID | Tier | Mode |')); expect(table, contains('gguf-chat-features-smoke')); + expect(table, contains('LLAMADART_MTP_BENCHMARK_NGRAM_ONLY=true')); + expect(table, contains('llama_cpp_mtp_benchmark.dart ')); + expect(table, contains('1,2 1')); expect(table, contains('run_local_e2e.dart')); expect(table, isNot(contains('static-format-analyze'))); }); diff --git a/tool/testing/llama_cpp_mtp_benchmark.dart b/tool/testing/llama_cpp_mtp_benchmark.dart index 747d3325..0001ef2a 100644 --- a/tool/testing/llama_cpp_mtp_benchmark.dart +++ b/tool/testing/llama_cpp_mtp_benchmark.dart @@ -12,7 +12,11 @@ Future main(List args) async { '[draft-token-max-list] [warmups]\n' 'Set LLAMADART_MTP_BENCHMARK_INSTRUCTION to override the prompt.\n' 'Set LLAMADART_MTP_BENCHMARK_BACKEND to override the backend.\n' - 'Set LLAMADART_MTP_BENCHMARK_RAW_PROMPT=true to skip chat wrapping.', + 'Set LLAMADART_MTP_BENCHMARK_RAW_PROMPT=true to skip chat wrapping.\n' + 'Set LLAMADART_MTP_BENCHMARK_NGRAM=true to include ngram-simple cases.\n' + 'Set LLAMADART_MTP_BENCHMARK_NGRAM_ONLY=true to omit MTP cases.\n' + 'Set LLAMADART_MTP_BENCHMARK_NGRAM_SIZE to override ngram-simple size.\n' + 'Set LLAMADART_MTP_BENCHMARK_PENALTY to override the repeat penalty.', ); exitCode = 64; return; @@ -35,11 +39,26 @@ Future main(List args) async { ); final rawPrompt = Platform.environment['LLAMADART_MTP_BENCHMARK_RAW_PROMPT'] == 'true'; + final ngramOnly = + Platform.environment['LLAMADART_MTP_BENCHMARK_NGRAM_ONLY'] == 'true'; + final includeNgramSimple = + ngramOnly || + Platform.environment['LLAMADART_MTP_BENCHMARK_NGRAM'] == 'true'; + final ngramSize = + int.tryParse( + Platform.environment['LLAMADART_MTP_BENCHMARK_NGRAM_SIZE'] ?? '', + ) ?? + 12; final maxDraftTokenMax = draftTokenMaxValues.fold( 1, (max, value) => value > max ? value : max, ); + final generationPenalty = + double.tryParse( + Platform.environment['LLAMADART_MTP_BENCHMARK_PENALTY'] ?? '', + ) ?? + 1.1; final baselineModelParams = ModelParams( contextSize: 2048, preferredBackend: preferredBackend, @@ -65,11 +84,18 @@ Future main(List args) async { final benchmarkCases = <_BenchmarkCase>[ const _BenchmarkCase.baseline(), - for (final draftTokenMax in draftTokenMaxValues) - _BenchmarkCase.mtp( - draftModelPath: draftModelPath, - draftTokenMax: draftTokenMax, - ), + if (!ngramOnly) + for (final draftTokenMax in draftTokenMaxValues) + _BenchmarkCase.mtp( + draftModelPath: draftModelPath, + draftTokenMax: draftTokenMax, + ), + if (includeNgramSimple) + for (final draftTokenMax in draftTokenMaxValues) + _BenchmarkCase.ngramSimple( + draftTokenMax: draftTokenMax, + ngramSize: ngramSize, + ), ]; final results = <_RunResult>[]; @@ -84,6 +110,7 @@ Future main(List args) async { : baselineModelParams, prompt: prompt, maxTokens: maxTokens, + generationPenalty: generationPenalty, benchmarkCase: benchmarkCase, runIndex: i, warmup: true, @@ -102,6 +129,7 @@ Future main(List args) async { : baselineModelParams, prompt: prompt, maxTokens: maxTokens, + generationPenalty: generationPenalty, benchmarkCase: benchmarkCase, runIndex: i, warmup: false, @@ -121,6 +149,10 @@ Future main(List args) async { 'backend': backendName, 'model': modelPath, 'draftModel': draftModelPath, + 'includeNgramSimple': includeNgramSimple, + 'ngramOnly': ngramOnly, + 'ngramSize': includeNgramSimple ? ngramSize : null, + 'penalty': generationPenalty, 'maxTokens': maxTokens, 'measuredRuns': measuredRuns, 'warmupRuns': warmupRuns, @@ -191,6 +223,7 @@ Future<_RunResult> _runCase({ required ModelParams modelParams, required String prompt, required int maxTokens, + required double generationPenalty, required _BenchmarkCase benchmarkCase, required int runIndex, required bool warmup, @@ -210,6 +243,7 @@ Future<_RunResult> _runCase({ GenerationParams( maxTokens: maxTokens, temp: 0.0, + penalty: generationPenalty, seed: 7, reusePromptPrefix: false, speculativeDecodingConfig: benchmarkCase.speculativeDecodingConfig, @@ -370,6 +404,19 @@ class _BenchmarkCase { ); } + factory _BenchmarkCase.ngramSimple({ + required int draftTokenMax, + required int ngramSize, + }) { + return _BenchmarkCase._( + 'ngram_simple_n${ngramSize}_draft_$draftTokenMax', + SpeculativeDecodingConfig.ngramSimple( + draftTokenMax: draftTokenMax, + ngramSize: ngramSize, + ), + ); + } + final String name; final SpeculativeDecodingConfig? speculativeDecodingConfig; diff --git a/tool/testing/test_matrix.dart b/tool/testing/test_matrix.dart index 3e84a40e..976c390d 100644 --- a/tool/testing/test_matrix.dart +++ b/tool/testing/test_matrix.dart @@ -136,6 +136,23 @@ const List testMatrixRows = [ useWhen: 'Generation latency, streaming, batching, prompt reuse, or performance changes.', ), + TestMatrixRow( + id: 'llama-cpp-speculative-benchmark', + tier: 'targeted', + mode: 'local-only', + covers: + 'real GGUF llama.cpp speculative decoding baseline, MTP, and optional ' + 'ngram-simple throughput/acceptance metrics', + command: + 'LLAMADART_MTP_BENCHMARK_NGRAM=true ' + 'LLAMADART_MTP_BENCHMARK_NGRAM_ONLY=true ' + 'LLAMADART_MTP_BENCHMARK_NGRAM_SIZE=1 dart run ' + 'tool/testing/llama_cpp_mtp_benchmark.dart - 128 3 ' + '1,2 1', + useWhen: + 'llama.cpp speculative decoding strategy, wrapper, rollback, or ' + 'performance changes.', + ), TestMatrixRow( id: 'gguf-chat-features-smoke', tier: 'targeted', diff --git a/website/docs/changelog/recent-releases.md b/website/docs/changelog/recent-releases.md index 08f945e3..831225fa 100644 --- a/website/docs/changelog/recent-releases.md +++ b/website/docs/changelog/recent-releases.md @@ -9,6 +9,10 @@ For canonical full release notes, use: ## Unreleased +- Added llama.cpp ngram-simple speculative decoding through + `SpeculativeDecodingConfig.ngramSimple(...)`, including Dart routing, native + wrapper bindings, docs, and local benchmark matrix coverage. + - Added `LlamaStructuredOutput` and `LlamaEngine.createStructuredJson(...)` helpers for strict JSON-object / JSON-schema generation with final-output validation and typed decoding. diff --git a/website/docs/configuration/runtime-parameters.md b/website/docs/configuration/runtime-parameters.md index 7fd428ab..920aca41 100644 --- a/website/docs/configuration/runtime-parameters.md +++ b/website/docs/configuration/runtime-parameters.md @@ -109,8 +109,14 @@ Important fields: speculative decoding. Native LiteRT-LM honors the legacy boolean flag. llama.cpp supports `SpeculativeDecodingConfig.mtp(...)` for compatible MTP GGUF models, including separate target/draft model pairs through - `draftModelPath`. WebGPU and LiteRT-LM web reject speculative decoding until - their speculative paths are implemented. + `draftModelPath`. Native llama.cpp bundles that export the + `llama_dart_ngram_*` wrapper symbols also support + `SpeculativeDecodingConfig.ngramSimple(...)` for token-history n-gram + speculation without a draft model. llamadart currently supports + ngram-simple `draftTokenMax <= 2`; deeper drafts can diverge from + non-speculative decoding on some model/backend combinations. WebGPU and + LiteRT-LM web reject speculative decoding until their speculative paths are + implemented. - `seed`: deterministic replay when set. - `grammar`: constrained decoding with GBNF. diff --git a/website/docs/guides/performance-tuning.md b/website/docs/guides/performance-tuning.md index fe737fb5..d8e57bc2 100644 --- a/website/docs/guides/performance-tuning.md +++ b/website/docs/guides/performance-tuning.md @@ -114,7 +114,12 @@ Guidelines: default remains off because it is not a universal speedup. Native LiteRT-LM uses the legacy `speculativeDecoding` boolean, while llama.cpp MTP uses `SpeculativeDecodingConfig.mtp(...)` and can optionally load a separate draft - GGUF through `draftModelPath`. + GGUF through `draftModelPath`. llama.cpp ngram-simple uses + `SpeculativeDecodingConfig.ngramSimple(...)` for draft-model-free + token-history speculation when the native bundle exports the required + `llama_dart_ngram_*` symbols. llamadart currently supports ngram-simple + `draftTokenMax <= 2`; deeper drafts can diverge from non-speculative decoding + on some model/backend combinations. - `reusePromptPrefix` is enabled by default for native generation; keep it on for multi-turn chats and repeated prompts, and validate parity for your target model/workload.