From f7d49165a3999d7e2bcbc48e04d061f9c4e9685a Mon Sep 17 00:00:00 2001 From: Sasha Denisov Date: Fri, 23 Jan 2026 15:15:39 +0500 Subject: [PATCH 1/5] Add LiteRT-LM engine support for Android (.litertlm models) Implement Strategy pattern for inference engines with two backends: - MediaPipe (existing .task files) - LiteRT-LM (new .litertlm files with multimodal support) Key changes: - Add InferenceEngine interface with Engine/Session abstractions - Add EngineFactory for automatic engine selection based on file extension - Implement LiteRtLmEngine with visionBackend for multimodal models - Implement LiteRtLmSession with chunk buffering for MediaPipe compatibility - Add thread-safety (synchronized locks) in FlutterGemmaPlugin - Add LiteRT-LM SDK dependency (0.9.0-alpha01) - Add gemma3n LiteRT-LM model options in example app - Add unit tests for engines Tested with Gemma 3 Nano E2B multimodal (text + image) on Pixel 8. --- android/build.gradle | 3 + .../flutter_gemma/FlutterGemmaPlugin.kt | 184 ++++++++++----- .../flutter_gemma/engines/EngineConfig.kt | 38 +++ .../flutter_gemma/engines/EngineFactory.kt | 77 +++++++ .../flutter_gemma/engines/InferenceEngine.kt | 50 ++++ .../flutter_gemma/engines/InferenceSession.kt | 51 +++++ .../engines/litertlm/LiteRtLmEngine.kt | 110 +++++++++ .../engines/litertlm/LiteRtLmSession.kt | 167 ++++++++++++++ .../engines/mediapipe/MediaPipeEngine.kt | 91 ++++++++ .../engines/mediapipe/MediaPipeSession.kt | 84 +++++++ .../engines/EngineFactoryTest.kt | 132 +++++++++++ .../engines/litertlm/LiteRtLmEngineTest.kt | 196 ++++++++++++++++ .../engines/litertlm/LiteRtLmSessionTest.kt | 216 ++++++++++++++++++ example/lib/models/model.dart | 46 +++- 14 files changed, 1384 insertions(+), 61 deletions(-) create mode 100644 android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/EngineConfig.kt create mode 100644 android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/EngineFactory.kt create mode 100644 android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/InferenceEngine.kt create mode 100644 android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/InferenceSession.kt create mode 100644 android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmEngine.kt create mode 100644 android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmSession.kt create mode 100644 android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/mediapipe/MediaPipeEngine.kt create mode 100644 android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/mediapipe/MediaPipeSession.kt create mode 100644 android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/EngineFactoryTest.kt create mode 100644 android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmEngineTest.kt create mode 100644 android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmSessionTest.kt diff --git a/android/build.gradle b/android/build.gradle index 4ff9de1b..a16b78fd 100644 --- a/android/build.gradle +++ b/android/build.gradle @@ -73,6 +73,9 @@ dependencies { implementation 'com.google.guava:guava:33.3.1-android' implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-guava:1.9.0' + // LiteRT-LM Engine for .litertlm model files + implementation 'com.google.ai.edge.litertlm:litertlm-android:0.9.0-alpha01' + implementation 'androidx.core:core-ktx:1.12.0' implementation 'androidx.lifecycle:lifecycle-runtime-ktx:2.7.0' testImplementation 'org.jetbrains.kotlin:kotlin-test' diff --git a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/FlutterGemmaPlugin.kt b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/FlutterGemmaPlugin.kt index d049b09b..094e43a4 100644 --- a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/FlutterGemmaPlugin.kt +++ b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/FlutterGemmaPlugin.kt @@ -9,6 +9,8 @@ import io.flutter.plugin.common.EventChannel import io.flutter.plugin.common.MethodChannel import kotlinx.coroutines.* +import dev.flutterberlin.flutter_gemma.engines.* + /** FlutterGemmaPlugin */ class FlutterGemmaPlugin: FlutterPlugin { /// The MethodChannel that will the communication between Flutter and native Android @@ -18,13 +20,14 @@ class FlutterGemmaPlugin: FlutterPlugin { private lateinit var eventChannel: EventChannel private lateinit var bundledChannel: MethodChannel private lateinit var context: Context + private var service: PlatformServiceImpl? = null override fun onAttachedToEngine(flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) { context = flutterPluginBinding.applicationContext - val service = PlatformServiceImpl(context) + service = PlatformServiceImpl(context) eventChannel = EventChannel(flutterPluginBinding.binaryMessenger, "flutter_gemma_stream") - eventChannel.setStreamHandler(service) - PlatformService.setUp(flutterPluginBinding.binaryMessenger, service) + eventChannel.setStreamHandler(service!!) + PlatformService.setUp(flutterPluginBinding.binaryMessenger, service!!) // Setup bundled assets channel bundledChannel = MethodChannel(flutterPluginBinding.binaryMessenger, "flutter_gemma_bundled") @@ -61,21 +64,43 @@ class FlutterGemmaPlugin: FlutterPlugin { override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) { eventChannel.setStreamHandler(null) bundledChannel.setMethodCallHandler(null) + service?.cleanup() + service = null } } private class PlatformServiceImpl( val context: Context ) : PlatformService, EventChannel.StreamHandler { - private val scope = CoroutineScope(Dispatchers.IO) + private val scope = CoroutineScope(Dispatchers.IO + SupervisorJob()) private var eventSink: EventChannel.EventSink? = null - private var inferenceModel: InferenceModel? = null - private var session: InferenceModelSession? = null - + private var streamJob: kotlinx.coroutines.Job? = null // Track stream collection job + private val engineLock = Any() // Lock for thread-safe engine access + + // NEW: Use InferenceEngine abstraction instead of InferenceModel + private var engine: InferenceEngine? = null + private var session: InferenceSession? = null + // RAG components private var embeddingModel: EmbeddingModel? = null private var vectorStore: VectorStore? = null + fun cleanup() { + scope.cancel() + streamJob?.cancel() + streamJob = null + synchronized(engineLock) { + session?.close() + session = null + engine?.close() + engine = null + } + embeddingModel?.close() + embeddingModel = null + vectorStore?.close() + vectorStore = null + } + override fun createModel( maxTokens: Long, modelPath: String, @@ -86,20 +111,31 @@ private class PlatformServiceImpl( ) { scope.launch { try { + // Build configuration first (before touching state) val backendEnum = preferredBackend?.let { PreferredBackendEnum.values()[it.ordinal] } - val config = InferenceModelConfig( - modelPath, - maxTokens.toInt(), - loraRanks?.map { it.toInt() }, - backendEnum, - maxNumImages?.toInt() + val config = EngineConfig( + modelPath = modelPath, + maxTokens = maxTokens.toInt(), + supportedLoraRanks = loraRanks?.map { it.toInt() }, + preferredBackend = backendEnum, + maxNumImages = maxNumImages?.toInt() ) - if (config != inferenceModel?.config) { - inferenceModel?.close() - inferenceModel = InferenceModel(context, config) + + // Create and initialize new engine BEFORE clearing old state + // This ensures we don't leave state inconsistent on failure + val newEngine = EngineFactory.createFromModelPath(modelPath, context) + newEngine.initialize(config) + + // Only now clear old state and swap in new engine (thread-safe) + synchronized(engineLock) { + session?.close() + session = null + engine?.close() + engine = newEngine } + callback(Result.success(Unit)) } catch (e: Exception) { callback(Result.failure(e)) @@ -108,12 +144,16 @@ private class PlatformServiceImpl( } override fun closeModel(callback: (Result) -> Unit) { - try { - inferenceModel?.close() - inferenceModel = null - callback(Result.success(Unit)) - } catch (e: Exception) { - callback(Result.failure(e)) + synchronized(engineLock) { + try { + session?.close() + session = null + engine?.close() + engine = null + callback(Result.success(Unit)) + } catch (e: Exception) { + callback(Result.failure(e)) + } } } @@ -128,17 +168,22 @@ private class PlatformServiceImpl( ) { scope.launch { try { - val model = inferenceModel ?: throw IllegalStateException("Inference model is not created") - val config = InferenceSessionConfig( - temperature.toFloat(), - randomSeed.toInt(), - topK.toInt(), - topP?.toFloat(), - loraPath, - enableVisionModality - ) - session?.close() - session = model.createSession(config) + synchronized(engineLock) { + val currentEngine = engine + ?: throw IllegalStateException("Inference model is not created") + + val config = SessionConfig( + temperature = temperature.toFloat(), + randomSeed = randomSeed.toInt(), + topK = topK.toInt(), + topP = topP?.toFloat(), + loraPath = loraPath, + enableVisionModality = enableVisionModality + ) + + session?.close() + session = currentEngine.createSession(config) + } callback(Result.success(Unit)) } catch (e: Exception) { callback(Result.failure(e)) @@ -147,19 +192,23 @@ private class PlatformServiceImpl( } override fun closeSession(callback: (Result) -> Unit) { - try { - session?.close() - session = null - callback(Result.success(Unit)) - } catch (e: Exception) { - callback(Result.failure(e)) + synchronized(engineLock) { + try { + session?.close() + session = null + callback(Result.success(Unit)) + } catch (e: Exception) { + callback(Result.failure(e)) + } } } override fun sizeInTokens(prompt: String, callback: (Result) -> Unit) { scope.launch { try { - val size = session?.sizeInTokens(prompt) ?: throw IllegalStateException("Session not created") + val currentSession = session + ?: throw IllegalStateException("Session not created") + val size = currentSession.sizeInTokens(prompt) callback(Result.success(size.toLong())) } catch (e: Exception) { callback(Result.failure(e)) @@ -170,7 +219,9 @@ private class PlatformServiceImpl( override fun addQueryChunk(prompt: String, callback: (Result) -> Unit) { scope.launch { try { - session?.addQueryChunk(prompt) ?: throw IllegalStateException("Session not created") + val currentSession = session + ?: throw IllegalStateException("Session not created") + currentSession.addQueryChunk(prompt) callback(Result.success(Unit)) } catch (e: Exception) { callback(Result.failure(e)) @@ -181,7 +232,9 @@ private class PlatformServiceImpl( override fun addImage(imageBytes: ByteArray, callback: (Result) -> Unit) { scope.launch { try { - session?.addImage(imageBytes) ?: throw IllegalStateException("Session not created") + val currentSession = session + ?: throw IllegalStateException("Session not created") + currentSession.addImage(imageBytes) callback(Result.success(Unit)) } catch (e: Exception) { callback(Result.failure(e)) @@ -192,7 +245,9 @@ private class PlatformServiceImpl( override fun generateResponse(callback: (Result) -> Unit) { scope.launch { try { - val result = session?.generateResponse() ?: throw IllegalStateException("Session not created") + val currentSession = session + ?: throw IllegalStateException("Session not created") + val result = currentSession.generateResponse() callback(Result.success(result)) } catch (e: Exception) { callback(Result.failure(e)) @@ -203,7 +258,9 @@ private class PlatformServiceImpl( override fun generateResponseAsync(callback: (Result) -> Unit) { scope.launch { try { - session?.generateResponseAsync() ?: throw IllegalStateException("Session not created") + val currentSession = session + ?: throw IllegalStateException("Session not created") + currentSession.generateResponseAsync() callback(Result.success(Unit)) } catch (e: Exception) { callback(Result.failure(e)) @@ -214,7 +271,9 @@ private class PlatformServiceImpl( override fun stopGeneration(callback: (Result) -> Unit) { scope.launch { try { - session?.stopGeneration() ?: throw IllegalStateException("Session not created") + val currentSession = session + ?: throw IllegalStateException("Session not created") + currentSession.cancelGeneration() callback(Result.success(Unit)) } catch (e: Exception) { callback(Result.failure(e)) @@ -223,26 +282,31 @@ private class PlatformServiceImpl( } override fun onListen(arguments: Any?, events: EventChannel.EventSink?) { + // Cancel previous stream collection to prevent orphaned coroutines + streamJob?.cancel() eventSink = events - val model = inferenceModel ?: return - scope.launch { - launch { - model.partialResults.collect { (text, done) -> - val payload = mapOf("partialResult" to text, "done" to done) - withContext(Dispatchers.Main) { - events?.success(payload) - if (done) { - events?.endOfStream() + synchronized(engineLock) { + val currentEngine = engine ?: return + + streamJob = scope.launch { + launch { + currentEngine.partialResults.collect { (text, done) -> + val payload = mapOf("partialResult" to text, "done" to done) + withContext(Dispatchers.Main) { + events?.success(payload) + if (done) { + events?.endOfStream() + } } } } - } - launch { - model.errors.collect { error -> - withContext(Dispatchers.Main) { - events?.error("ERROR", error.message, null) + launch { + currentEngine.errors.collect { error -> + withContext(Dispatchers.Main) { + events?.error("ERROR", error.message, null) + } } } } @@ -250,6 +314,8 @@ private class PlatformServiceImpl( } override fun onCancel(arguments: Any?) { + streamJob?.cancel() + streamJob = null eventSink = null } diff --git a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/EngineConfig.kt b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/EngineConfig.kt new file mode 100644 index 00000000..4dd151e6 --- /dev/null +++ b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/EngineConfig.kt @@ -0,0 +1,38 @@ +package dev.flutterberlin.flutter_gemma.engines + +import dev.flutterberlin.flutter_gemma.PreferredBackendEnum +import kotlinx.coroutines.channels.BufferOverflow +import kotlinx.coroutines.flow.MutableSharedFlow + +/** + * Engine initialization configuration. + */ +data class EngineConfig( + val modelPath: String, + val maxTokens: Int, + val supportedLoraRanks: List? = null, + val preferredBackend: PreferredBackendEnum? = null, + val maxNumImages: Int? = null, +) + +/** + * Session-level configuration (sampling parameters). + */ +data class SessionConfig( + val temperature: Float = 1.0f, + val randomSeed: Int = 0, + val topK: Int = 40, + val topP: Float? = null, + val loraPath: String? = null, + val enableVisionModality: Boolean? = null, +) + +/** + * Helper to create SharedFlow instances with consistent configuration. + */ +object FlowFactory { + fun createSharedFlow(): MutableSharedFlow = MutableSharedFlow( + extraBufferCapacity = 1, + onBufferOverflow = BufferOverflow.DROP_OLDEST + ) +} diff --git a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/EngineFactory.kt b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/EngineFactory.kt new file mode 100644 index 00000000..527fa17a --- /dev/null +++ b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/EngineFactory.kt @@ -0,0 +1,77 @@ +package dev.flutterberlin.flutter_gemma.engines + +import android.content.Context +import dev.flutterberlin.flutter_gemma.engines.mediapipe.MediaPipeEngine +import dev.flutterberlin.flutter_gemma.engines.litertlm.LiteRtLmEngine + +/** + * Factory for creating inference engines. + * + * Engine selection strategy: + * - MEDIAPIPE: .task, .bin, .tflite files + * - LITERTLM: .litertlm files + */ +object EngineFactory { + + /** + * Create engine based on file extension. + * + * @param modelPath Path to model file + * @param context Android context + * @return Appropriate engine instance + * @throws IllegalArgumentException if file extension not recognized + */ + fun createFromModelPath(modelPath: String, context: Context): InferenceEngine { + return when { + modelPath.endsWith(".litertlm", ignoreCase = true) -> LiteRtLmEngine(context) + modelPath.endsWith(".task", ignoreCase = true) -> MediaPipeEngine(context) + modelPath.endsWith(".bin", ignoreCase = true) -> MediaPipeEngine(context) + modelPath.endsWith(".tflite", ignoreCase = true) -> MediaPipeEngine(context) + else -> throw IllegalArgumentException( + "Unsupported model format: ${modelPath.substringAfterLast('.')}. " + + "Supported: .litertlm (LiteRT-LM), .task/.bin/.tflite (MediaPipe)" + ) + } + } + + /** + * Create engine explicitly by type (for testing or advanced use cases). + * + * @param engineType Type of engine to create + * @param context Android context + * @return Engine instance of specified type + */ + fun create(engineType: EngineType, context: Context): InferenceEngine { + return when (engineType) { + EngineType.MEDIAPIPE -> MediaPipeEngine(context) + EngineType.LITERTLM -> LiteRtLmEngine(context) + } + } + + /** + * Detect engine type from model path. + * + * @param modelPath Path to model file + * @return Engine type for the given model + * @throws IllegalArgumentException if extension not recognized + */ + fun detectEngineType(modelPath: String): EngineType { + return when { + modelPath.endsWith(".litertlm", ignoreCase = true) -> EngineType.LITERTLM + modelPath.endsWith(".task", ignoreCase = true) -> EngineType.MEDIAPIPE + modelPath.endsWith(".bin", ignoreCase = true) -> EngineType.MEDIAPIPE + modelPath.endsWith(".tflite", ignoreCase = true) -> EngineType.MEDIAPIPE + else -> throw IllegalArgumentException( + "Unsupported model format: ${modelPath.substringAfterLast('.')}" + ) + } + } +} + +/** + * Engine type enumeration. + */ +enum class EngineType { + MEDIAPIPE, // .task, .bin, .tflite + LITERTLM // .litertlm +} diff --git a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/InferenceEngine.kt b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/InferenceEngine.kt new file mode 100644 index 00000000..28ba5524 --- /dev/null +++ b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/InferenceEngine.kt @@ -0,0 +1,50 @@ +package dev.flutterberlin.flutter_gemma.engines + +import kotlinx.coroutines.flow.SharedFlow + +/** + * Abstraction for inference engines (MediaPipe, LiteRT-LM, future engines). + * + * Lifecycle: + * 1. initialize(config) - Load model, setup backend + * 2. createSession(config) - Create conversation/session + * 3. close() - Release resources + */ +interface InferenceEngine { + /** Whether engine has been initialized successfully */ + val isInitialized: Boolean + + /** Engine capabilities (vision, audio, function calls) */ + val capabilities: EngineCapabilities + + /** Streaming outputs (partial results + errors) */ + val partialResults: SharedFlow> + val errors: SharedFlow + + /** + * Initialize engine with model file. + * MUST be called on background thread (can take 10+ seconds). + */ + suspend fun initialize(config: EngineConfig) + + /** + * Create a new inference session. + * Throws IllegalStateException if engine not initialized. + */ + fun createSession(config: SessionConfig): InferenceSession + + /** Release all resources */ + fun close() +} + +/** + * Engine capabilities descriptor. + */ +data class EngineCapabilities( + val supportsVision: Boolean = false, + val supportsAudio: Boolean = false, + val supportsFunctionCalls: Boolean = false, + val supportsStreaming: Boolean = true, + val supportsTokenCounting: Boolean = false, + val maxTokenLimit: Int = 2048, +) diff --git a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/InferenceSession.kt b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/InferenceSession.kt new file mode 100644 index 00000000..68a56516 --- /dev/null +++ b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/InferenceSession.kt @@ -0,0 +1,51 @@ +package dev.flutterberlin.flutter_gemma.engines + +/** + * Abstraction for inference sessions. + * + * API Design: + * - addQueryChunk() accumulates text (supports both chunk-based and message-based APIs) + * - addImage() accumulates images for multimodal + * - generateResponse() / generateResponseAsync() triggers inference + */ +interface InferenceSession { + /** + * Add text chunk to current query. + * Multiple calls accumulate into single message. + */ + fun addQueryChunk(prompt: String) + + /** + * Add image to current query (for multimodal models). + * Throws UnsupportedOperationException if engine doesn't support vision. + */ + fun addImage(imageBytes: ByteArray) + + /** + * Generate response synchronously (blocking). + * Consumes accumulated chunks/images. + */ + fun generateResponse(): String + + /** + * Generate response asynchronously (streaming). + * Consumes accumulated chunks/images. + * Results emitted via engine's partialResults SharedFlow. + */ + fun generateResponseAsync() + + /** + * Estimate token count for text. + * Returns approximate value if engine doesn't expose tokenizer. + */ + fun sizeInTokens(prompt: String): Int + + /** + * Cancel ongoing async generation. + * No-op if generation already completed. + */ + fun cancelGeneration() + + /** Release session resources */ + fun close() +} diff --git a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmEngine.kt b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmEngine.kt new file mode 100644 index 00000000..523abbf8 --- /dev/null +++ b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmEngine.kt @@ -0,0 +1,110 @@ +package dev.flutterberlin.flutter_gemma.engines.litertlm + +import android.content.Context +import android.util.Log +import com.google.ai.edge.litertlm.Backend +import com.google.ai.edge.litertlm.Engine +import com.google.ai.edge.litertlm.EngineConfig as LiteRtEngineConfig +import dev.flutterberlin.flutter_gemma.PreferredBackendEnum +import dev.flutterberlin.flutter_gemma.engines.* +import kotlinx.coroutines.flow.MutableSharedFlow +import kotlinx.coroutines.flow.SharedFlow +import kotlinx.coroutines.flow.asSharedFlow +import java.io.File + +private const val TAG = "LiteRtLmEngine" + +/** + * LiteRT-LM Engine implementation for .litertlm files. + * + * Key Differences from MediaPipe: + * - Uses Conversation API (not session-based) + * - No chunk accumulation at engine level (handled by LiteRtLmSession) + * - Supports audio modality + * - Faster initialization (~1-2s with cache vs ~10s cold start) + */ +class LiteRtLmEngine( + private val context: Context +) : InferenceEngine { + + private var engine: Engine? = null + + override var isInitialized: Boolean = false + private set + + override val capabilities = EngineCapabilities( + supportsVision = true, + supportsAudio = true, // LiteRT-LM supports audio + supportsFunctionCalls = true, // Native @Tool annotation support + supportsStreaming = true, + supportsTokenCounting = false, // No direct API, must estimate + maxTokenLimit = 4096, // Higher context window + ) + + private val _partialResults = FlowFactory.createSharedFlow>() + override val partialResults: SharedFlow> = _partialResults.asSharedFlow() + + private val _errors = FlowFactory.createSharedFlow() + override val errors: SharedFlow = _errors.asSharedFlow() + + override suspend fun initialize(config: EngineConfig) { + // Validate model file + val modelFile = File(config.modelPath) + if (!modelFile.exists()) { + throw IllegalArgumentException("Model not found at path: ${config.modelPath}") + } + + // Map PreferredBackendEnum to LiteRT-LM Backend + val backend = when (config.preferredBackend) { + PreferredBackendEnum.GPU, + PreferredBackendEnum.GPU_FLOAT16, + PreferredBackendEnum.GPU_MIXED, + PreferredBackendEnum.GPU_FULL -> Backend.GPU + PreferredBackendEnum.CPU -> Backend.CPU + else -> Backend.CPU // Default to CPU for safety + } + + try { + // Configure engine with cache directory for faster reloads + // visionBackend is required for multimodal models (image support) + val visionBackend = if (config.maxNumImages != null && config.maxNumImages > 0) backend else null + + val engineConfig = LiteRtEngineConfig( + modelPath = config.modelPath, + backend = backend, + visionBackend = visionBackend, + maxNumTokens = config.maxTokens, + cacheDir = context.cacheDir.absolutePath, // Improves reload time 10s→1-2s + ) + + Log.i(TAG, "Initializing LiteRT-LM engine with backend: $backend, maxTokens: ${config.maxTokens}") + + val newEngine = Engine(engineConfig) + newEngine.initialize() // Can take 10+ seconds on cold start, 1-2s with cache + engine = newEngine + isInitialized = true + + Log.i(TAG, "LiteRT-LM engine initialized successfully") + } catch (e: Exception) { + Log.e(TAG, "Failed to initialize LiteRT-LM engine", e) + throw RuntimeException("Failed to initialize LiteRT-LM: ${e.message}", e) + } + } + + override fun createSession(config: SessionConfig): InferenceSession { + val currentEngine = engine + ?: throw IllegalStateException("Engine not initialized. Call initialize() first.") + return LiteRtLmSession(currentEngine, config, _partialResults, _errors) + } + + override fun close() { + try { + engine?.close() + } catch (e: Exception) { + Log.w(TAG, "Error closing LiteRT-LM engine", e) + } + engine = null + isInitialized = false + Log.i(TAG, "LiteRT-LM engine closed") + } +} diff --git a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmSession.kt b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmSession.kt new file mode 100644 index 00000000..73ef759f --- /dev/null +++ b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmSession.kt @@ -0,0 +1,167 @@ +package dev.flutterberlin.flutter_gemma.engines.litertlm + +import android.util.Log +import com.google.ai.edge.litertlm.Content +import com.google.ai.edge.litertlm.Conversation +import com.google.ai.edge.litertlm.ConversationConfig +import com.google.ai.edge.litertlm.Engine +import com.google.ai.edge.litertlm.Message +import com.google.ai.edge.litertlm.MessageCallback +import com.google.ai.edge.litertlm.SamplerConfig +import dev.flutterberlin.flutter_gemma.engines.* +import kotlinx.coroutines.flow.MutableSharedFlow + +private const val TAG = "LiteRtLmSession" + +/** + * LiteRT-LM Session implementation. + * + * Key Design Decision: Chunk Buffering + * - MediaPipe: addQueryChunk() directly on session + * - LiteRT-LM: sendMessage() takes complete message + * - Solution: Buffer chunks in StringBuilder, send on generateResponse() + */ +class LiteRtLmSession( + engine: Engine, + config: SessionConfig, + private val resultFlow: MutableSharedFlow>, + private val errorFlow: MutableSharedFlow +) : InferenceSession { + + private val conversation: Conversation + + // Chunk buffering (MediaPipe compatibility) - thread-safe access + private val pendingPrompt = StringBuilder() + private val promptLock = Any() + @Volatile private var pendingImage: ByteArray? = null + + init { + // Build sampler config + val samplerConfig = SamplerConfig( + topK = config.topK, + topP = (config.topP ?: 0.95f).toDouble(), + temperature = config.temperature.toDouble(), + ) + + // Build conversation config + val conversationConfig = ConversationConfig( + samplerConfig = samplerConfig, + systemMessage = null, // System message not exposed in current API + ) + + conversation = engine.createConversation(conversationConfig) + Log.d(TAG, "Created LiteRT-LM conversation with topK=${config.topK}, temp=${config.temperature}") + } + + override fun addQueryChunk(prompt: String) { + // Accumulate chunks (LiteRT-LM uses sendMessage, not addQueryChunk) + synchronized(promptLock) { + pendingPrompt.append(prompt) + Log.v(TAG, "Accumulated chunk: ${prompt.length} chars, total: ${pendingPrompt.length}") + } + } + + override fun addImage(imageBytes: ByteArray) { + // Store image for multimodal message (thread-safe) + synchronized(promptLock) { + pendingImage = imageBytes + } + Log.d(TAG, "Added image: ${imageBytes.size} bytes") + } + + override fun generateResponse(): String { + val message = buildAndConsumeMessage() + Log.d(TAG, "Generating sync response for message: ${message.toString().length} chars") + + return try { + val response = conversation.sendMessage(message) + response.toString() + } catch (e: Exception) { + Log.e(TAG, "Error generating response", e) + errorFlow.tryEmit(e) + throw e + } + } + + override fun generateResponseAsync() { + val message = buildAndConsumeMessage() + Log.d(TAG, "Generating async response for message: ${message.toString().length} chars") + + try { + // Use callback-based API + conversation.sendMessageAsync(message, object : MessageCallback { + override fun onMessage(message: Message) { + val text = message.toString() + resultFlow.tryEmit(text to false) + } + + override fun onDone() { + resultFlow.tryEmit("" to true) + } + + override fun onError(throwable: Throwable) { + Log.e(TAG, "Async generation error", throwable) + errorFlow.tryEmit(throwable) + resultFlow.tryEmit("" to true) + } + }) + } catch (e: Exception) { + Log.e(TAG, "Failed to start async generation", e) + errorFlow.tryEmit(e) + resultFlow.tryEmit("" to true) + } + } + + override fun sizeInTokens(prompt: String): Int { + // LiteRT-LM doesn't expose tokenizer API + // Estimate: ~4 characters per token (GPT-style average) + val estimate = (prompt.length + 3) / 4 + Log.w(TAG, "sizeInTokens: LiteRT-LM does not support token counting. " + + "Using estimate (~4 chars/token): $estimate tokens for ${prompt.length} chars. " + + "This may be inaccurate for non-English text.") + return estimate + } + + override fun cancelGeneration() { + // LiteRT-LM 0.9.x doesn't expose cancellation API + Log.w(TAG, "cancelGeneration: Not yet supported by LiteRT-LM SDK") + } + + override fun close() { + try { + conversation.close() + Log.d(TAG, "Conversation closed") + } catch (e: Exception) { + Log.w(TAG, "Error closing conversation", e) + } + } + + /** + * Build Message from accumulated chunks/images and clear buffer. + * Thread-safe: uses synchronized access to pendingPrompt and pendingImage. + * + * Note: Message.of() is deprecated in newer SDK versions but Contents + * is not exported in 0.9.0-alpha01. Text comes first per API convention. + */ + private fun buildAndConsumeMessage(): Message { + val text: String + val image: ByteArray? + synchronized(promptLock) { + text = pendingPrompt.toString() + pendingPrompt.clear() + image = pendingImage + pendingImage = null + } + + return if (image != null) { + // Multimodal message: text first, then image (per API convention) + Message.of( + Content.Text(text), + Content.ImageBytes(image) + ) + } else { + // Text-only message + Message.of(text) + } + } +} diff --git a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/mediapipe/MediaPipeEngine.kt b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/mediapipe/MediaPipeEngine.kt new file mode 100644 index 00000000..c6ee1198 --- /dev/null +++ b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/mediapipe/MediaPipeEngine.kt @@ -0,0 +1,91 @@ +package dev.flutterberlin.flutter_gemma.engines.mediapipe + +import android.content.Context +import com.google.mediapipe.tasks.genai.llminference.LlmInference +import dev.flutterberlin.flutter_gemma.PreferredBackendEnum +import dev.flutterberlin.flutter_gemma.engines.* +import kotlinx.coroutines.flow.MutableSharedFlow +import kotlinx.coroutines.flow.SharedFlow +import kotlinx.coroutines.flow.asSharedFlow +import java.io.File + +/** + * Adapter wrapping existing MediaPipe LlmInference. + * + * This adapter wraps the existing MediaPipe implementation without + * modifying the original InferenceModel logic. + */ +class MediaPipeEngine( + private val context: Context +) : InferenceEngine { + + private var llmInference: LlmInference? = null + + override var isInitialized: Boolean = false + private set + + override val capabilities = EngineCapabilities( + supportsVision = true, + supportsAudio = false, + supportsFunctionCalls = true, // Manual via chat templates + supportsStreaming = true, + supportsTokenCounting = true, // MediaPipe has sizeInTokens() + maxTokenLimit = 2048, + ) + + // SharedFlow instances (same pattern as existing InferenceModel) + private val _partialResults = FlowFactory.createSharedFlow>() + override val partialResults: SharedFlow> = _partialResults.asSharedFlow() + + private val _errors = FlowFactory.createSharedFlow() + override val errors: SharedFlow = _errors.asSharedFlow() + + override suspend fun initialize(config: EngineConfig) { + // Validate model file exists + if (!File(config.modelPath).exists()) { + throw IllegalArgumentException("Model not found at path: ${config.modelPath}") + } + + try { + // Build LlmInferenceOptions (same logic as existing InferenceModel.kt) + val optionsBuilder = LlmInference.LlmInferenceOptions.builder() + .setModelPath(config.modelPath) + .setMaxTokens(config.maxTokens) + .apply { + config.supportedLoraRanks?.let { setSupportedLoraRanks(it) } + config.preferredBackend?.let { + // Explicit mapping instead of ordinal-based (safer) + val backendEnum: LlmInference.Backend? = when (it) { + PreferredBackendEnum.CPU -> LlmInference.Backend.CPU + PreferredBackendEnum.GPU, + PreferredBackendEnum.GPU_FLOAT16, + PreferredBackendEnum.GPU_MIXED, + PreferredBackendEnum.GPU_FULL -> LlmInference.Backend.GPU + PreferredBackendEnum.UNKNOWN, + PreferredBackendEnum.TPU -> null // Not supported by MediaPipe, use default + } + backendEnum?.let { backend -> setPreferredBackend(backend) } + } + config.maxNumImages?.let { setMaxNumImages(it) } + } + + val options = optionsBuilder.build() + llmInference = LlmInference.createFromOptions(context, options) + isInitialized = true + } catch (e: Exception) { + throw RuntimeException("Failed to initialize MediaPipe LlmInference: ${e.message}", e) + } + } + + override fun createSession(config: SessionConfig): InferenceSession { + val inference = llmInference + ?: throw IllegalStateException("Engine not initialized. Call initialize() first.") + return MediaPipeSession(inference, config, _partialResults, _errors) + } + + override fun close() { + llmInference?.close() + llmInference = null + isInitialized = false + } +} diff --git a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/mediapipe/MediaPipeSession.kt b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/mediapipe/MediaPipeSession.kt new file mode 100644 index 00000000..d7b7c046 --- /dev/null +++ b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/mediapipe/MediaPipeSession.kt @@ -0,0 +1,84 @@ +package dev.flutterberlin.flutter_gemma.engines.mediapipe + +import android.graphics.BitmapFactory +import com.google.mediapipe.framework.image.BitmapImageBuilder +import com.google.mediapipe.tasks.genai.llminference.GraphOptions +import com.google.mediapipe.tasks.genai.llminference.LlmInference +import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession +import dev.flutterberlin.flutter_gemma.engines.* +import kotlinx.coroutines.flow.MutableSharedFlow + +/** + * Adapter wrapping MediaPipe LlmInferenceSession. + * + * Direct pass-through to existing MediaPipe implementation. + * Same logic as existing InferenceModelSession.kt. + */ +class MediaPipeSession( + private val llmInference: LlmInference, + config: SessionConfig, + private val resultFlow: MutableSharedFlow>, + private val errorFlow: MutableSharedFlow +) : InferenceSession { + + private val session: LlmInferenceSession + + init { + // Same session creation logic as existing InferenceModelSession.kt + val sessionOptionsBuilder = LlmInferenceSession.LlmInferenceSessionOptions.builder() + .setTemperature(config.temperature) + .setRandomSeed(config.randomSeed) + .setTopK(config.topK) + .apply { + config.topP?.let { setTopP(it) } + config.loraPath?.let { setLoraPath(it) } + config.enableVisionModality?.let { enableVision -> + setGraphOptions( + GraphOptions.builder() + .setEnableVisionModality(enableVision) + .build() + ) + } + } + + val sessionOptions = sessionOptionsBuilder.build() + session = LlmInferenceSession.createFromOptions(llmInference, sessionOptions) + } + + override fun addQueryChunk(prompt: String) { + session.addQueryChunk(prompt) + } + + override fun addImage(imageBytes: ByteArray) { + val bitmap = BitmapFactory.decodeByteArray(imageBytes, 0, imageBytes.size) + ?: throw IllegalArgumentException("Failed to decode image bytes") + val mpImage = BitmapImageBuilder(bitmap).build() + session.addImage(mpImage) + } + + override fun generateResponse(): String { + return session.generateResponse() ?: "" + } + + override fun generateResponseAsync() { + session.generateResponseAsync { result, done -> + if (result != null) { + resultFlow.tryEmit(result to done) + } else if (done) { + resultFlow.tryEmit("" to true) + } + } + } + + override fun sizeInTokens(prompt: String): Int { + return session.sizeInTokens(prompt) + } + + override fun cancelGeneration() { + session.cancelGenerateResponseAsync() + } + + override fun close() { + session.close() + } +} diff --git a/android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/EngineFactoryTest.kt b/android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/EngineFactoryTest.kt new file mode 100644 index 00000000..b7d0b750 --- /dev/null +++ b/android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/EngineFactoryTest.kt @@ -0,0 +1,132 @@ +package dev.flutterberlin.flutter_gemma.engines + +import android.content.Context +import org.junit.Assert.* +import org.junit.Test +import org.mockito.Mockito.mock + +/** + * Unit tests for EngineFactory. + * + * Tests file extension detection and engine type selection. + */ +class EngineFactoryTest { + + private val mockContext: Context = mock(Context::class.java) + + // =========================================== + // createFromModelPath() tests + // =========================================== + + @Test + fun `createFromModelPath with litertlm extension returns LiteRtLmEngine`() { + val engine = EngineFactory.createFromModelPath("/path/to/model.litertlm", mockContext) + assertTrue("Expected LiteRtLmEngine", engine is dev.flutterberlin.flutter_gemma.engines.litertlm.LiteRtLmEngine) + } + + @Test + fun `createFromModelPath with LITERTLM uppercase returns LiteRtLmEngine`() { + val engine = EngineFactory.createFromModelPath("/path/to/model.LITERTLM", mockContext) + assertTrue("Expected LiteRtLmEngine for uppercase", engine is dev.flutterberlin.flutter_gemma.engines.litertlm.LiteRtLmEngine) + } + + @Test + fun `createFromModelPath with task extension returns MediaPipeEngine`() { + val engine = EngineFactory.createFromModelPath("/path/to/model.task", mockContext) + assertTrue("Expected MediaPipeEngine", engine is dev.flutterberlin.flutter_gemma.engines.mediapipe.MediaPipeEngine) + } + + @Test + fun `createFromModelPath with bin extension returns MediaPipeEngine`() { + val engine = EngineFactory.createFromModelPath("/path/to/model.bin", mockContext) + assertTrue("Expected MediaPipeEngine", engine is dev.flutterberlin.flutter_gemma.engines.mediapipe.MediaPipeEngine) + } + + @Test + fun `createFromModelPath with tflite extension returns MediaPipeEngine`() { + val engine = EngineFactory.createFromModelPath("/path/to/model.tflite", mockContext) + assertTrue("Expected MediaPipeEngine", engine is dev.flutterberlin.flutter_gemma.engines.mediapipe.MediaPipeEngine) + } + + @Test(expected = IllegalArgumentException::class) + fun `createFromModelPath with unknown extension throws IllegalArgumentException`() { + EngineFactory.createFromModelPath("/path/to/model.unknown", mockContext) + } + + @Test(expected = IllegalArgumentException::class) + fun `createFromModelPath with no extension throws IllegalArgumentException`() { + EngineFactory.createFromModelPath("/path/to/model", mockContext) + } + + // =========================================== + // detectEngineType() tests + // =========================================== + + @Test + fun `detectEngineType returns LITERTLM for litertlm extension`() { + val type = EngineFactory.detectEngineType("/path/to/model.litertlm") + assertEquals(EngineType.LITERTLM, type) + } + + @Test + fun `detectEngineType returns MEDIAPIPE for task extension`() { + val type = EngineFactory.detectEngineType("/path/to/model.task") + assertEquals(EngineType.MEDIAPIPE, type) + } + + @Test + fun `detectEngineType returns MEDIAPIPE for bin extension`() { + val type = EngineFactory.detectEngineType("/path/to/model.bin") + assertEquals(EngineType.MEDIAPIPE, type) + } + + @Test + fun `detectEngineType returns MEDIAPIPE for tflite extension`() { + val type = EngineFactory.detectEngineType("/path/to/model.tflite") + assertEquals(EngineType.MEDIAPIPE, type) + } + + @Test + fun `detectEngineType is case insensitive`() { + assertEquals(EngineType.LITERTLM, EngineFactory.detectEngineType("/model.LiteRtLm")) + assertEquals(EngineType.MEDIAPIPE, EngineFactory.detectEngineType("/model.TASK")) + assertEquals(EngineType.MEDIAPIPE, EngineFactory.detectEngineType("/model.BIN")) + } + + @Test(expected = IllegalArgumentException::class) + fun `detectEngineType throws for unknown extension`() { + EngineFactory.detectEngineType("/path/to/model.gguf") + } + + // =========================================== + // create() tests + // =========================================== + + @Test + fun `create with MEDIAPIPE returns MediaPipeEngine`() { + val engine = EngineFactory.create(EngineType.MEDIAPIPE, mockContext) + assertTrue("Expected MediaPipeEngine", engine is dev.flutterberlin.flutter_gemma.engines.mediapipe.MediaPipeEngine) + } + + @Test + fun `create with LITERTLM returns LiteRtLmEngine`() { + val engine = EngineFactory.create(EngineType.LITERTLM, mockContext) + assertTrue("Expected LiteRtLmEngine", engine is dev.flutterberlin.flutter_gemma.engines.litertlm.LiteRtLmEngine) + } + + // =========================================== + // Edge cases + // =========================================== + + @Test + fun `createFromModelPath handles paths with multiple dots`() { + val engine = EngineFactory.createFromModelPath("/path/to/model.v1.2.litertlm", mockContext) + assertTrue("Expected LiteRtLmEngine", engine is dev.flutterberlin.flutter_gemma.engines.litertlm.LiteRtLmEngine) + } + + @Test + fun `createFromModelPath handles paths with spaces`() { + val engine = EngineFactory.createFromModelPath("/path/to/my model.task", mockContext) + assertTrue("Expected MediaPipeEngine", engine is dev.flutterberlin.flutter_gemma.engines.mediapipe.MediaPipeEngine) + } +} diff --git a/android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmEngineTest.kt b/android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmEngineTest.kt new file mode 100644 index 00000000..9156f15e --- /dev/null +++ b/android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmEngineTest.kt @@ -0,0 +1,196 @@ +package dev.flutterberlin.flutter_gemma.engines.litertlm + +import android.content.Context +import dev.flutterberlin.flutter_gemma.PreferredBackendEnum +import dev.flutterberlin.flutter_gemma.engines.EngineConfig +import dev.flutterberlin.flutter_gemma.engines.SessionConfig +import kotlinx.coroutines.runBlocking +import org.junit.Assert.* +import org.junit.Before +import org.junit.Test +import org.mockito.Mockito.* +import java.io.File + +/** + * Unit tests for LiteRtLmEngine. + * + * Tests initialization, backend mapping, and lifecycle management. + */ +class LiteRtLmEngineTest { + + private lateinit var mockContext: Context + private lateinit var engine: LiteRtLmEngine + private lateinit var tempCacheDir: File + + @Before + fun setUp() { + mockContext = mock(Context::class.java) + tempCacheDir = File(System.getProperty("java.io.tmpdir"), "test_cache") + tempCacheDir.mkdirs() + + `when`(mockContext.cacheDir).thenReturn(tempCacheDir) + + engine = LiteRtLmEngine(mockContext) + } + + // =========================================== + // Initialization Tests + // =========================================== + + @Test + fun `isInitialized is false before initialize`() { + assertFalse("Engine should not be initialized", engine.isInitialized) + } + + @Test + fun `initialize with non-existent file throws IllegalArgumentException`() { + val config = EngineConfig( + modelPath = "/non/existent/path/model.litertlm", + maxTokens = 1024 + ) + + runBlocking { + try { + engine.initialize(config) + fail("Should throw for non-existent file") + } catch (e: IllegalArgumentException) { + assertTrue(e.message?.contains("not found") == true) + } + } + } + + // =========================================== + // Capabilities Tests + // =========================================== + + @Test + fun `capabilities reports vision support`() { + assertTrue("Should support vision", engine.capabilities.supportsVision) + } + + @Test + fun `capabilities reports audio support`() { + assertTrue("Should support audio", engine.capabilities.supportsAudio) + } + + @Test + fun `capabilities reports function calls support`() { + assertTrue("Should support function calls", engine.capabilities.supportsFunctionCalls) + } + + @Test + fun `capabilities reports streaming support`() { + assertTrue("Should support streaming", engine.capabilities.supportsStreaming) + } + + @Test + fun `capabilities reports no token counting support`() { + assertFalse("Should not support token counting", engine.capabilities.supportsTokenCounting) + } + + @Test + fun `capabilities has 4096 max token limit`() { + assertEquals(4096, engine.capabilities.maxTokenLimit) + } + + // =========================================== + // Session Creation Tests + // =========================================== + + @Test + fun `createSession before initialize throws IllegalStateException`() { + val config = SessionConfig(temperature = 0.7f) + + try { + engine.createSession(config) + fail("Should throw IllegalStateException") + } catch (e: IllegalStateException) { + assertTrue(e.message?.contains("not initialized") == true) + } + } + + // =========================================== + // Close Tests + // =========================================== + + @Test + fun `close sets isInitialized to false`() { + // Even if not initialized, close should work + engine.close() + assertFalse("Should not be initialized after close", engine.isInitialized) + } + + @Test + fun `close can be called multiple times safely`() { + engine.close() + engine.close() + engine.close() + + // Should not throw + assertFalse(engine.isInitialized) + } + + // =========================================== + // Flow Tests + // =========================================== + + @Test + fun `partialResults flow is accessible`() { + assertNotNull("partialResults should not be null", engine.partialResults) + } + + @Test + fun `errors flow is accessible`() { + assertNotNull("errors should not be null", engine.errors) + } + + // =========================================== + // Backend Mapping Tests (via EngineConfig) + // =========================================== + + @Test + fun `config with GPU backend is accepted`() { + val config = EngineConfig( + modelPath = "/test/model.litertlm", + maxTokens = 1024, + preferredBackend = PreferredBackendEnum.GPU + ) + + // Config creation should not throw + assertNotNull(config) + assertEquals(PreferredBackendEnum.GPU, config.preferredBackend) + } + + @Test + fun `config with CPU backend is accepted`() { + val config = EngineConfig( + modelPath = "/test/model.litertlm", + maxTokens = 1024, + preferredBackend = PreferredBackendEnum.CPU + ) + + assertEquals(PreferredBackendEnum.CPU, config.preferredBackend) + } + + @Test + fun `config with GPU_FLOAT16 backend is accepted`() { + val config = EngineConfig( + modelPath = "/test/model.litertlm", + maxTokens = 1024, + preferredBackend = PreferredBackendEnum.GPU_FLOAT16 + ) + + assertEquals(PreferredBackendEnum.GPU_FLOAT16, config.preferredBackend) + } + + @Test + fun `config with null backend defaults correctly`() { + val config = EngineConfig( + modelPath = "/test/model.litertlm", + maxTokens = 1024, + preferredBackend = null + ) + + assertNull(config.preferredBackend) + } +} diff --git a/android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmSessionTest.kt b/android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmSessionTest.kt new file mode 100644 index 00000000..5d828dd1 --- /dev/null +++ b/android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmSessionTest.kt @@ -0,0 +1,216 @@ +package dev.flutterberlin.flutter_gemma.engines.litertlm + +import com.google.ai.edge.litertlm.Conversation +import com.google.ai.edge.litertlm.Engine +import com.google.ai.edge.litertlm.Message +import dev.flutterberlin.flutter_gemma.engines.SessionConfig +import kotlinx.coroutines.flow.MutableSharedFlow +import org.junit.Assert.* +import org.junit.Before +import org.junit.Test +import org.mockito.Mockito.* +import java.util.concurrent.CountDownLatch +import java.util.concurrent.Executors +import java.util.concurrent.TimeUnit + +/** + * Unit tests for LiteRtLmSession. + * + * Tests chunk buffering, thread safety, and message building. + */ +class LiteRtLmSessionTest { + + private lateinit var mockEngine: Engine + private lateinit var mockConversation: Conversation + private lateinit var resultFlow: MutableSharedFlow> + private lateinit var errorFlow: MutableSharedFlow + private lateinit var session: LiteRtLmSession + + @Before + fun setUp() { + mockEngine = mock(Engine::class.java) + mockConversation = mock(Conversation::class.java) + resultFlow = MutableSharedFlow() + errorFlow = MutableSharedFlow() + + `when`(mockEngine.createConversation(any())).thenReturn(mockConversation) + + val config = SessionConfig( + temperature = 0.7f, + randomSeed = 42, + topK = 40, + topP = 0.95f + ) + + session = LiteRtLmSession(mockEngine, config, resultFlow, errorFlow) + } + + // =========================================== + // Chunk Buffering Tests + // =========================================== + + @Test + fun `addQueryChunk accumulates text`() { + session.addQueryChunk("Hello, ") + session.addQueryChunk("world!") + + // Token count estimate should reflect accumulated length + // "Hello, world!" = 13 chars → ~3-4 tokens + val tokenCount = session.sizeInTokens("") + // This verifies internal state indirectly + assertTrue("Token count should be reasonable", tokenCount >= 0) + } + + @Test + fun `addQueryChunk handles empty string`() { + session.addQueryChunk("") + session.addQueryChunk("test") + session.addQueryChunk("") + + // Should not crash + val tokenCount = session.sizeInTokens("test") + assertTrue(tokenCount >= 0) + } + + @Test + fun `addQueryChunk handles unicode text`() { + session.addQueryChunk("Привет, ") + session.addQueryChunk("мир! 🌍") + + // Should handle unicode without crashing + val tokenCount = session.sizeInTokens("Тест") + assertTrue(tokenCount >= 0) + } + + // =========================================== + // Thread Safety Tests + // =========================================== + + @Test + fun `concurrent addQueryChunk calls are thread-safe`() { + val executor = Executors.newFixedThreadPool(10) + val latch = CountDownLatch(100) + + repeat(100) { i -> + executor.submit { + try { + session.addQueryChunk("chunk$i ") + } finally { + latch.countDown() + } + } + } + + assertTrue("Should complete without deadlock", latch.await(5, TimeUnit.SECONDS)) + executor.shutdown() + } + + @Test + fun `concurrent addImage and addQueryChunk are thread-safe`() { + val executor = Executors.newFixedThreadPool(10) + val latch = CountDownLatch(100) + + repeat(50) { i -> + executor.submit { + try { + session.addQueryChunk("chunk$i ") + } finally { + latch.countDown() + } + } + executor.submit { + try { + session.addImage(byteArrayOf(i.toByte())) + } finally { + latch.countDown() + } + } + } + + assertTrue("Should complete without deadlock", latch.await(5, TimeUnit.SECONDS)) + executor.shutdown() + } + + // =========================================== + // Token Counting Tests + // =========================================== + + @Test + fun `sizeInTokens returns estimate based on character count`() { + // Formula: (length + 3) / 4 + val prompt = "Hello world" // 11 chars + val expected = (11 + 3) / 4 // = 3 + + val result = session.sizeInTokens(prompt) + + assertEquals("Token estimate should be ~chars/4", expected, result) + } + + @Test + fun `sizeInTokens handles empty string`() { + val result = session.sizeInTokens("") + assertEquals("Empty string should return ~1 token", 0, result) + } + + @Test + fun `sizeInTokens handles very long text`() { + val longText = "a".repeat(10000) + val result = session.sizeInTokens(longText) + + assertTrue("Should handle long text", result > 2000) + } + + // =========================================== + // Cancel Generation Tests + // =========================================== + + @Test + fun `cancelGeneration does not throw`() { + // LiteRT-LM doesn't support cancellation, but should not crash + session.cancelGeneration() + // If we get here, test passes + } + + // =========================================== + // Close Tests + // =========================================== + + @Test + fun `close releases conversation resource`() { + session.close() + + verify(mockConversation).close() + } + + @Test + fun `close can be called multiple times`() { + session.close() + session.close() + session.close() + + // Should not throw, verify close was attempted + verify(mockConversation, atLeast(1)).close() + } + + // =========================================== + // Image Handling Tests + // =========================================== + + @Test + fun `addImage stores image bytes`() { + val imageBytes = byteArrayOf(0x89.toByte(), 0x50, 0x4E, 0x47) // PNG header + + session.addImage(imageBytes) + + // Should not throw - image stored for later use + } + + @Test + fun `addImage replaces previous image`() { + session.addImage(byteArrayOf(1, 2, 3)) + session.addImage(byteArrayOf(4, 5, 6)) + + // Only last image should be used (implementation detail) + // No assertion needed - just verify no crash + } +} diff --git a/example/lib/models/model.dart b/example/lib/models/model.dart index 2322f853..9d45448a 100644 --- a/example/lib/models/model.dart +++ b/example/lib/models/model.dart @@ -105,6 +105,48 @@ enum Model implements InferenceModelInterface { supportsFunctionCalls: false, ), + // === LiteRT-LM ENGINE MODELS (for testing parity with MediaPipe) === + + // Gemma 3 Nano E2B LiteRT-LM (same model, different engine) + gemma3n_2B_litertlm( + baseUrl: + 'https://huggingface.co/google/gemma-3n-E2B-it-litert-lm/resolve/main/gemma-3n-E2B-it-int4.litertlm', + filename: 'gemma-3n-E2B-it-int4.litertlm', + displayName: 'Gemma 3 Nano E2B IT (LiteRT-LM)', + size: '3.1GB', + licenseUrl: 'https://huggingface.co/google/gemma-3n-E2B-it-litert-lm', + needsAuth: true, + preferredBackend: PreferredBackend.gpu, + modelType: ModelType.gemmaIt, + temperature: 1.0, + topK: 64, + topP: 0.95, + supportImage: true, + maxTokens: 4096, + maxNumImages: 1, + supportsFunctionCalls: true, + ), + + // Gemma 3 Nano E4B LiteRT-LM (same model, different engine) + gemma3n_4B_litertlm( + baseUrl: + 'https://huggingface.co/google/gemma-3n-E4B-it-litert-lm/resolve/main/gemma-3n-E4B-it-int4.litertlm', + filename: 'gemma-3n-E4B-it-int4.litertlm', + displayName: 'Gemma 3 Nano E4B IT (LiteRT-LM)', + size: '6.5GB', + licenseUrl: 'https://huggingface.co/google/gemma-3n-E4B-it-litert-lm', + needsAuth: true, + preferredBackend: PreferredBackend.gpu, + modelType: ModelType.gemmaIt, + temperature: 1.0, + topK: 64, + topP: 0.95, + supportImage: true, + maxTokens: 4096, + maxNumImages: 1, + supportsFunctionCalls: true, + ), + // Local Gemma models (for testing) gemma3LocalAsset( // model file should be pre-downloaded and placed in the assets folder @@ -318,8 +360,8 @@ enum Model implements InferenceModelInterface { // FunctionGemma 270M IT (Local asset) functionGemma_270M_local( - baseUrl: 'assets/models/functiongemma-flutter-1.litertlm', - filename: 'functiongemma-flutter-1.litertlm', + baseUrl: 'assets/models/functiongemma-270M-it.litertlm', + filename: 'functiongemma-270M-it.litertlm', displayName: 'FunctionGemma 270M IT (Local)', size: '284MB', licenseUrl: '', From df5dfc5c4f74ee541b1598cd60604111d6c28e7e Mon Sep 17 00:00:00 2001 From: Sasha Denisov Date: Fri, 23 Jan 2026 23:48:30 +0500 Subject: [PATCH 2/5] Refactor PreferredBackend enum: add NPU, remove unsupported values - Remove non-existent SDK values: unknown, gpuFloat16, gpuMixed, gpuFull, tpu - Add NPU backend support for LiteRT-LM (Google Tensor, Qualcomm) - Simplify backend mapping across all engines - Use Pigeon-generated PreferredBackend directly instead of PreferredBackendEnum - Update tests for NPU backend - Fix Copilot review issues: typo in test comment, error message for missing extension --- .../flutter_gemma/FlutterGemmaPlugin.kt | 12 +++-------- .../flutter_gemma/InferenceModel.kt | 20 ++++++++++--------- .../flutter_gemma/PigeonInterface.g.kt | 10 +++------- .../flutter_gemma/engines/EngineConfig.kt | 4 ++-- .../flutter_gemma/engines/EngineFactory.kt | 15 ++++++++++---- .../engines/litertlm/LiteRtLmEngine.kt | 14 ++++++------- .../engines/mediapipe/MediaPipeEngine.kt | 14 +++++-------- .../engines/litertlm/LiteRtLmEngineTest.kt | 16 +++++++-------- .../engines/litertlm/LiteRtLmSessionTest.kt | 2 +- ios/Classes/FlutterGemmaPlugin.swift | 3 ++- ios/Classes/PigeonInterface.g.swift | 10 +++------- lib/pigeon.g.dart | 6 +----- pigeon.dart | 6 +----- 13 files changed, 57 insertions(+), 75 deletions(-) diff --git a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/FlutterGemmaPlugin.kt b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/FlutterGemmaPlugin.kt index 094e43a4..4d1f599b 100644 --- a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/FlutterGemmaPlugin.kt +++ b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/FlutterGemmaPlugin.kt @@ -112,14 +112,11 @@ private class PlatformServiceImpl( scope.launch { try { // Build configuration first (before touching state) - val backendEnum = preferredBackend?.let { - PreferredBackendEnum.values()[it.ordinal] - } val config = EngineConfig( modelPath = modelPath, maxTokens = maxTokens.toInt(), supportedLoraRanks = loraRanks?.map { it.toInt() }, - preferredBackend = backendEnum, + preferredBackend = preferredBackend, maxNumImages = maxNumImages?.toInt() ) @@ -332,11 +329,8 @@ private class PlatformServiceImpl( embeddingModel?.close() // Convert PreferredBackend to useGPU boolean - val useGPU = when (preferredBackend) { - PreferredBackend.GPU, PreferredBackend.GPU_FLOAT16, - PreferredBackend.GPU_MIXED, PreferredBackend.GPU_FULL -> true - else -> false - } + // Note: NPU not supported for embeddings, fallback to CPU + val useGPU = preferredBackend == PreferredBackend.GPU embeddingModel = EmbeddingModel(context, modelPath, tokenizerPath, useGPU) embeddingModel!!.initialize() diff --git a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/InferenceModel.kt b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/InferenceModel.kt index 1014c7f8..5e6ac24b 100644 --- a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/InferenceModel.kt +++ b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/InferenceModel.kt @@ -13,10 +13,7 @@ import kotlinx.coroutines.flow.SharedFlow import kotlinx.coroutines.* import kotlinx.coroutines.flow.asSharedFlow -// Enum generated via Pigeon -enum class PreferredBackendEnum(val value: Int) { - UNKNOWN(0), CPU(1), GPU(2), GPU_FLOAT16(3), GPU_MIXED(4), GPU_FULL(5), TPU(6) -} +// Note: PreferredBackend is generated by Pigeon in PigeonInterface.g.kt // Configuration data classes @@ -24,7 +21,7 @@ data class InferenceModelConfig( val modelPath: String, val maxTokens: Int, val supportedLoraRanks: List?, - val preferredBackend: PreferredBackendEnum?, + val preferredBackend: PreferredBackend?, val maxNumImages: Int?, ) @@ -70,10 +67,15 @@ class InferenceModel( .setMaxTokens(config.maxTokens) .apply { config.supportedLoraRanks?.let { setSupportedLoraRanks(it) } - config.preferredBackend?.let { - val backendEnum = LlmInference.Backend.values().getOrNull(it.ordinal) - ?: throw IllegalArgumentException("Invalid preferredBackend value: ${it.ordinal}") - setPreferredBackend(backendEnum) + config.preferredBackend?.let { backend -> + // Map PreferredBackend to MediaPipe Backend + // Note: NPU is not supported by MediaPipe, fallback to default + val mpBackend = when (backend) { + PreferredBackend.CPU -> LlmInference.Backend.CPU + PreferredBackend.GPU -> LlmInference.Backend.GPU + PreferredBackend.NPU -> null // MediaPipe doesn't support NPU + } + mpBackend?.let { setPreferredBackend(it) } } config.maxNumImages?.let { setMaxNumImages(it) } } diff --git a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/PigeonInterface.g.kt b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/PigeonInterface.g.kt index 43c76636..acceabe4 100644 --- a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/PigeonInterface.g.kt +++ b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/PigeonInterface.g.kt @@ -47,13 +47,9 @@ class FlutterError ( ) : Throwable() enum class PreferredBackend(val raw: Int) { - UNKNOWN(0), - CPU(1), - GPU(2), - GPU_FLOAT16(3), - GPU_MIXED(4), - GPU_FULL(5), - TPU(6); + CPU(0), + GPU(1), + NPU(2); companion object { fun ofRaw(raw: Int): PreferredBackend? { diff --git a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/EngineConfig.kt b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/EngineConfig.kt index 4dd151e6..814a3935 100644 --- a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/EngineConfig.kt +++ b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/EngineConfig.kt @@ -1,6 +1,6 @@ package dev.flutterberlin.flutter_gemma.engines -import dev.flutterberlin.flutter_gemma.PreferredBackendEnum +import dev.flutterberlin.flutter_gemma.PreferredBackend import kotlinx.coroutines.channels.BufferOverflow import kotlinx.coroutines.flow.MutableSharedFlow @@ -11,7 +11,7 @@ data class EngineConfig( val modelPath: String, val maxTokens: Int, val supportedLoraRanks: List? = null, - val preferredBackend: PreferredBackendEnum? = null, + val preferredBackend: PreferredBackend? = null, val maxNumImages: Int? = null, ) diff --git a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/EngineFactory.kt b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/EngineFactory.kt index 527fa17a..ae4d1de9 100644 --- a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/EngineFactory.kt +++ b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/EngineFactory.kt @@ -27,10 +27,17 @@ object EngineFactory { modelPath.endsWith(".task", ignoreCase = true) -> MediaPipeEngine(context) modelPath.endsWith(".bin", ignoreCase = true) -> MediaPipeEngine(context) modelPath.endsWith(".tflite", ignoreCase = true) -> MediaPipeEngine(context) - else -> throw IllegalArgumentException( - "Unsupported model format: ${modelPath.substringAfterLast('.')}. " + - "Supported: .litertlm (LiteRT-LM), .task/.bin/.tflite (MediaPipe)" - ) + else -> { + val extension = if (modelPath.contains('.')) { + modelPath.substringAfterLast('.') + } else { + "" + } + throw IllegalArgumentException( + "Unsupported model format: .$extension. " + + "Supported: .litertlm (LiteRT-LM), .task/.bin/.tflite (MediaPipe)" + ) + } } } diff --git a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmEngine.kt b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmEngine.kt index 523abbf8..2e562cb4 100644 --- a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmEngine.kt +++ b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmEngine.kt @@ -5,7 +5,7 @@ import android.util.Log import com.google.ai.edge.litertlm.Backend import com.google.ai.edge.litertlm.Engine import com.google.ai.edge.litertlm.EngineConfig as LiteRtEngineConfig -import dev.flutterberlin.flutter_gemma.PreferredBackendEnum +import dev.flutterberlin.flutter_gemma.PreferredBackend import dev.flutterberlin.flutter_gemma.engines.* import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.SharedFlow @@ -54,14 +54,12 @@ class LiteRtLmEngine( throw IllegalArgumentException("Model not found at path: ${config.modelPath}") } - // Map PreferredBackendEnum to LiteRT-LM Backend + // Map PreferredBackend to LiteRT-LM Backend val backend = when (config.preferredBackend) { - PreferredBackendEnum.GPU, - PreferredBackendEnum.GPU_FLOAT16, - PreferredBackendEnum.GPU_MIXED, - PreferredBackendEnum.GPU_FULL -> Backend.GPU - PreferredBackendEnum.CPU -> Backend.CPU - else -> Backend.CPU // Default to CPU for safety + PreferredBackend.GPU -> Backend.GPU + PreferredBackend.NPU -> Backend.NPU // LiteRT-LM supports NPU (Google Tensor, Qualcomm) + PreferredBackend.CPU, + null -> Backend.CPU } try { diff --git a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/mediapipe/MediaPipeEngine.kt b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/mediapipe/MediaPipeEngine.kt index c6ee1198..c47ec3f1 100644 --- a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/mediapipe/MediaPipeEngine.kt +++ b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/engines/mediapipe/MediaPipeEngine.kt @@ -2,7 +2,7 @@ package dev.flutterberlin.flutter_gemma.engines.mediapipe import android.content.Context import com.google.mediapipe.tasks.genai.llminference.LlmInference -import dev.flutterberlin.flutter_gemma.PreferredBackendEnum +import dev.flutterberlin.flutter_gemma.PreferredBackend import dev.flutterberlin.flutter_gemma.engines.* import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.SharedFlow @@ -54,15 +54,11 @@ class MediaPipeEngine( .apply { config.supportedLoraRanks?.let { setSupportedLoraRanks(it) } config.preferredBackend?.let { - // Explicit mapping instead of ordinal-based (safer) + // Map to MediaPipe Backend (NPU not supported) val backendEnum: LlmInference.Backend? = when (it) { - PreferredBackendEnum.CPU -> LlmInference.Backend.CPU - PreferredBackendEnum.GPU, - PreferredBackendEnum.GPU_FLOAT16, - PreferredBackendEnum.GPU_MIXED, - PreferredBackendEnum.GPU_FULL -> LlmInference.Backend.GPU - PreferredBackendEnum.UNKNOWN, - PreferredBackendEnum.TPU -> null // Not supported by MediaPipe, use default + PreferredBackend.CPU -> LlmInference.Backend.CPU + PreferredBackend.GPU -> LlmInference.Backend.GPU + PreferredBackend.NPU -> null // MediaPipe doesn't support NPU } backendEnum?.let { backend -> setPreferredBackend(backend) } } diff --git a/android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmEngineTest.kt b/android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmEngineTest.kt index 9156f15e..93d4236b 100644 --- a/android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmEngineTest.kt +++ b/android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmEngineTest.kt @@ -1,7 +1,7 @@ package dev.flutterberlin.flutter_gemma.engines.litertlm import android.content.Context -import dev.flutterberlin.flutter_gemma.PreferredBackendEnum +import dev.flutterberlin.flutter_gemma.PreferredBackend import dev.flutterberlin.flutter_gemma.engines.EngineConfig import dev.flutterberlin.flutter_gemma.engines.SessionConfig import kotlinx.coroutines.runBlocking @@ -153,12 +153,12 @@ class LiteRtLmEngineTest { val config = EngineConfig( modelPath = "/test/model.litertlm", maxTokens = 1024, - preferredBackend = PreferredBackendEnum.GPU + preferredBackend = PreferredBackend.GPU ) // Config creation should not throw assertNotNull(config) - assertEquals(PreferredBackendEnum.GPU, config.preferredBackend) + assertEquals(PreferredBackend.GPU, config.preferredBackend) } @Test @@ -166,21 +166,21 @@ class LiteRtLmEngineTest { val config = EngineConfig( modelPath = "/test/model.litertlm", maxTokens = 1024, - preferredBackend = PreferredBackendEnum.CPU + preferredBackend = PreferredBackend.CPU ) - assertEquals(PreferredBackendEnum.CPU, config.preferredBackend) + assertEquals(PreferredBackend.CPU, config.preferredBackend) } @Test - fun `config with GPU_FLOAT16 backend is accepted`() { + fun `config with NPU backend is accepted`() { val config = EngineConfig( modelPath = "/test/model.litertlm", maxTokens = 1024, - preferredBackend = PreferredBackendEnum.GPU_FLOAT16 + preferredBackend = PreferredBackend.NPU ) - assertEquals(PreferredBackendEnum.GPU_FLOAT16, config.preferredBackend) + assertEquals(PreferredBackend.NPU, config.preferredBackend) } @Test diff --git a/android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmSessionTest.kt b/android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmSessionTest.kt index 5d828dd1..02ccc96a 100644 --- a/android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmSessionTest.kt +++ b/android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmSessionTest.kt @@ -149,7 +149,7 @@ class LiteRtLmSessionTest { @Test fun `sizeInTokens handles empty string`() { val result = session.sizeInTokens("") - assertEquals("Empty string should return ~1 token", 0, result) + assertEquals("Empty string should return 0 tokens", 0, result) } @Test diff --git a/ios/Classes/FlutterGemmaPlugin.swift b/ios/Classes/FlutterGemmaPlugin.swift index 4575c7ed..7aeba928 100644 --- a/ios/Classes/FlutterGemmaPlugin.swift +++ b/ios/Classes/FlutterGemmaPlugin.swift @@ -289,7 +289,8 @@ class PlatformServiceImpl : NSObject, PlatformService, FlutterStreamHandler { print("[PLUGIN] Preferred backend: \(String(describing: preferredBackend))") // Convert PreferredBackend to useGPU boolean - let useGPU = preferredBackend == .gpu || preferredBackend == .gpuFloat16 || preferredBackend == .gpuMixed || preferredBackend == .gpuFull + // Note: NPU not supported for embeddings on iOS + let useGPU = preferredBackend == .gpu DispatchQueue.global(qos: .userInitiated).async { do { diff --git a/ios/Classes/PigeonInterface.g.swift b/ios/Classes/PigeonInterface.g.swift index 7f748783..e3d0084e 100644 --- a/ios/Classes/PigeonInterface.g.swift +++ b/ios/Classes/PigeonInterface.g.swift @@ -65,13 +65,9 @@ private func nilOrValue(_ value: Any?) -> T? { } enum PreferredBackend: Int { - case unknown = 0 - case cpu = 1 - case gpu = 2 - case gpuFloat16 = 3 - case gpuMixed = 4 - case gpuFull = 5 - case tpu = 6 + case cpu = 0 + case gpu = 1 + case npu = 2 } /// Generated class from Pigeon that represents data sent in messages. diff --git a/lib/pigeon.g.dart b/lib/pigeon.g.dart index cb750c38..3218ce8a 100644 --- a/lib/pigeon.g.dart +++ b/lib/pigeon.g.dart @@ -16,13 +16,9 @@ PlatformException _createConnectionError(String channelName) { } enum PreferredBackend { - unknown, cpu, gpu, - gpuFloat16, - gpuMixed, - gpuFull, - tpu, + npu, } class RetrievalResult { diff --git a/pigeon.dart b/pigeon.dart index 59e6b754..92c9c989 100644 --- a/pigeon.dart +++ b/pigeon.dart @@ -2,13 +2,9 @@ import 'package:pigeon/pigeon.dart'; // Command to generate pigeon files: dart run pigeon --input pigeon.dart enum PreferredBackend { - unknown, cpu, gpu, - gpuFloat16, - gpuMixed, - gpuFull, - tpu, + npu, // Supported by LiteRT-LM only (Google Tensor, Qualcomm NPU) } @ConfigurePigeon(PigeonOptions( From 627c3084d880a83b1e390404a18e60d6ffb2ad6c Mon Sep 17 00:00:00 2001 From: Sasha Denisov Date: Sat, 24 Jan 2026 11:07:25 +0500 Subject: [PATCH 3/5] Add PreferredBackend documentation with platform support matrix - Document backend support per platform (Android, iOS, Web, Desktop) - Clarify that CPU is not supported on Web (MediaPipe limitation) - Clarify that NPU is Android-only (.litertlm models) - Add docstrings to PreferredBackend enum in pigeon.dart - Update proto comments for desktop backend options --- CLAUDE.md | 103 +++++++++++++++++- README.md | 15 ++- .../flutter_gemma/PigeonInterface.g.kt | 10 ++ ios/Classes/PigeonInterface.g.swift | 8 ++ lib/pigeon.g.dart | 8 ++ litertlm-server/src/main/proto/litertlm.proto | 5 +- pigeon.dart | 10 +- 7 files changed, 153 insertions(+), 6 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 9248450e..7c5ec067 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -99,10 +99,20 @@ final spec = InferenceModelSpec( **Runtime accepts configuration each time:** - `maxTokens` - Context size (default: 1024) -- `preferredBackend` - CPU/GPU preference +- `preferredBackend` - Hardware backend (see PreferredBackend below) - `supportImage` - Multimodal support - `maxNumImages` - Image limits +**PreferredBackend enum:** +| Value | Android | iOS | Web | Desktop | +|-------|---------|-----|-----|---------| +| `cpu` | ✅ | ✅ | ❌ | ✅ | +| `gpu` | ✅ | ✅ | ✅ (required) | ✅ | +| `npu` | ✅ (.litertlm) | ❌ | ❌ | ❌ | + +> - **NPU**: Qualcomm, MediaTek, Google Tensor. Up to 25x faster than CPU. +> - **Web**: GPU only (MediaPipe limitation). + **Usage:** ```dart // Step 1: Install with identity @@ -515,6 +525,75 @@ use_frameworks! :linkage => :static ``` +#### Android LiteRT-LM Engine (v0.12.x+) + +Android now supports **dual inference engines** - MediaPipe and LiteRT-LM - with automatic selection based on file extension. + +**Engine Selection:** +| File Extension | Engine | Android | Desktop | Web | +|----------------|--------|---------|---------|-----| +| `.task`, `.bin`, `.tflite` | MediaPipe | Yes | No | Yes | +| `.litertlm` | LiteRT-LM | Yes | Yes | No | + +**Architecture:** +``` +android/src/main/kotlin/dev/flutterberlin/flutter_gemma/ +├── FlutterGemmaPlugin.kt # Plugin entry point +├── PlatformService.g.kt # Pigeon-generated interface +└── engines/ # Engine abstraction layer + ├── InferenceEngine.kt # Strategy interface + ├── InferenceSession.kt # Session interface + ├── EngineConfig.kt # Configuration data classes + ├── EngineFactory.kt # Factory for engine creation + ├── FlowFactory.kt # SharedFlow factory + ├── mediapipe/ + │ ├── MediaPipeEngine.kt # MediaPipe adapter (wraps LlmInference) + │ └── MediaPipeSession.kt # MediaPipe session adapter + └── litertlm/ + ├── LiteRtLmEngine.kt # LiteRT-LM implementation + └── LiteRtLmSession.kt # LiteRT-LM session with chunk buffering +``` + +**Key Design Decisions:** + +1. **Strategy Pattern**: `InferenceEngine` interface allows interchangeable engine implementations +2. **Adapter Pattern**: `MediaPipeEngine` wraps existing MediaPipe code without modifications +3. **Chunk Buffering**: LiteRT-LM uses `sendMessage()` not `addQueryChunk()`, so `LiteRtLmSession` buffers chunks in `StringBuilder` and sends complete message on `generateResponse()` + +**LiteRT-LM Limitations:** + +⚠️ **Token Counting**: LiteRT-LM SDK does not expose tokenizer API. The implementation uses an estimate of ~4 characters per token with a warning log: +```kotlin +Log.w(TAG, "sizeInTokens: LiteRT-LM does not support token counting. " + + "Using estimate (~4 chars/token): $estimate tokens for ${prompt.length} chars. " + + "This may be inaccurate for non-English text.") +``` + +⚠️ **Cancellation**: `cancelGeneration()` is not yet supported by LiteRT-LM SDK 0.9.x + +**LiteRT-LM Behavioral Differences:** + +1. **Chunk Buffering**: Unlike MediaPipe which processes `addQueryChunk()` directly, LiteRT-LM buffers chunks in `StringBuilder` and sends complete message on `generateResponse()`. +2. **Thread-Safe Accumulation**: Uses `synchronized(promptLock)` for safe concurrent chunk additions. +3. **Cache Support**: Engine configured with `cacheDir` for faster reloads (~10s cold → ~1-2s cached). + +**Dependency (build.gradle):** +```gradle +implementation 'com.google.ai.edge.litertlm:litertlm-android:0.9.0-alpha01' +``` + +**Usage (Dart - no changes required):** +```dart +// Engine is automatically selected based on file extension +await FlutterGemma.installModel(modelType: ModelType.gemmaIt) + .fromNetwork('https://example.com/model.litertlm') // → LiteRtLmEngine + .install(); + +await FlutterGemma.installModel(modelType: ModelType.gemmaIt) + .fromNetwork('https://example.com/model.task') // → MediaPipeEngine + .install(); +``` + #### Web Configuration ```html @@ -1080,7 +1159,27 @@ flutter_gemma/ └── CLAUDE.md # This file ``` -## Recent Updates (2026-01-01) +## Recent Updates (2026-01-18) + +### ✅ Android LiteRT-LM Engine (v0.12.x+) +- **Dual Engine Support** - MediaPipe and LiteRT-LM on Android +- **Automatic Selection** - Engine chosen by file extension (`.litertlm` → LiteRT-LM, `.task/.bin` → MediaPipe) +- **Strategy Pattern** - `InferenceEngine` interface with interchangeable implementations +- **Adapter Pattern** - `MediaPipeEngine` wraps existing code without modifications +- **Chunk Buffering** - LiteRT-LM session buffers `addQueryChunk()` calls for `sendMessage()` API +- **Token Estimation** - ~4 chars/token with warning log (LiteRT-LM lacks tokenizer API) +- **Zero Flutter API Changes** - Transparent to Dart layer + +**Key Files:** +- `android/.../engines/InferenceEngine.kt` - Strategy interface +- `android/.../engines/EngineFactory.kt` - Factory for engine creation +- `android/.../engines/mediapipe/` - MediaPipe adapter +- `android/.../engines/litertlm/` - LiteRT-LM implementation + +**Dependency:** +```gradle +implementation 'com.google.ai.edge.litertlm:litertlm-android:0.9.0-alpha01' +``` ### ✅ Desktop Platform Support (v0.12.0+) - **macOS, Windows, Linux** support via LiteRT-LM JVM diff --git a/README.md b/README.md index d87ac13f..33bc39b9 100644 --- a/README.md +++ b/README.md @@ -1086,6 +1086,18 @@ final inferenceModel = await FlutterGemmaPlugin.instance.createModel( ); ``` +**PreferredBackend Options:** + +| Backend | Android | iOS | Web | Desktop | +|---------|---------|-----|-----|---------| +| `cpu` | ✅ | ✅ | ❌ | ✅ | +| `gpu` | ✅ | ✅ | ✅ (required) | ✅ | +| `npu` | ✅ (.litertlm) | ❌ | ❌ | ❌ | + +- **NPU**: Qualcomm AI Engine, MediaTek NeuroPilot, Google Tensor. Up to 25x faster than CPU. +- **Web**: GPU only (MediaPipe limitation). CPU models will fail to initialize. +- **Desktop**: GPU uses Metal (macOS), DirectX 12 (Windows), Vulkan (Linux). + 6.**Using Sessions for Single Inferences:** If you need to generate individual responses without maintaining a conversation history, use sessions. Sessions allow precise control over inference and must be properly closed to avoid memory leaks. @@ -2007,8 +2019,7 @@ final supported = await FlutterGemma.isStreamingSupported(); ``` #### Backend Support -- **GPU only:** Web platform requires GPU backend (MediaPipe limitation) -- **CPU models:** ❌ Will fail to initialize on web +- **GPU only:** See [PreferredBackend Options](#preferredbackend-options) table above #### CORS Configuration - **Required for custom servers:** Enable CORS headers on your model hosting server diff --git a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/PigeonInterface.g.kt b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/PigeonInterface.g.kt index acceabe4..63f3b6f8 100644 --- a/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/PigeonInterface.g.kt +++ b/android/src/main/kotlin/dev/flutterberlin/flutter_gemma/PigeonInterface.g.kt @@ -46,6 +46,16 @@ class FlutterError ( val details: Any? = null ) : Throwable() +/** + * Hardware backend for model inference. + * + * Platform support: + * - [cpu]: All platforms + * - [gpu]: All platforms (Metal on macOS, DirectX on Windows, Vulkan on Linux, OpenCL on Android) + * - [npu]: Android only with LiteRT-LM (.litertlm models) - Qualcomm, MediaTek, Google Tensor + * + * If selected backend is unavailable, engine falls back to GPU, then CPU. + */ enum class PreferredBackend(val raw: Int) { CPU(0), GPU(1), diff --git a/ios/Classes/PigeonInterface.g.swift b/ios/Classes/PigeonInterface.g.swift index e3d0084e..0988520d 100644 --- a/ios/Classes/PigeonInterface.g.swift +++ b/ios/Classes/PigeonInterface.g.swift @@ -64,6 +64,14 @@ private func nilOrValue(_ value: Any?) -> T? { return value as! T? } +/// Hardware backend for model inference. +/// +/// Platform support: +/// - [cpu]: All platforms +/// - [gpu]: All platforms (Metal on macOS, DirectX on Windows, Vulkan on Linux, OpenCL on Android) +/// - [npu]: Android only with LiteRT-LM (.litertlm models) - Qualcomm, MediaTek, Google Tensor +/// +/// If selected backend is unavailable, engine falls back to GPU, then CPU. enum PreferredBackend: Int { case cpu = 0 case gpu = 1 diff --git a/lib/pigeon.g.dart b/lib/pigeon.g.dart index 3218ce8a..2c35ca3b 100644 --- a/lib/pigeon.g.dart +++ b/lib/pigeon.g.dart @@ -15,6 +15,14 @@ PlatformException _createConnectionError(String channelName) { ); } +/// Hardware backend for model inference. +/// +/// Platform support: +/// - [cpu]: All platforms +/// - [gpu]: All platforms (Metal on macOS, DirectX on Windows, Vulkan on Linux, OpenCL on Android) +/// - [npu]: Android only with LiteRT-LM (.litertlm models) - Qualcomm, MediaTek, Google Tensor +/// +/// If selected backend is unavailable, engine falls back to GPU, then CPU. enum PreferredBackend { cpu, gpu, diff --git a/litertlm-server/src/main/proto/litertlm.proto b/litertlm-server/src/main/proto/litertlm.proto index b7b98470..3dc9b0ee 100644 --- a/litertlm-server/src/main/proto/litertlm.proto +++ b/litertlm-server/src/main/proto/litertlm.proto @@ -30,7 +30,10 @@ service LiteRtLmService { message InitializeRequest { string model_path = 1; - string backend = 2; // "cpu", "gpu" + // Backend: "cpu" or "gpu" + // GPU uses Metal (macOS), DirectX 12 (Windows), Vulkan (Linux) + // Note: "npu" is not supported on desktop (Android only) + string backend = 2; int32 max_tokens = 3; bool enable_vision = 4; int32 max_num_images = 5; diff --git a/pigeon.dart b/pigeon.dart index 92c9c989..0b5a727b 100644 --- a/pigeon.dart +++ b/pigeon.dart @@ -1,10 +1,18 @@ import 'package:pigeon/pigeon.dart'; // Command to generate pigeon files: dart run pigeon --input pigeon.dart +/// Hardware backend for model inference. +/// +/// Platform support: +/// - [cpu]: All platforms +/// - [gpu]: All platforms (Metal on macOS, DirectX on Windows, Vulkan on Linux, OpenCL on Android) +/// - [npu]: Android only with LiteRT-LM (.litertlm models) - Qualcomm, MediaTek, Google Tensor +/// +/// If selected backend is unavailable, engine falls back to GPU, then CPU. enum PreferredBackend { cpu, gpu, - npu, // Supported by LiteRT-LM only (Google Tensor, Qualcomm NPU) + npu, // Android only: Qualcomm AI Engine, MediaTek NeuroPilot, Google Tensor } @ConfigurePigeon(PigeonOptions( From 043271961c75b8a922d67a4e2c3f027545e152f9 Mon Sep 17 00:00:00 2001 From: Sasha Denisov Date: Sat, 24 Jan 2026 11:08:18 +0500 Subject: [PATCH 4/5] Fix misleading test: sizeInTokens counts prompt, not accumulated chunks --- .../engines/litertlm/LiteRtLmSessionTest.kt | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmSessionTest.kt b/android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmSessionTest.kt index 02ccc96a..5b2f76a5 100644 --- a/android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmSessionTest.kt +++ b/android/src/test/kotlin/dev/flutterberlin/flutter_gemma/engines/litertlm/LiteRtLmSessionTest.kt @@ -50,15 +50,15 @@ class LiteRtLmSessionTest { // =========================================== @Test - fun `addQueryChunk accumulates text`() { + fun `addQueryChunk accumulates text without throwing`() { + // Note: LiteRT-LM buffers chunks until generateResponse() is called + // We can only verify that accumulation doesn't throw session.addQueryChunk("Hello, ") session.addQueryChunk("world!") - // Token count estimate should reflect accumulated length - // "Hello, world!" = 13 chars → ~3-4 tokens - val tokenCount = session.sizeInTokens("") - // This verifies internal state indirectly - assertTrue("Token count should be reasonable", tokenCount >= 0) + // Verify sizeInTokens works independently (estimates tokens for given prompt) + val tokenCount = session.sizeInTokens("Hello, world!") + assertEquals("13 chars → (13+3)/4 = 4 tokens", 4, tokenCount) } @Test From 016e389727e8fcbd4a65347b1b167771e4bc85fd Mon Sep 17 00:00:00 2001 From: Sasha Denisov Date: Sat, 24 Jan 2026 12:57:52 +0500 Subject: [PATCH 5/5] Fix CLAUDE.md: FlowFactory is in EngineConfig.kt, not separate file --- CLAUDE.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 7c5ec067..ed5f2013 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -543,9 +543,8 @@ android/src/main/kotlin/dev/flutterberlin/flutter_gemma/ └── engines/ # Engine abstraction layer ├── InferenceEngine.kt # Strategy interface ├── InferenceSession.kt # Session interface - ├── EngineConfig.kt # Configuration data classes + ├── EngineConfig.kt # Configuration + SessionConfig + FlowFactory ├── EngineFactory.kt # Factory for engine creation - ├── FlowFactory.kt # SharedFlow factory ├── mediapipe/ │ ├── MediaPipeEngine.kt # MediaPipe adapter (wraps LlmInference) │ └── MediaPipeSession.kt # MediaPipe session adapter