diff --git a/CHANGELOG.md b/CHANGELOG.md index a7776a4d..0b2332e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,12 @@ options as `loadModelSource(...)`, while preserving the existing `loadMultimodalProjector(...)` path/string API. +* Improved the runnable chat app's Manage Models cache UX so model and mmproj + asset cache states are shown separately, missing multimodal projectors can be + re-cached without re-fetching already cached model assets, and runtime media + capability mismatches surface as user-readable warnings. Custom signed or + tokenized Hugging Face URLs now require confirmation before they are saved. + ## 0.8.12 * Updated the default LiteRT-LM native runtime pin to diff --git a/example/chat_app/lib/screens/manage_models_screen.dart b/example/chat_app/lib/screens/manage_models_screen.dart index bc295669..1d363d55 100644 --- a/example/chat_app/lib/screens/manage_models_screen.dart +++ b/example/chat_app/lib/screens/manage_models_screen.dart @@ -54,6 +54,7 @@ class _ManageModelsScreenState extends State final Map> _downloadUiStateByFile = {}; + final Map _cacheStateByFile = {}; final Map _lastDownloadedBytes = {}; final Map _lastDownloadSampleAt = {}; final Map _smoothedDownloadRateBytesPerSec = {}; @@ -97,11 +98,19 @@ class _ManageModelsScreenState extends State await _loadCustomModels(); _modelsDir = await _modelService.getModelsDirectory(); _downloadedFiles = await _modelService.getDownloadedModels(_models); + await _refreshCacheStates(_models); if (mounted) { setState(() {}); } } + Future _refreshCacheStates(Iterable models) async { + for (final model in models) { + _cacheStateByFile[model.filename] = await _modelService + .getModelCacheState(model); + } + } + Future _loadCustomModels() async { final prefs = await SharedPreferences.getInstance(); final entries = prefs.getStringList(_customModelsPrefsKey) ?? const []; @@ -233,6 +242,7 @@ class _ManageModelsScreenState extends State }); _downloadedFiles = await _modelService.getDownloadedModels(_models); + await _refreshCacheStates([model]); if (mounted) { setState(() {}); } @@ -255,6 +265,81 @@ class _ManageModelsScreenState extends State return null; } + List _credentialLikeCustomUrlLabels({ + required String modelUrl, + required String? mmprojUrl, + }) { + return [ + if (_hasCredentialLikePersistentUrlParts(modelUrl)) 'GGUF URL', + if (mmprojUrl != null && + mmprojUrl.isNotEmpty && + _hasCredentialLikePersistentUrlParts(mmprojUrl)) + 'MMProj URL', + ]; + } + + bool _hasCredentialLikePersistentUrlParts(String value) { + final uri = Uri.tryParse(value.trim()); + if (uri == null) { + return false; + } + if (uri.userInfo.isNotEmpty || uri.fragment.isNotEmpty) { + return true; + } + + const benignQueryKeys = {'download'}; + return uri.queryParameters.keys.any((key) { + final lower = key.toLowerCase(); + if (benignQueryKeys.contains(lower)) { + return false; + } + return lower.contains('token') || + lower.contains('sig') || + lower.contains('signature') || + lower.contains('expires') || + lower.contains('credential') || + lower.contains('key') || + lower.contains('secret') || + lower.contains('auth') || + lower.contains('session') || + lower.startsWith('x-amz'); + }); + } + + Future _confirmSavingCredentialLikeCustomUrls( + BuildContext context, + List labels, + ) async { + final labelText = labels.length == 1 + ? labels.single + : '${labels.take(labels.length - 1).join(', ')} and ${labels.last}'; + final confirmed = await showDialog( + context: context, + builder: (context) { + return AlertDialog( + title: const Text('Save credentialed URL?'), + content: Text( + '$labelText includes user info, a fragment, or credential-like ' + 'query parameters. Custom model URLs are saved in local ' + 'preferences. Prefer a public ?download=true URL or runtime ' + 'headers for private access.', + ), + actions: [ + TextButton( + onPressed: () => Navigator.of(context).pop(false), + child: const Text('Review URL'), + ), + FilledButton( + onPressed: () => Navigator.of(context).pop(true), + child: const Text('Save anyway'), + ), + ], + ); + }, + ); + return confirmed ?? false; + } + Future _showAddHuggingFaceDialog() async { final nameController = TextEditingController(); final modelUrlController = TextEditingController(); @@ -376,6 +461,22 @@ class _ManageModelsScreenState extends State return; } + final credentialUrlLabels = _credentialLikeCustomUrlLabels( + modelUrl: url, + mmprojUrl: mmprojUrl.isEmpty ? null : mmprojUrl, + ); + if (credentialUrlLabels.isNotEmpty) { + if (!dialogContext.mounted) return; + final confirmed = + await _confirmSavingCredentialLikeCustomUrls( + dialogContext, + credentialUrlLabels, + ); + if (!confirmed) { + return; + } + } + await _addCustomModelEntry(customModel); if (!dialogContext.mounted) return; Navigator.of(dialogContext).pop(); @@ -739,8 +840,13 @@ class _ManageModelsScreenState extends State clearTask: true, ); _clearDownloadTracking(model.filename); + final downloadedFiles = await _modelService.getDownloadedModels(_models); + await _refreshCacheStates(_models); + if (!mounted) { + return; + } setState(() { - _downloadedFiles.add(model.filename); + _downloadedFiles = downloadedFiles; }); ScaffoldMessenger.of(context).showSnackBar( SnackBar(content: Text('${model.name} downloaded successfully.')), @@ -764,6 +870,15 @@ class _ManageModelsScreenState extends State _clearDownloadTracking(model.filename); } + final downloadedFiles = await _modelService.getDownloadedModels(_models); + await _refreshCacheStates(_models); + if (!mounted) { + return; + } + setState(() { + _downloadedFiles = downloadedFiles; + }); + ScaffoldMessenger.of(context).showSnackBar( SnackBar( content: Text( @@ -837,8 +952,13 @@ class _ManageModelsScreenState extends State }); if (provider.error == null) { + final capabilityWarning = _runtimeCapabilityWarning(model, provider); ScaffoldMessenger.of(context).showSnackBar( - SnackBar(content: Text('${model.name} loaded successfully.')), + SnackBar( + content: Text( + capabilityWarning ?? '${model.name} loaded successfully.', + ), + ), ); widget.onModelActivated?.call(); } @@ -862,9 +982,12 @@ class _ManageModelsScreenState extends State ); _clearDownloadTracking(model.filename); await _disposeDownloadController(model.filename); + final downloadedFiles = await _modelService.getDownloadedModels(_models); + await _refreshCacheStates(_models); + if (!mounted) return; setState(() { - _downloadedFiles.remove(model.filename); + _downloadedFiles = downloadedFiles; }); } @@ -932,6 +1055,8 @@ class _ManageModelsScreenState extends State } _downloadControllers.clear(); _downloadedFiles = await _modelService.getDownloadedModels(_models); + _cacheStateByFile.clear(); + await _refreshCacheStates(_models); await _saveCustomModels(); @@ -1060,6 +1185,24 @@ class _ManageModelsScreenState extends State : null; } + String? _runtimeCapabilityWarning( + DownloadableModel model, + ChatProvider provider, + ) { + final missing = [ + if (model.supportsVision && !provider.supportsVision) 'vision', + if (model.supportsAudio && !provider.supportsAudio) 'audio', + ]; + if (missing.isEmpty) { + return null; + } + + final capabilityLabel = missing.length == 1 + ? missing.single + : '${missing.take(missing.length - 1).join(', ')} and ${missing.last}'; + return 'Loaded, but the active runtime/projector did not report $capabilityLabel support. Media controls stay disabled for unsupported inputs.'; + } + Future _setSelectedModelMmprojMode( ChatProvider provider, DownloadableModel model, { @@ -1339,6 +1482,7 @@ class _ManageModelsScreenState extends State isDownloaded: _downloadedFiles.contains( model.filename, ), + cacheState: _cacheStateByFile[model.filename], isDownloading: downloadState.isDownloading, progress: downloadState.progress, downloadStatusLabel: detail == null diff --git a/example/chat_app/lib/services/model_service_base.dart b/example/chat_app/lib/services/model_service_base.dart index 03a00862..6f83271d 100644 --- a/example/chat_app/lib/services/model_service_base.dart +++ b/example/chat_app/lib/services/model_service_base.dart @@ -154,6 +154,147 @@ abstract class WebCachePrefetchModelService { Future supportsWebCachePrefetch(); } +class ModelAssetCacheState { + final ModelAssetRole role; + final String label; + final bool isAvailable; + + const ModelAssetCacheState({ + required this.role, + required this.label, + required this.isAvailable, + }); +} + +class ModelProfileCacheState { + final ModelAssetCacheState model; + final ModelAssetCacheState? multimodalProjector; + + const ModelProfileCacheState({required this.model, this.multimodalProjector}); + + bool get isReady => model.isAvailable && _projectorReady; + + bool get hasPartialAssets => + !isReady && + (model.isAvailable || (multimodalProjector?.isAvailable ?? false)); + + List get availableAssetLabels => [ + if (model.isAvailable) 'model', + if (multimodalProjector?.isAvailable ?? false) 'mmproj', + ]; + + List get missingAssetLabels => [ + if (!model.isAvailable) 'model', + if (multimodalProjector != null && !multimodalProjector!.isAvailable) + 'mmproj', + ]; + + bool get _projectorReady => + multimodalProjector == null || multimodalProjector!.isAvailable; +} + +/// Tracks persisted cache markers by asset cache key. +class ModelAssetCacheMarkers { + final Set _cacheKeys; + + ModelAssetCacheMarkers(Iterable cacheKeys) + : _cacheKeys = cacheKeys.toSet(); + + Set toSet() => _cacheKeys.toSet(); + + bool containsAsset(ModelAssetSource source) { + return source is RemoteModelAssetSource && + _cacheKeys.contains(source.cacheKey); + } + + bool containsMarker(String value) => _cacheKeys.contains(value); + + void markAssetCached(RemoteModelAssetSource source) { + _cacheKeys.add(source.cacheKey); + } + + void markAssetsCached(Iterable sources) { + for (final source in sources) { + markAssetCached(source); + } + } + + void removeMarker(String value) { + _cacheKeys.remove(value); + } + + void removeAssets(Iterable sources) { + for (final source in sources) { + _cacheKeys.remove(source.cacheKey); + } + } + + bool isProfileCached(DownloadableModel model, {required bool web}) { + final sources = _remoteSourcesFor(model, web: web); + if (sources.length != _assetSourcesFor(model, web: web).length) { + return false; + } + return sources.every(containsAsset); + } + + bool migrateLegacyProfileMarker( + DownloadableModel model, { + required bool web, + }) { + final sources = _remoteSourcesFor(model, web: web); + if (!containsMarker(model.filename) || + sources.length != _assetSourcesFor(model, web: web).length) { + return false; + } + removeMarker(model.filename); + markAssetsCached(sources); + return true; + } + + ModelProfileCacheState modelCacheState( + DownloadableModel model, { + required bool web, + }) { + final modelSource = model.modelSourceFor(web: web); + final mmprojSource = model.multimodalProjectorSourceFor(web: web); + return ModelProfileCacheState( + model: _assetCacheState(ModelAssetRole.model, modelSource), + multimodalProjector: mmprojSource == null + ? null + : _assetCacheState(ModelAssetRole.multimodalProjector, mmprojSource), + ); + } + + ModelAssetCacheState _assetCacheState( + ModelAssetRole role, + ModelAssetSource source, + ) { + return ModelAssetCacheState( + role: role, + label: source.displayName, + isAvailable: containsAsset(source), + ); + } + + List _assetSourcesFor( + DownloadableModel model, { + required bool web, + }) { + final projector = model.multimodalProjectorSourceFor(web: web); + return [model.modelSourceFor(web: web), ?projector]; + } + + List _remoteSourcesFor( + DownloadableModel model, { + required bool web, + }) { + return _assetSourcesFor( + model, + web: web, + ).whereType().toList(growable: false); + } +} + abstract class ModelService { factory ModelService() => createModelService(); @@ -161,6 +302,8 @@ abstract class ModelService { Future> getDownloadedModels(List models); + Future getModelCacheState(DownloadableModel model); + Future downloadModel({ required DownloadableModel model, required String modelsDir, diff --git a/example/chat_app/lib/services/model_service_io.dart b/example/chat_app/lib/services/model_service_io.dart index 77354a77..205d81c4 100644 --- a/example/chat_app/lib/services/model_service_io.dart +++ b/example/chat_app/lib/services/model_service_io.dart @@ -48,21 +48,8 @@ class ModelServiceIO implements ModelService { final Set downloaded = {}; for (final model in models) { - final hasModel = await _isAssetAvailable( - dirPath, - model.modelSource, - role: ModelAssetRole.model, - ); - final mmprojSource = model.multimodalProjectorSource; - final hasMmproj = - mmprojSource == null || - await _isAssetAvailable( - dirPath, - mmprojSource, - role: ModelAssetRole.multimodalProjector, - ); - - if (hasModel && hasMmproj) { + final cacheState = await _modelCacheState(model, dirPath); + if (cacheState.isReady) { downloaded.add(model.filename); } } @@ -70,6 +57,14 @@ class ModelServiceIO implements ModelService { return downloaded; } + @override + Future getModelCacheState( + DownloadableModel model, + ) async { + final dirPath = await getModelsDirectory(); + return _modelCacheState(model, dirPath); + } + @override Future downloadModel({ required DownloadableModel model, @@ -87,35 +82,55 @@ class ModelServiceIO implements ModelService { model.multimodalProjectorSource is RemoteModelAssetSource ? model.multimodalProjectorSource as RemoteModelAssetSource : null; - final stageCount = [ - modelRemoteSource, - mmprojRemoteSource, - ].whereType().length; - final modelStageIndex = modelRemoteSource == null ? 0 : 1; - final mmprojStageIndex = mmprojRemoteSource == null - ? 0 - : modelRemoteSource == null - ? 1 - : 2; - final providedTotalBytes = - modelRemoteSource == null && mmprojRemoteSource != null - ? mmprojRemoteSource.sizeBytes - : (model.sizeBytes > 0 ? model.sizeBytes : null); final progressDispatcher = _ProgressDispatcher( onProgress: onProgress, onProgressDetail: onProgressDetail, ); - final aggregate = ModelDownloadProgressTracker( - includeMmproj: mmprojRemoteSource != null, - providedTotalBytes: providedTotalBytes, - ); try { await _validateLocalSource(model.modelSource); - if (modelRemoteSource != null) { - final modelSavePath = _assetPath(modelsDir, modelRemoteSource); + final modelNeedsDownload = + modelRemoteSource != null && + !await _isAssetAvailable( + modelsDir, + modelRemoteSource, + role: ModelAssetRole.model, + ); + + final mmprojSource = model.multimodalProjectorSource; + await _validateLocalSource(mmprojSource); + final mmprojNeedsDownload = + mmprojRemoteSource != null && + !await _isAssetAvailable( + modelsDir, + mmprojRemoteSource, + role: ModelAssetRole.multimodalProjector, + ); + + final stageCount = [ + if (modelNeedsDownload) modelRemoteSource, + if (mmprojNeedsDownload) mmprojRemoteSource, + ].whereType().length; + final modelStageIndex = modelNeedsDownload ? 1 : 0; + final mmprojStageIndex = mmprojNeedsDownload + ? (modelNeedsDownload ? 2 : 1) + : 0; + final aggregate = ModelDownloadProgressTracker( + includeMmproj: mmprojNeedsDownload, + providedTotalBytes: _providedDownloadTotalBytes( + model: model, + modelSource: modelRemoteSource, + modelNeedsDownload: modelNeedsDownload, + mmprojSource: mmprojRemoteSource, + mmprojNeedsDownload: mmprojNeedsDownload, + ), + ); + + if (modelNeedsDownload) { + final source = modelRemoteSource; + final modelSavePath = _assetPath(modelsDir, source); await _downloadFileWithResume( - url: modelRemoteSource.url, + url: source.url, savePath: modelSavePath, cancelToken: cancelToken, onProgress: (downloadedBytes, totalBytes, resumed) { @@ -134,12 +149,11 @@ class ModelServiceIO implements ModelService { ); } - final mmprojSource = model.multimodalProjectorSource; - await _validateLocalSource(mmprojSource); - if (mmprojRemoteSource != null) { - final mmprojSavePath = _assetPath(modelsDir, mmprojRemoteSource); + if (mmprojNeedsDownload) { + final source = mmprojRemoteSource; + final mmprojSavePath = _assetPath(modelsDir, source); await _downloadFileWithResume( - url: mmprojRemoteSource.url, + url: source.url, savePath: mmprojSavePath, cancelToken: cancelToken, onProgress: (downloadedBytes, totalBytes, resumed) { @@ -171,6 +185,56 @@ class ModelServiceIO implements ModelService { } } + Future _modelCacheState( + DownloadableModel model, + String modelsDir, + ) async { + final modelSource = model.modelSource; + final mmprojSource = model.multimodalProjectorSource; + return ModelProfileCacheState( + model: ModelAssetCacheState( + role: ModelAssetRole.model, + label: modelSource.displayName, + isAvailable: await _isAssetAvailable( + modelsDir, + modelSource, + role: ModelAssetRole.model, + ), + ), + multimodalProjector: mmprojSource == null + ? null + : ModelAssetCacheState( + role: ModelAssetRole.multimodalProjector, + label: mmprojSource.displayName, + isAvailable: await _isAssetAvailable( + modelsDir, + mmprojSource, + role: ModelAssetRole.multimodalProjector, + ), + ), + ); + } + + int? _providedDownloadTotalBytes({ + required DownloadableModel model, + required RemoteModelAssetSource? modelSource, + required bool modelNeedsDownload, + required RemoteModelAssetSource? mmprojSource, + required bool mmprojNeedsDownload, + }) { + if (modelNeedsDownload && mmprojNeedsDownload && model.sizeBytes > 0) { + return model.sizeBytes; + } + if (modelNeedsDownload) { + return modelSource?.sizeBytes ?? + (model.sizeBytes > 0 ? model.sizeBytes : null); + } + if (mmprojNeedsDownload) { + return mmprojSource?.sizeBytes; + } + return null; + } + Future _isAssetAvailable( String modelsDir, ModelAssetSource source, { diff --git a/example/chat_app/lib/services/model_service_web.dart b/example/chat_app/lib/services/model_service_web.dart index 1206ef7b..1e31c007 100644 --- a/example/chat_app/lib/services/model_service_web.dart +++ b/example/chat_app/lib/services/model_service_web.dart @@ -21,42 +21,36 @@ class ModelServiceWeb implements ModelService, WebCachePrefetchModelService { ) async { final prefs = await SharedPreferences.getInstance(); final downloaded = prefs.getStringList(_downloadedModelsKey) ?? const []; - final downloadedSet = downloaded.toSet(); + final markers = ModelAssetCacheMarkers(downloaded); final cachedModels = {}; var migratedLegacyMarkers = false; for (final model in models) { - if (_isProfileCached(model, downloadedSet)) { + if (markers.isProfileCached(model, web: true)) { cachedModels.add(model.filename); continue; } - final sources = _remoteSourcesFor(model); - if (downloadedSet.contains(model.filename) && - sources.length == _assetSourcesFor(model).length) { - downloadedSet.remove(model.filename); - for (final source in sources) { - downloadedSet.add(source.cacheKey); - } + if (markers.migrateLegacyProfileMarker(model, web: true)) { cachedModels.add(model.filename); migratedLegacyMarkers = true; } } if (migratedLegacyMarkers) { - await prefs.setStringList(_downloadedModelsKey, downloadedSet.toList()); + await prefs.setStringList(_downloadedModelsKey, markers.toSet().toList()); } return cachedModels; } - bool _isProfileCached(DownloadableModel model, Set downloaded) { - final sources = _remoteSourcesFor(model); - if (sources.length != _assetSourcesFor(model).length) { - return false; - } - - return sources.every((source) => downloaded.contains(source.cacheKey)); + @override + Future getModelCacheState( + DownloadableModel model, + ) async { + final prefs = await SharedPreferences.getInstance(); + final downloaded = prefs.getStringList(_downloadedModelsKey) ?? const []; + return ModelAssetCacheMarkers(downloaded).modelCacheState(model, web: true); } List _assetSourcesFor(DownloadableModel model) { @@ -89,20 +83,46 @@ class ModelServiceWeb implements ModelService, WebCachePrefetchModelService { return; } final remoteModelSource = modelSource as RemoteModelAssetSource; - final mmprojUrl = - (model.multimodalProjectorSourceFor(web: true) - as RemoteModelAssetSource?) - ?.url; - final stageCount = mmprojUrl == null ? 1 : 2; + final remoteMmprojSource = + model.multimodalProjectorSourceFor(web: true) + as RemoteModelAssetSource?; + final prefs = await SharedPreferences.getInstance(); + final downloaded = + prefs.getStringList(_downloadedModelsKey)?.toSet() ?? {}; + final markers = ModelAssetCacheMarkers(downloaded); + if (markers.migrateLegacyProfileMarker(model, web: true)) { + await prefs.setStringList(_downloadedModelsKey, markers.toSet().toList()); + } + final pendingAssets = <_PendingWebCacheAsset>[ + if (!markers.containsAsset(remoteModelSource)) + _PendingWebCacheAsset( + source: remoteModelSource, + stage: ModelDownloadStage.model, + ), + if (remoteMmprojSource != null && + !markers.containsAsset(remoteMmprojSource)) + _PendingWebCacheAsset( + source: remoteMmprojSource, + stage: ModelDownloadStage.multimodalProjector, + ), + ]; + final stageCount = pendingAssets.length; final aggregate = ModelDownloadProgressTracker( - includeMmproj: mmprojUrl != null, - providedTotalBytes: model.sizeBytesFor(web: true) > 0 - ? model.sizeBytesFor(web: true) - : null, + includeMmproj: pendingAssets.any( + (asset) => asset.stage == ModelDownloadStage.multimodalProjector, + ), + providedTotalBytes: _providedDownloadTotalBytes( + model: model, + pendingAssets: pendingAssets, + ), ); - if (_remoteSourcesFor( - model, - ).any((source) => _hasPersistentCacheSensitiveUrlParts(source.url))) { + if (pendingAssets.isEmpty) { + onSuccess(model.filename); + return; + } + if (pendingAssets.any( + (asset) => _hasPersistentCacheSensitiveUrlParts(asset.source.url), + )) { onError( UnsupportedError( 'Browser cache prefetch skipped for credentialed remote URL; load the model directly to avoid storing sensitive URL parts.', @@ -150,44 +170,32 @@ class ModelServiceWeb implements ModelService, WebCachePrefetchModelService { // A prefetch failure (including an old bridge that lacks // prefetchModelToCache) now surfaces as a real error instead of silently // marking the model cached. - await _prefetchStage( - bridge, - remoteModelSource.url, - stage: ModelDownloadStage.model, - stageIndex: 1, - stageCount: stageCount, - aggregate: aggregate, - updateStage: aggregate.updateModel, - onProgress: onProgress, - onProgressDetail: onProgressDetail, - ); - - if (mmprojUrl != null) { + for (var i = 0; i < pendingAssets.length; i++) { + final asset = pendingAssets[i]; await _prefetchStage( bridge, - mmprojUrl, - stage: ModelDownloadStage.multimodalProjector, - stageIndex: 2, + asset.source.url, + stage: asset.stage, + stageIndex: i + 1, stageCount: stageCount, aggregate: aggregate, - updateStage: aggregate.updateMmproj, + updateStage: asset.stage == ModelDownloadStage.model + ? aggregate.updateModel + : aggregate.updateMmproj, onProgress: onProgress, onProgressDetail: onProgressDetail, ); + markers.markAssetCached(asset.source); + await prefs.setStringList( + _downloadedModelsKey, + markers.toSet().toList(), + ); } final finalDetail = aggregate.finalProgress(stageCount: stageCount); onProgress(finalDetail.overallProgress); onProgressDetail?.call(finalDetail); - final prefs = await SharedPreferences.getInstance(); - final downloaded = - prefs.getStringList(_downloadedModelsKey)?.toSet() ?? {}; - for (final source in _remoteSourcesFor(model)) { - downloaded.add(source.cacheKey); - } - await prefs.setStringList(_downloadedModelsKey, downloaded.toList()); - onSuccess(model.filename); } catch (error) { if (_looksCancelled(error) || cancelToken.isCancelled) { @@ -249,6 +257,19 @@ class ModelServiceWeb implements ModelService, WebCachePrefetchModelService { } } + int? _providedDownloadTotalBytes({ + required DownloadableModel model, + required List<_PendingWebCacheAsset> pendingAssets, + }) { + if (pendingAssets.length > 1 && model.sizeBytesFor(web: true) > 0) { + return model.sizeBytesFor(web: true); + } + if (pendingAssets.length == 1) { + return pendingAssets.single.source.sizeBytes; + } + return null; + } + bool _hasPersistentCacheSensitiveUrlParts(String value) { final uri = Uri.tryParse(value); if (uri == null) { @@ -491,10 +512,9 @@ class ModelServiceWeb implements ModelService, WebCachePrefetchModelService { final prefs = await SharedPreferences.getInstance(); final downloaded = prefs.getStringList(_downloadedModelsKey)?.toSet() ?? {}; - for (final source in _remoteSourcesFor(model)) { - downloaded.remove(source.cacheKey); - } - await prefs.setStringList(_downloadedModelsKey, downloaded.toList()); + final markers = ModelAssetCacheMarkers(downloaded) + ..removeAssets(_remoteSourcesFor(model)); + await prefs.setStringList(_downloadedModelsKey, markers.toSet().toList()); final bridge = _tryCreateBridge(); if (bridge == null) { @@ -533,6 +553,13 @@ class _BridgeProgressSnapshot { const _BridgeProgressSnapshot({required this.loaded, required this.total}); } +class _PendingWebCacheAsset { + final RemoteModelAssetSource source; + final ModelDownloadStage stage; + + const _PendingWebCacheAsset({required this.source, required this.stage}); +} + @JS('LlamaWebGpuBridge') extension type _WebModelCacheBridge._(JSObject _) implements JSObject { external factory _WebModelCacheBridge([_WebModelCacheBridgeConfig? config]); diff --git a/example/chat_app/lib/widgets/model_card.dart b/example/chat_app/lib/widgets/model_card.dart index c7c2f39f..bf359511 100644 --- a/example/chat_app/lib/widgets/model_card.dart +++ b/example/chat_app/lib/widgets/model_card.dart @@ -2,10 +2,12 @@ import 'package:flutter/foundation.dart'; import 'package:flutter/material.dart'; import 'package:google_fonts/google_fonts.dart'; import '../models/downloadable_model.dart'; +import '../services/model_service_base.dart'; class ModelCard extends StatelessWidget { final DownloadableModel model; final bool isDownloaded; + final ModelProfileCacheState? cacheState; final bool isDownloading; final double progress; final String? downloadStatusLabel; @@ -25,6 +27,7 @@ class ModelCard extends StatelessWidget { super.key, required this.model, required this.isDownloaded, + this.cacheState, required this.isDownloading, required this.progress, this.downloadStatusLabel, @@ -50,6 +53,10 @@ class ModelCard extends StatelessWidget { final effectiveModelSizeBytes = model.sizeBytesFor(web: isWeb); final showWebLargeModelWarning = isWeb && effectiveModelSizeBytes >= webLargeModelWarningThresholdBytes; + final partialCacheMessage = _partialCacheMessage(cacheState, isWeb: isWeb); + final hasPartialCache = partialCacheMessage != null; + final hasAnyCachedAsset = + isDownloaded || (cacheState?.availableAssetLabels.isNotEmpty ?? false); final showMobileDownloadGuidance = isDownloading && !isWeb && @@ -107,6 +114,34 @@ class ModelCard extends StatelessWidget { ), ], ), + if (cacheState != null) ...[ + const SizedBox(height: 10), + Wrap( + spacing: 8, + runSpacing: 8, + children: [ + _buildCapabilityChip( + context, + icon: Icons.inventory_2_outlined, + label: cacheState!.model.isAvailable + ? 'Model cached' + : 'Model missing', + supported: cacheState!.model.isAvailable, + ), + if (cacheState!.multimodalProjector != null) + _buildCapabilityChip( + context, + icon: Icons.visibility_outlined, + label: + cacheState!.multimodalProjector!.isAvailable + ? 'mmproj cached' + : 'mmproj missing', + supported: + cacheState!.multimodalProjector!.isAvailable, + ), + ], + ), + ], const SizedBox(height: 10), Wrap( spacing: 8, @@ -152,7 +187,7 @@ class ModelCard extends StatelessWidget { ], ), ), - if (isDownloaded || (progress > 0 && !isDownloaded)) + if (hasAnyCachedAsset || (progress > 0 && !isDownloaded)) IconButton( icon: Icon( Icons.delete_outline_rounded, @@ -161,7 +196,11 @@ class ModelCard extends StatelessWidget { onPressed: onDelete, tooltip: progress > 0 && !isDownloaded ? 'Cancel & Discard' - : 'Delete Model', + : hasPartialCache + ? 'Delete cached assets' + : model.multimodalProjectorSourceFor(web: isWeb) == null + ? 'Delete Model' + : 'Delete model and mmproj', ), ], ), @@ -174,6 +213,29 @@ class ModelCard extends StatelessWidget { height: 1.4, ), ), + if (partialCacheMessage != null) ...[ + const SizedBox(height: 10), + Container( + width: double.infinity, + padding: const EdgeInsets.symmetric(horizontal: 10, vertical: 8), + decoration: BoxDecoration( + color: colorScheme.secondaryContainer.withValues(alpha: 0.45), + borderRadius: BorderRadius.circular(10), + border: Border.all( + color: colorScheme.outlineVariant.withValues(alpha: 0.45), + ), + ), + child: Text( + partialCacheMessage, + style: GoogleFonts.outfit( + fontSize: 12, + height: 1.3, + color: colorScheme.onSecondaryContainer, + fontWeight: FontWeight.w500, + ), + ), + ), + ], if (showWebLargeModelWarning) ...[ const SizedBox(height: 10), Container( @@ -448,8 +510,16 @@ class ModelCard extends StatelessWidget { ), label: Text( isWeb - ? (progress > 0 ? 'Resume Cache' : 'Cache Model') - : (progress > 0 ? 'Resume Download' : 'Download'), + ? (progress > 0 + ? 'Resume Cache' + : hasPartialCache + ? 'Cache Missing Assets' + : 'Cache Model') + : (progress > 0 + ? 'Resume Download' + : hasPartialCache + ? 'Download Missing Assets' + : 'Download'), ), style: OutlinedButton.styleFrom( padding: const EdgeInsets.symmetric(vertical: 16), @@ -469,6 +539,36 @@ class ModelCard extends StatelessWidget { return model.filenameFor(web: true).toLowerCase().endsWith('.litertlm'); } + String? _partialCacheMessage( + ModelProfileCacheState? state, { + required bool isWeb, + }) { + if (state == null || !state.hasPartialAssets) { + return null; + } + final available = _joinAssetLabels(state.availableAssetLabels); + final missing = _joinAssetLabels(state.missingAssetLabels); + final action = isWeb ? 'Cache' : 'Download'; + return '${_capitalize(available)} cached; $missing missing. $action will fetch only missing assets.'; + } + + String _joinAssetLabels(List labels) { + if (labels.isEmpty) { + return 'asset'; + } + if (labels.length == 1) { + return labels.single; + } + return '${labels.take(labels.length - 1).join(', ')} and ${labels.last}'; + } + + String _capitalize(String value) { + if (value.isEmpty) { + return value; + } + return value[0].toUpperCase() + value.substring(1); + } + Widget _buildMetaChip( BuildContext context, { required IconData icon, diff --git a/example/chat_app/test/manage_models_screen_download_test.dart b/example/chat_app/test/manage_models_screen_download_test.dart index d8eb0a5f..908062f2 100644 --- a/example/chat_app/test/manage_models_screen_download_test.dart +++ b/example/chat_app/test/manage_models_screen_download_test.dart @@ -1,4 +1,5 @@ import 'dart:async'; +import 'dart:convert'; import 'package:dio/dio.dart'; import 'package:flutter/foundation.dart'; @@ -162,6 +163,163 @@ void main() { expect(modelService.lastCancelToken?.isCancelled, isTrue); }); + + testWidgets( + 'selection warns when runtime lacks advertised vision support', + (tester) async { + SharedPreferences.setMockInitialValues({}); + final model = _remoteVisionModel(); + final modelService = _HoldingModelService( + downloadedFiles: {model.filename}, + ); + final provider = ChatProvider( + chatService: MockChatService(engine: _NoVisionEngine()), + settingsService: MockSettingsService(), + ); + addTearDown(provider.dispose); + + await _pumpScreen( + tester, + modelService: modelService, + models: [model], + provider: provider, + ); + + expect(find.text('Use this model'), findsOneWidget); + + await tester.ensureVisible(find.text('Use this model')); + await tester.pump(); + await tester.tap(find.text('Use this model')); + await tester.pump(); + await tester.pump(const Duration(milliseconds: 300)); + + expect( + find.textContaining( + 'active runtime/projector did not report vision support', + ), + findsOneWidget, + ); + }, + ); + + testWidgets('delete refreshes other profiles that share cached assets', ( + tester, + ) async { + SharedPreferences.setMockInitialValues({}); + final sharedMmproj = const RemoteModelAssetSource( + url: 'https://example.com/shared-mmproj.gguf', + filename: 'shared-mmproj.gguf', + ); + final first = DownloadableModel.fromSources( + id: 'first-vlm', + name: 'First VLM', + description: 'First profile with shared projector.', + modelSource: const RemoteModelAssetSource( + url: 'https://example.com/first.gguf', + filename: 'first.gguf', + ), + multimodalProjectorSource: sharedMmproj, + supportsVision: true, + ); + final second = DownloadableModel.fromSources( + id: 'second-vlm', + name: 'Second VLM', + description: 'Second profile with shared projector.', + modelSource: const RemoteModelAssetSource( + url: 'https://example.com/second.gguf', + filename: 'second.gguf', + ), + multimodalProjectorSource: sharedMmproj, + supportsVision: true, + ); + final modelService = _HoldingModelService( + cachedAssetKeys: { + (first.modelSource as RemoteModelAssetSource).cacheKey, + (second.modelSource as RemoteModelAssetSource).cacheKey, + sharedMmproj.cacheKey, + }, + ); + + await _pumpScreen( + tester, + modelService: modelService, + models: [first, second], + ); + + expect(find.text('First VLM'), findsOneWidget); + expect(find.text('Second VLM'), findsOneWidget); + expect( + find.text( + 'Model cached; mmproj missing. Download will fetch only missing assets.', + ), + findsNothing, + ); + + await tester.tap(find.byTooltip('Delete model and mmproj').first); + await tester.pumpAndSettle(); + + expect(modelService.deleteCalls, 1); + expect( + find.text( + 'Model cached; mmproj missing. Download will fetch only missing assets.', + ), + findsOneWidget, + ); + }); + + testWidgets('signed custom URLs require confirmation before saving', ( + tester, + ) async { + SharedPreferences.setMockInitialValues({}); + final modelService = _HoldingModelService(); + + await _pumpScreen(tester, modelService: modelService, models: []); + + await tester.tap(find.text('Add GGUF (HF)')); + await tester.pumpAndSettle(); + await tester.enterText( + find.widgetWithText(TextField, 'GGUF URL (Hugging Face)'), + 'https://huggingface.co/owner/repo/resolve/main/model.gguf?token=secret', + ); + + await tester.tap(find.text('Add model')); + await tester.pumpAndSettle(); + + expect(find.text('Save credentialed URL?'), findsOneWidget); + final warningDialog = find.ancestor( + of: find.text('Save credentialed URL?'), + matching: find.byType(AlertDialog), + ); + expect( + find.descendant( + of: warningDialog, + matching: find.textContaining('token=secret'), + ), + findsNothing, + ); + + await tester.tap(find.text('Review URL')); + await tester.pumpAndSettle(); + + var prefs = await SharedPreferences.getInstance(); + expect(prefs.getStringList('custom_hf_models_v1'), isNull); + expect(find.text('Add Hugging Face GGUF'), findsOneWidget); + + await tester.tap(find.text('Add model')); + await tester.pumpAndSettle(); + await tester.tap(find.text('Save anyway')); + await tester.pumpAndSettle(); + + prefs = await SharedPreferences.getInstance(); + final entries = prefs.getStringList('custom_hf_models_v1'); + expect(entries, hasLength(1)); + final saved = jsonDecode(entries!.single) as Map; + expect( + saved['url'], + 'https://huggingface.co/owner/repo/resolve/main/model.gguf?token=secret', + ); + expect(find.text('Added model.gguf'), findsOneWidget); + }); }); } @@ -171,16 +329,21 @@ Future _pumpScreen( WidgetTester tester, { required _HoldingModelService modelService, required List models, + ChatProvider? provider, }) async { - final provider = ChatProvider( - chatService: MockChatService(), - settingsService: MockSettingsService(), - ); - addTearDown(provider.dispose); + final effectiveProvider = + provider ?? + ChatProvider( + chatService: MockChatService(), + settingsService: MockSettingsService(), + ); + if (provider == null) { + addTearDown(effectiveProvider.dispose); + } await tester.pumpWidget( ChangeNotifierProvider.value( - value: provider, + value: effectiveProvider, child: MaterialApp( home: Scaffold( body: ManageModelsScreen( @@ -206,11 +369,33 @@ DownloadableModel _remoteModel() { ); } +DownloadableModel _remoteVisionModel() { + return const DownloadableModel( + name: 'Tiny Vision Model', + description: 'Small fake VLM for screen tests.', + url: 'https://example.com/tiny-vlm.gguf', + filename: 'tiny-vlm.gguf', + mmprojUrl: 'https://example.com/tiny-mmproj.gguf', + mmprojFilename: 'tiny-mmproj.gguf', + sizeBytes: 20, + supportsVision: true, + ); +} + class _HoldingModelService implements ModelService { + _HoldingModelService({ + Set? downloadedFiles, + Set? cachedAssetKeys, + }) : downloadedFiles = downloadedFiles ?? {}, + cachedAssetKeys = cachedAssetKeys?.toSet(); + final Completer downloadStarted = Completer(); final Completer downloadCancelled = Completer(); + final Set downloadedFiles; + final Set? cachedAssetKeys; int downloadCalls = 0; + int deleteCalls = 0; CancelToken? lastCancelToken; @override @@ -220,7 +405,28 @@ class _HoldingModelService implements ModelService { Future> getDownloadedModels( List models, ) async { - return {}; + return models.where(_isProfileReady).map((model) => model.filename).toSet(); + } + + @override + Future getModelCacheState( + DownloadableModel model, + ) async { + final mmprojSource = model.multimodalProjectorSource; + return ModelProfileCacheState( + model: ModelAssetCacheState( + role: ModelAssetRole.model, + label: model.modelSource.displayName, + isAvailable: _isAssetAvailable(model, model.modelSource), + ), + multimodalProjector: mmprojSource == null + ? null + : ModelAssetCacheState( + role: ModelAssetRole.multimodalProjector, + label: mmprojSource.displayName, + isAvailable: _isAssetAvailable(model, mmprojSource), + ), + ); } @override @@ -268,5 +474,55 @@ class _HoldingModelService implements ModelService { } @override - Future deleteModel(String modelsDir, DownloadableModel model) async {} + Future deleteModel(String modelsDir, DownloadableModel model) async { + deleteCalls += 1; + final keys = cachedAssetKeys; + if (keys == null) { + downloadedFiles.remove(model.filename); + return; + } + + for (final source in _remoteSourcesFor(model)) { + keys.remove(source.cacheKey); + } + } + + bool _isProfileReady(DownloadableModel model) { + final keys = cachedAssetKeys; + if (keys == null) { + return downloadedFiles.contains(model.filename); + } + final sources = _assetSourcesFor(model); + return sources.every( + (source) => + source is RemoteModelAssetSource && keys.contains(source.cacheKey), + ); + } + + bool _isAssetAvailable(DownloadableModel model, ModelAssetSource source) { + final keys = cachedAssetKeys; + if (keys == null) { + return downloadedFiles.contains(model.filename); + } + return source is RemoteModelAssetSource && keys.contains(source.cacheKey); + } + + List _assetSourcesFor(DownloadableModel model) { + final mmprojSource = model.multimodalProjectorSource; + return [model.modelSource, ?mmprojSource]; + } + + List _remoteSourcesFor(DownloadableModel model) { + return _assetSourcesFor( + model, + ).whereType().toList(growable: false); + } +} + +class _NoVisionEngine extends MockLlamaEngine { + @override + Future get supportsVision async => false; + + @override + Future get supportsAudio async => false; } diff --git a/example/chat_app/test/model_asset_source_test.dart b/example/chat_app/test/model_asset_source_test.dart index 21a4c983..d6245202 100644 --- a/example/chat_app/test/model_asset_source_test.dart +++ b/example/chat_app/test/model_asset_source_test.dart @@ -1,5 +1,6 @@ import 'package:flutter_test/flutter_test.dart'; import 'package:llamadart_chat_example/models/downloadable_model.dart'; +import 'package:llamadart_chat_example/services/model_service_base.dart'; void main() { group('Model asset sources', () { @@ -83,6 +84,83 @@ void main() { expect(profile.mmprojFilename, 'mmproj.gguf'); }); + test( + 'cache markers surface model availability after mmproj prefetch failure', + () { + final profile = DownloadableModel( + name: 'Partial VLM', + description: 'Remote model cached before projector failure', + url: 'https://example.com/model.gguf', + filename: 'model.gguf', + mmprojUrl: 'https://example.com/mmproj.gguf', + mmprojFilename: 'mmproj.gguf', + sizeBytes: 2048, + supportsVision: true, + ); + final modelSource = profile.modelSource as RemoteModelAssetSource; + final mmprojSource = + profile.multimodalProjectorSource as RemoteModelAssetSource; + final markers = ModelAssetCacheMarkers([]); + + markers.markAssetCached(modelSource); + + final state = markers.modelCacheState(profile, web: true); + expect(markers.isProfileCached(profile, web: true), isFalse); + expect(state.isReady, isFalse); + expect(state.hasPartialAssets, isTrue); + expect(state.model.isAvailable, isTrue); + expect(state.multimodalProjector?.isAvailable, isFalse); + expect(markers.toSet(), contains(modelSource.cacheKey)); + expect(markers.toSet(), isNot(contains(mmprojSource.cacheKey))); + }, + ); + + test('legacy profile marker migrates to per-asset remote markers', () { + final profile = DownloadableModel( + name: 'Legacy Cached VLM', + description: 'Remote model and projector', + url: 'https://example.com/model.gguf', + filename: 'model.gguf', + mmprojUrl: 'https://example.com/mmproj.gguf', + mmprojFilename: 'mmproj.gguf', + sizeBytes: 2048, + supportsVision: true, + ); + final modelSource = profile.modelSource as RemoteModelAssetSource; + final mmprojSource = + profile.multimodalProjectorSource as RemoteModelAssetSource; + final markers = ModelAssetCacheMarkers({profile.filename}); + + expect(markers.migrateLegacyProfileMarker(profile, web: true), isTrue); + + expect(markers.containsMarker(profile.filename), isFalse); + expect(markers.containsAsset(modelSource), isTrue); + expect(markers.containsAsset(mmprojSource), isTrue); + expect(markers.isProfileCached(profile, web: true), isTrue); + }); + + test('legacy profile marker is not migrated for local assets', () { + final profile = DownloadableModel.fromSources( + name: 'Mixed Legacy VLM', + description: 'Remote model with local projector', + modelSource: const RemoteModelAssetSource( + url: 'https://example.com/model.gguf', + filename: 'model.gguf', + ), + multimodalProjectorSource: const LocalModelAssetSource( + '/models/mmproj.gguf', + ), + sizeBytes: 2048, + supportsVision: true, + ); + final markers = ModelAssetCacheMarkers({profile.filename}); + + expect(markers.migrateLegacyProfileMarker(profile, web: true), isFalse); + + expect(markers.containsMarker(profile.filename), isTrue); + expect(markers.isProfileCached(profile, web: true), isFalse); + }); + test('platform-specific web source can differ from native source', () { final profile = DownloadableModel.fromSources( id: 'litert-gemma', diff --git a/example/chat_app/test/model_card_test.dart b/example/chat_app/test/model_card_test.dart index e7b59f53..8168e364 100644 --- a/example/chat_app/test/model_card_test.dart +++ b/example/chat_app/test/model_card_test.dart @@ -1,6 +1,7 @@ import 'package:flutter/material.dart'; import 'package:flutter_test/flutter_test.dart'; import 'package:llamadart_chat_example/models/downloadable_model.dart'; +import 'package:llamadart_chat_example/services/model_service_base.dart'; import 'package:llamadart_chat_example/widgets/model_card.dart'; void main() { @@ -68,6 +69,56 @@ void main() { expect(find.textContaining('very large LiteRT-LM'), findsNothing); }); + + testWidgets('partial multimodal cache shows missing mmproj state', ( + tester, + ) async { + var downloadCalls = 0; + var deleteCalls = 0; + + await _pumpCard( + tester, + model: _vlmModel(), + isWeb: false, + isDownloaded: false, + cacheState: const ModelProfileCacheState( + model: ModelAssetCacheState( + role: ModelAssetRole.model, + label: 'model.gguf', + isAvailable: true, + ), + multimodalProjector: ModelAssetCacheState( + role: ModelAssetRole.multimodalProjector, + label: 'mmproj.gguf', + isAvailable: false, + ), + ), + onSelect: () {}, + onDownload: () => downloadCalls += 1, + onDelete: () => deleteCalls += 1, + ); + + expect(find.text('Model cached'), findsOneWidget); + expect(find.text('mmproj missing'), findsOneWidget); + expect( + find.text( + 'Model cached; mmproj missing. Download will fetch only missing assets.', + ), + findsOneWidget, + ); + expect(find.text('Download Missing Assets'), findsOneWidget); + expect(find.byTooltip('Delete cached assets'), findsOneWidget); + + await tester.tap(find.text('Download Missing Assets')); + await tester.pump(); + + expect(downloadCalls, 1); + + await tester.tap(find.byTooltip('Delete cached assets')); + await tester.pump(); + + expect(deleteCalls, 1); + }); } Future _pumpCard( @@ -75,8 +126,10 @@ Future _pumpCard( required DownloadableModel model, required bool isWeb, required bool isDownloaded, + ModelProfileCacheState? cacheState, required VoidCallback onSelect, required VoidCallback onDownload, + VoidCallback? onDelete, }) async { await tester.pumpWidget( MaterialApp( @@ -85,6 +138,7 @@ Future _pumpCard( child: ModelCard( model: model, isDownloaded: isDownloaded, + cacheState: cacheState, isDownloading: false, progress: 0, isWeb: isWeb, @@ -95,7 +149,7 @@ Future _pumpCard( onContextSizeChanged: (_) {}, onSelect: onSelect, onDownload: onDownload, - onDelete: () {}, + onDelete: onDelete ?? () {}, ), ), ), @@ -103,6 +157,19 @@ Future _pumpCard( ); } +DownloadableModel _vlmModel() { + return const DownloadableModel( + name: 'VLM Test Model', + description: 'Fake multimodal model for widget tests.', + url: 'https://example.com/model.gguf', + filename: 'model.gguf', + mmprojUrl: 'https://example.com/mmproj.gguf', + mmprojFilename: 'mmproj.gguf', + sizeBytes: 10, + supportsVision: true, + ); +} + DownloadableModel _litertLmModel() { return const DownloadableModel( name: 'LiteRT-LM Test Model', diff --git a/example/chat_app/test/model_download_controller_adapter_test.dart b/example/chat_app/test/model_download_controller_adapter_test.dart index 570a0062..c0cc3924 100644 --- a/example/chat_app/test/model_download_controller_adapter_test.dart +++ b/example/chat_app/test/model_download_controller_adapter_test.dart @@ -198,6 +198,28 @@ class _FakeModelService implements ModelService { .toSet(); } + @override + Future getModelCacheState( + DownloadableModel model, + ) async { + final isAvailable = downloadedFiles.contains(model.filename); + final mmprojSource = model.multimodalProjectorSource; + return ModelProfileCacheState( + model: ModelAssetCacheState( + role: ModelAssetRole.model, + label: model.modelSource.displayName, + isAvailable: isAvailable, + ), + multimodalProjector: mmprojSource == null + ? null + : ModelAssetCacheState( + role: ModelAssetRole.multimodalProjector, + label: mmprojSource.displayName, + isAvailable: isAvailable, + ), + ); + } + @override Future downloadModel({ required DownloadableModel model, diff --git a/example/chat_app/test/model_service_test.dart b/example/chat_app/test/model_service_test.dart index b06b396f..a0219df9 100644 --- a/example/chat_app/test/model_service_test.dart +++ b/example/chat_app/test/model_service_test.dart @@ -14,6 +14,7 @@ void main() { late ModelService service; late List testData; late List mmprojData; + late Map getRequestCountByPath; const int testDataSize = 1024 * 1024 * 5; // 5 MB const int mmprojDataSize = 1024 * 1024 * 2; // 2 MB @@ -22,6 +23,7 @@ void main() { // Generate random test data testData = List.generate(testDataSize, (i) => i % 256); mmprojData = List.generate(mmprojDataSize, (i) => (i * 7) % 256); + getRequestCountByPath = {}; tempDir = await Directory.systemTemp.createTemp('model_service_test'); service = TestModelService(tempDir); @@ -40,6 +42,7 @@ void main() { request.response.statusCode = HttpStatus.ok; await request.response.close(); } else if (request.method == 'GET') { + getRequestCountByPath[path] = (getRequestCountByPath[path] ?? 0) + 1; final rangeHeader = request.headers.value('range'); int start = 0; int end = payloadSize - 1; @@ -164,6 +167,63 @@ void main() { expect(updates.last.overallProgress, closeTo(1.0, 0.0001)); }); + test( + 'Multimodal download skips cached model when mmproj is missing', + () async { + final model = DownloadableModel( + name: 'Partially cached VLM', + description: 'Test', + url: '$baseUrl/model.gguf', + filename: 'cached-vlm-model.gguf', + mmprojUrl: '$baseUrl/mmproj.gguf', + mmprojFilename: 'cached-vlm-mmproj.gguf', + sizeBytes: testDataSize + mmprojDataSize, + supportsVision: true, + ); + await File(p.join(tempDir.path, model.filename)).writeAsBytes(testData); + + final before = await service.getModelCacheState(model); + expect(before.model.isAvailable, isTrue); + expect(before.multimodalProjector?.isAvailable, isFalse); + expect(before.hasPartialAssets, isTrue); + + final updates = []; + + await service.downloadModel( + model: model, + modelsDir: tempDir.path, + cancelToken: CancelToken(), + onProgress: (_) {}, + onProgressDetail: updates.add, + onSuccess: (_) {}, + onError: (e) => fail('Download failed: $e'), + ); + + expect(getRequestCountByPath['/model.gguf'] ?? 0, 0); + expect(getRequestCountByPath['/mmproj.gguf'] ?? 0, 1); + expect( + updates.where((u) => u.stage == ModelDownloadStage.model), + isEmpty, + ); + expect( + updates.every( + (u) => + u.stage == ModelDownloadStage.multimodalProjector && + u.stageIndex == 1 && + u.stageCount == 1, + ), + isTrue, + ); + + final after = await service.getModelCacheState(model); + expect(after.isReady, isTrue); + expect( + await service.getDownloadedModels([model]), + contains(model.filename), + ); + }, + ); + test( 'Local model with remote mmproj reports a single projector stage', () async { diff --git a/example/chat_app/test/unit_test.dart b/example/chat_app/test/unit_test.dart index c6190a39..b460d453 100644 --- a/example/chat_app/test/unit_test.dart +++ b/example/chat_app/test/unit_test.dart @@ -1287,6 +1287,27 @@ class _RecordingModelService return {}; } + @override + Future getModelCacheState( + DownloadableModel model, + ) async { + final mmprojSource = model.multimodalProjectorSource; + return app_model_service.ModelProfileCacheState( + model: app_model_service.ModelAssetCacheState( + role: ModelAssetRole.model, + label: model.modelSource.displayName, + isAvailable: false, + ), + multimodalProjector: mmprojSource == null + ? null + : app_model_service.ModelAssetCacheState( + role: ModelAssetRole.multimodalProjector, + label: mmprojSource.displayName, + isAvailable: false, + ), + ); + } + @override Future downloadModel({ required DownloadableModel model, diff --git a/website/docs/changelog/recent-releases.md b/website/docs/changelog/recent-releases.md index da9d1c93..08f945e3 100644 --- a/website/docs/changelog/recent-releases.md +++ b/website/docs/changelog/recent-releases.md @@ -17,6 +17,12 @@ For canonical full release notes, use: projector files can use the same `ModelSource` resolver and native download/cache options as `loadModelSource(...)`. +- Improved the runnable chat app's Manage Models cache UX so model and mmproj + asset cache states are shown separately, missing multimodal projectors can be + re-cached without re-fetching already cached model assets, and runtime media + capability mismatches surface as user-readable warnings. Custom signed or + tokenized Hugging Face URLs now require confirmation before they are saved. + ## 0.8.12 - Updated the default LiteRT-LM native runtime pin to