diff --git a/packages/firebase_ai/firebase_ai/example/lib/main.dart b/packages/firebase_ai/firebase_ai/example/lib/main.dart index a7fd0363aba7..584edfa8079f 100644 --- a/packages/firebase_ai/firebase_ai/example/lib/main.dart +++ b/packages/firebase_ai/firebase_ai/example/lib/main.dart @@ -80,7 +80,7 @@ class _GenerativeAISampleState extends State { imageFormat: ImagenFormat.jpeg(compressionQuality: 75), ); return instance.imagenModel( - model: 'imagen-3.0-generate-002', + model: 'imagen-3.0-capability-001', generationConfig: generationConfig, safetySettings: ImagenSafetySettings( ImagenSafetyFilterLevel.blockLowAndAbove, diff --git a/packages/firebase_ai/firebase_ai/example/lib/pages/audio_page.dart b/packages/firebase_ai/firebase_ai/example/lib/pages/audio_page.dart index 8708bcb01c15..be04d6a2db30 100644 --- a/packages/firebase_ai/firebase_ai/example/lib/pages/audio_page.dart +++ b/packages/firebase_ai/firebase_ai/example/lib/pages/audio_page.dart @@ -137,7 +137,11 @@ class _AudioPageState extends State { itemBuilder: (context, idx) { return MessageWidget( text: _messages[idx].text, - image: _messages[idx].image, + image: Image.memory( + _messages[idx].imageBytes!, + cacheWidth: 400, + cacheHeight: 400, + ), isFromUser: _messages[idx].fromUser ?? false, ); }, diff --git a/packages/firebase_ai/firebase_ai/example/lib/pages/bidi_page.dart b/packages/firebase_ai/firebase_ai/example/lib/pages/bidi_page.dart index cb221329d28f..3b7cfd334bc5 100644 --- a/packages/firebase_ai/firebase_ai/example/lib/pages/bidi_page.dart +++ b/packages/firebase_ai/firebase_ai/example/lib/pages/bidi_page.dart @@ -119,7 +119,11 @@ class _BidiPageState extends State { itemBuilder: (context, idx) { return MessageWidget( text: _messages[idx].text, - image: _messages[idx].image, + image: Image.memory( + _messages[idx].imageBytes!, + cacheWidth: 400, + cacheHeight: 400, + ), isFromUser: _messages[idx].fromUser ?? false, ); }, diff --git a/packages/firebase_ai/firebase_ai/example/lib/pages/chat_page.dart b/packages/firebase_ai/firebase_ai/example/lib/pages/chat_page.dart index df0afea88482..eb8e6128f2fc 100644 --- a/packages/firebase_ai/firebase_ai/example/lib/pages/chat_page.dart +++ b/packages/firebase_ai/firebase_ai/example/lib/pages/chat_page.dart @@ -70,7 +70,11 @@ class _ChatPageState extends State { itemBuilder: (context, idx) { return MessageWidget( text: _messages[idx].text, - image: _messages[idx].image, + image: Image.memory( + _messages[idx].imageBytes!, + cacheWidth: 400, + cacheHeight: 400, + ), isFromUser: _messages[idx].fromUser ?? false, ); }, diff --git a/packages/firebase_ai/firebase_ai/example/lib/pages/image_prompt_page.dart b/packages/firebase_ai/firebase_ai/example/lib/pages/image_prompt_page.dart index 5dff25a2efe1..5409c264450b 100644 --- a/packages/firebase_ai/firebase_ai/example/lib/pages/image_prompt_page.dart +++ b/packages/firebase_ai/firebase_ai/example/lib/pages/image_prompt_page.dart @@ -65,7 +65,11 @@ class _ImagePromptPageState extends State { var content = _generatedContent[idx]; return MessageWidget( text: content.text, - image: content.image, + image: Image.memory( + content.imageBytes!, + cacheWidth: 400, + cacheHeight: 400, + ), isFromUser: content.fromUser ?? false, ); }, @@ -137,14 +141,14 @@ class _ImagePromptPageState extends State { ]; _generatedContent.add( MessageData( - image: Image.asset('assets/images/cat.jpg'), + imageBytes: catBytes.buffer.asUint8List(), text: message, fromUser: true, ), ); _generatedContent.add( MessageData( - image: Image.asset('assets/images/scones.jpg'), + imageBytes: sconeBytes.buffer.asUint8List(), fromUser: true, ), ); diff --git a/packages/firebase_ai/firebase_ai/example/lib/pages/imagen_page.dart b/packages/firebase_ai/firebase_ai/example/lib/pages/imagen_page.dart index c957f207278e..ed016fb03d86 100644 --- a/packages/firebase_ai/firebase_ai/example/lib/pages/imagen_page.dart +++ b/packages/firebase_ai/firebase_ai/example/lib/pages/imagen_page.dart @@ -12,11 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -import 'package:flutter/material.dart'; +import 'dart:typed_data'; + +import 'package:image_picker/image_picker.dart'; import 'package:firebase_ai/firebase_ai.dart'; + +import 'package:flutter/material.dart'; //import 'package:firebase_storage/firebase_storage.dart'; import '../widgets/message_widget.dart'; +// Define a constant for the history limit +const int _MAX_HISTORY = 4; + class ImagenPage extends StatefulWidget { const ImagenPage({ super.key, @@ -38,6 +45,10 @@ class _ImagenPageState extends State { final List _generatedContent = []; bool _loading = false; + // For image picking + ImagenInlineImage? _sourceImage; + ImagenInlineImage? _maskImageForEditing; + void _scrollDown() { WidgetsBinding.instance.addPostFrameCallback( (_) => _scrollController.animateTo( @@ -68,7 +79,11 @@ class _ImagenPageState extends State { itemBuilder: (context, idx) { return MessageWidget( text: _generatedContent[idx].text, - image: _generatedContent[idx].image, + image: Image.memory( + _generatedContent[idx].imageBytes!, + cacheWidth: 400, + cacheHeight: 400, + ), isFromUser: _generatedContent[idx].fromUser ?? false, ); }, @@ -80,45 +95,89 @@ class _ImagenPageState extends State { vertical: 25, horizontal: 15, ), - child: Row( + child: Column( children: [ - Expanded( - child: TextField( - autofocus: true, - focusNode: _textFieldFocus, - controller: _textController, - ), - ), - const SizedBox.square( - dimension: 15, - ), - if (!_loading) - IconButton( - onPressed: () async { - await _testImagen(_textController.text); - }, - icon: Icon( - Icons.image_search, - color: Theme.of(context).colorScheme.primary, + // Generate Image Row + Row( + children: [ + Expanded( + child: TextField( + autofocus: true, + focusNode: _textFieldFocus, + decoration: const InputDecoration( + hintText: 'Enter a prompt...', + ), + controller: _textController, + ), + ), + const SizedBox.square(dimension: 15), + IconButton( + onPressed: () async { + await _pickSourceImage(); + }, + icon: Icon( + Icons.add_a_photo, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Pick Source Image', + ), + IconButton( + onPressed: () async { + await _pickMaskImage(); + }, + icon: Icon( + Icons.add_to_photos, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Pick mask', + ), + IconButton( + onPressed: () async { + await _editWithStyle(); + }, + icon: Icon( + Icons.edit, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Edit with Style', ), - tooltip: 'Imagen raw data', - ) - else - const CircularProgressIndicator(), - // NOTE: Keep this API private until future release. - // if (!_loading) - // IconButton( - // onPressed: () async { - // await _testImagenGCS(_textController.text); - // }, - // icon: Icon( - // Icons.imagesearch_roller, - // color: Theme.of(context).colorScheme.primary, - // ), - // tooltip: 'Imagen GCS', - // ) - // else - // const CircularProgressIndicator(), + IconButton( + onPressed: () async { + await _outpaintImageHappyPath(); + }, + icon: Icon( + Icons.masks, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Outpaint', + ), + IconButton( + onPressed: () async { + await _inpaintImageHappyPath(); + }, + icon: Icon( + Icons.plus_one, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Inpaint', + ), + if (!_loading) + IconButton( + onPressed: () async { + await _generateImageFromPrompt( + _textController.text, + ); + }, + icon: Icon( + Icons.image_search, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Generate Image', + ) + else + const CircularProgressIndicator(), + ], + ), ], ), ), @@ -128,23 +187,216 @@ class _ImagenPageState extends State { ); } - Future _testImagen(String prompt) async { + Future _pickImage() async { + final ImagePicker picker = ImagePicker(); + try { + final XFile? imageFile = + await picker.pickImage(source: ImageSource.gallery); + if (imageFile != null) { + // Attempt to get mimeType, default if null. + // Note: imageFile.mimeType might be null on some platforms or for some files. + final String mimeType = imageFile.mimeType ?? 'image/jpeg'; + final Uint8List imageBytes = await imageFile.readAsBytes(); + return ImagenInlineImage( + bytesBase64Encoded: imageBytes, mimeType: mimeType); + } + } catch (e) { + _showError('Error picking image: $e'); + } + return null; + } + + Future _pickSourceImage() async { + final pickedImage = await _pickImage(); + if (pickedImage != null) { + setState(() { + _sourceImage = pickedImage; + }); + } + } + + Future _pickMaskImage() async { + final pickedImage = await _pickImage(); + if (pickedImage != null) { + setState(() { + _maskImageForEditing = pickedImage; + }); + } + } + + Future _inpaintImageHappyPath() async { + if (_sourceImage == null) { + _showError('Please pick a source image for inpaint insertion.'); + return; + } setState(() { _loading = true; }); + final String prompt = _textController.text; + final promptMessage = MessageData( + imageBytes: _sourceImage!.bytesBase64Encoded, + text: 'Try to inpaint image with prompt: $prompt', + fromUser: true, + ); + + MessageData? resultMessage; + + try { + final response = await widget.model.inpaintImage( + _sourceImage!, + prompt, + ImagenBackgroundMask(), + config: ImagenEditingConfig(editMode: ImagenEditMode.inpaintInsertion), + ); + if (response.images.isNotEmpty) { + final inpaintImage = response.images[0]; + resultMessage = MessageData( + imageBytes: inpaintImage.bytesBase64Encoded, + text: 'Inpaint image result with prompt: $prompt', + fromUser: false, + ); + } else { + _showError('No image was returned from inpaint.'); + } + } catch (e) { + _showError('Error inpaint image: $e'); + } + + setState(() { + _generatedContent.add(promptMessage); + if (resultMessage != null) { + _generatedContent.add(resultMessage); + } + // Apply history limit here + while (_generatedContent.length > _MAX_HISTORY) { + _generatedContent.removeAt(0); + } + _loading = false; + _scrollDown(); + }); + } + + Future _outpaintImageHappyPath() async { + if (_sourceImage == null) { + _showError('Please pick a source image for outpainting.'); + return; + } + setState(() { + _loading = true; + }); + + final promptMessage = MessageData( + imageBytes: _sourceImage!.bytesBase64Encoded, + text: 'Outpaint the picture to 1400*1400', + fromUser: true, + ); + + MessageData? resultMessage; + try { + final response = await widget.model.outpaintImage( + _sourceImage!, + ImagenDimensions(width: 1400, height: 1400), + ); + if (response.images.isNotEmpty) { + final editedImage = response.images[0]; + resultMessage = MessageData( + imageBytes: editedImage.bytesBase64Encoded, + text: 'Edited image Outpaint 1400*1400', + fromUser: false, + ); + } else { + _showError('No image was returned from editing.'); + } + } catch (e) { + _showError('Error editing image: $e'); + } + + setState(() { + _generatedContent.add(promptMessage); + if (resultMessage != null) { + _generatedContent.add(resultMessage); + } + // Apply history limit here + while (_generatedContent.length > _MAX_HISTORY) { + _generatedContent.removeAt(0); + } + _loading = false; + _scrollDown(); + }); + } + + Future _editWithStyle() async { + if (_sourceImage == null) { + _showError('Please pick a source image for style editing.'); + return; + } + setState(() { + _loading = true; + }); + + final String prompt = _textController.text; + final promptMessage = MessageData( + imageBytes: _sourceImage!.bytesBase64Encoded, + text: prompt, + fromUser: true, + ); + MessageData? resultMessage; + try { + final response = await widget.model.editImage( + [ + ImagenStyleReference( + image: _sourceImage!, + description: 'van goh style', + ), + ], + prompt, + config: ImagenEditingConfig(editSteps: 50), + ); + if (response.images.isNotEmpty) { + final editedImage = response.images[0]; + + resultMessage = MessageData( + imageBytes: editedImage.bytesBase64Encoded, + text: 'Edited image with style: $prompt', + fromUser: false, + ); + } else { + _showError('No image was returned from style editing.'); + } + } catch (e) { + _showError('Error performing style edit: $e'); + } + + setState(() { + _generatedContent.add(promptMessage); + if (resultMessage != null) { + _generatedContent.add(resultMessage); + } + // Apply history limit here + while (_generatedContent.length > _MAX_HISTORY) { + _generatedContent.removeAt(0); + } + _loading = false; + _scrollDown(); + }); + } + + Future _generateImageFromPrompt(String prompt) async { + setState(() { + _loading = true; + }); + MessageData? resultMessage; try { var response = await widget.model.generateImages(prompt); if (response.images.isNotEmpty) { var imagenImage = response.images[0]; - _generatedContent.add( - MessageData( - image: Image.memory(imagenImage.bytesBase64Encoded), - text: prompt, - fromUser: false, - ), + resultMessage = MessageData( + imageBytes: imagenImage.bytesBase64Encoded, + text: prompt, + fromUser: false, ); } else { // Handle the case where no images were generated @@ -155,6 +407,13 @@ class _ImagenPageState extends State { } setState(() { + if (resultMessage != null) { + _generatedContent.add(resultMessage); + } + // Apply history limit here + while (_generatedContent.length > _MAX_HISTORY) { + _generatedContent.removeAt(0); + } _loading = false; _scrollDown(); }); diff --git a/packages/firebase_ai/firebase_ai/example/lib/widgets/message_widget.dart b/packages/firebase_ai/firebase_ai/example/lib/widgets/message_widget.dart index b8a0f23ce03b..368dfc1fea88 100644 --- a/packages/firebase_ai/firebase_ai/example/lib/widgets/message_widget.dart +++ b/packages/firebase_ai/firebase_ai/example/lib/widgets/message_widget.dart @@ -11,13 +11,13 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - +import 'dart:typed_data'; import 'package:flutter/material.dart'; import 'package:flutter_markdown/flutter_markdown.dart'; class MessageData { - MessageData({this.image, this.text, this.fromUser}); - final Image? image; + MessageData({this.imageBytes, this.text, this.fromUser}); + final Uint8List? imageBytes; final String? text; final bool? fromUser; } diff --git a/packages/firebase_ai/firebase_ai/example/macos/Runner/DebugProfile.entitlements b/packages/firebase_ai/firebase_ai/example/macos/Runner/DebugProfile.entitlements index b4bd9ee174a1..8560da29b687 100644 --- a/packages/firebase_ai/firebase_ai/example/macos/Runner/DebugProfile.entitlements +++ b/packages/firebase_ai/firebase_ai/example/macos/Runner/DebugProfile.entitlements @@ -14,5 +14,7 @@ com.apple.security.device.audio-input + com.apple.security.files.user-selected.read-only + diff --git a/packages/firebase_ai/firebase_ai/example/macos/Runner/Info.plist b/packages/firebase_ai/firebase_ai/example/macos/Runner/Info.plist index a81b3fd0d617..d4369e6253fa 100644 --- a/packages/firebase_ai/firebase_ai/example/macos/Runner/Info.plist +++ b/packages/firebase_ai/firebase_ai/example/macos/Runner/Info.plist @@ -30,5 +30,7 @@ NSApplication NSMicrophoneUsageDescription Permission to Record audio + NSPhotoLibraryUsageDescription + This app needs access to your photo library to let you select a profile picture. diff --git a/packages/firebase_ai/firebase_ai/example/pubspec.yaml b/packages/firebase_ai/firebase_ai/example/pubspec.yaml index 8cc5078bd03f..0431837579a8 100644 --- a/packages/firebase_ai/firebase_ai/example/pubspec.yaml +++ b/packages/firebase_ai/firebase_ai/example/pubspec.yaml @@ -27,6 +27,7 @@ dependencies: sdk: flutter flutter_markdown: ^0.6.20 flutter_soloud: ^3.1.6 + image_picker: ^1.1.2 path_provider: ^2.1.5 record: ^5.2.1 diff --git a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart index 0587c156f9a5..7c51a41a389c 100644 --- a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart +++ b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +import 'src/imagen/imagen_reference.dart'; + export 'src/api.dart' show BlockReason, @@ -51,7 +53,7 @@ export 'src/error.dart' ServerException, UnsupportedUserLocation; export 'src/firebase_ai.dart' show FirebaseAI; -export 'src/imagen_api.dart' +export 'src/imagen/imagen_api.dart' show ImagenSafetySettings, ImagenFormat, @@ -59,7 +61,32 @@ export 'src/imagen_api.dart' ImagenPersonFilterLevel, ImagenGenerationConfig, ImagenAspectRatio; -export 'src/imagen_content.dart' show ImagenInlineImage; +export 'src/imagen/imagen_content.dart' show ImagenInlineImage; +export 'src/imagen/imagen_edit.dart' + show + ImagenEditMode, + ImagenSubjectReferenceType, + ImagenControlType, + ImagenMaskMode, + ImagenMaskConfig, + ImagenSubjectConfig, + ImagenStyleConfig, + ImagenControlConfig, + ImagenEditingConfig, + ImagenDimensions, + ImagenImagePlacement; +export 'src/imagen/imagen_reference.dart' + show + ImagenReferenceImage, + ImagenMaskReference, + ImagenRawImage, + ImagenRawMask, + ImagenSemanticMask, + ImagenBackgroundMask, + ImagenForegroundMask, + ImagenSubjectReference, + ImagenStyleReference, + ImagenControlReference; export 'src/live_api.dart' show LiveGenerationConfig, diff --git a/packages/firebase_ai/firebase_ai/lib/src/base_model.dart b/packages/firebase_ai/firebase_ai/lib/src/base_model.dart index 413af8ba49eb..de62a972dbfe 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/base_model.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/base_model.dart @@ -28,15 +28,17 @@ import 'api.dart'; import 'client.dart'; import 'content.dart'; import 'developer/api.dart'; -import 'imagen_api.dart'; -import 'imagen_content.dart'; +import 'imagen/imagen_api.dart'; +import 'imagen/imagen_content.dart'; +import 'imagen/imagen_edit.dart'; +import 'imagen/imagen_reference.dart'; import 'live_api.dart'; import 'live_session.dart'; import 'tool.dart'; import 'vertex_version.dart'; part 'generative_model.dart'; -part 'imagen_model.dart'; +part 'imagen/imagen_model.dart'; part 'live_model.dart'; /// [Task] enum class for [GenerativeModel] to make request. diff --git a/packages/firebase_ai/firebase_ai/lib/src/client.dart b/packages/firebase_ai/firebase_ai/lib/src/client.dart index ba3eed67b6fe..464698ac9eeb 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/client.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/client.dart @@ -63,9 +63,13 @@ final class HttpApiClient implements ApiClient { @override Future> makeRequest( Uri uri, Map body) async { + print(uri); + final headers = await _headers(); + print(headers); + print(body); final response = await (_httpClient?.post ?? http.post)( uri, - headers: await _headers(), + headers: headers, body: _utf8Json.encode(body), ); if (response.statusCode >= 500) { diff --git a/packages/firebase_ai/firebase_ai/lib/src/imagen_api.dart b/packages/firebase_ai/firebase_ai/lib/src/imagen/imagen_api.dart similarity index 100% rename from packages/firebase_ai/firebase_ai/lib/src/imagen_api.dart rename to packages/firebase_ai/firebase_ai/lib/src/imagen/imagen_api.dart diff --git a/packages/firebase_ai/firebase_ai/lib/src/imagen_content.dart b/packages/firebase_ai/firebase_ai/lib/src/imagen/imagen_content.dart similarity index 94% rename from packages/firebase_ai/firebase_ai/lib/src/imagen_content.dart rename to packages/firebase_ai/firebase_ai/lib/src/imagen/imagen_content.dart index 525cbeef44ed..2aafcf5e6c0b 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/imagen_content.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/imagen/imagen_content.dart @@ -13,8 +13,9 @@ // limitations under the License. import 'dart:convert'; import 'dart:typed_data'; +import 'dart:ui' as ui; import 'package:meta/meta.dart'; -import 'error.dart'; +import '../error.dart'; /// Base type of Imagen Image. sealed class ImagenImage { @@ -59,6 +60,12 @@ final class ImagenInlineImage implements ImagenImage { 'mimeType': mimeType, 'bytesBase64Encoded': base64Encode(bytesBase64Encoded), }; + // Helper to decode bytes into a dart:ui.Image. + Future asUiImage() async { + final codec = await ui.instantiateImageCodec(bytesBase64Encoded); + final frame = await codec.getNextFrame(); + return frame.image; + } } /// Represents an image stored in Google Cloud Storage. diff --git a/packages/firebase_ai/firebase_ai/lib/src/imagen/imagen_edit.dart b/packages/firebase_ai/firebase_ai/lib/src/imagen/imagen_edit.dart new file mode 100644 index 000000000000..58a8769a1482 --- /dev/null +++ b/packages/firebase_ai/firebase_ai/lib/src/imagen/imagen_edit.dart @@ -0,0 +1,275 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'dart:convert'; + +import 'package:meta/meta.dart'; + +/// The desired outcome of the image editing. +@experimental +enum ImagenEditMode { + /// The result of the editing will be an insertion of the prompt in the masked + /// region. + inpaintInsertion('EDIT_MODE_INPAINT_INSERTION'), + + /// The result of the editing will be a removal of the masked region. + inpaintRemoval('EDIT_MODE_INPAINT_REMOVAL'), + + /// The result of the editing will be an outpainting of the source image. + outpaint('EDIT_MODE_OUTPAINT'); + + const ImagenEditMode(this._jsonString); + final String _jsonString; + // ignore: public_member_api_docs + String toJson() => _jsonString; +} + +/// The type of the subject in the image. +@experimental +enum ImagenSubjectReferenceType { + /// The subject is a person. + person('SUBJECT_TYPE_PERSON'), + + /// The subject is an animal. + animal('SUBJECT_TYPE_ANIMAL'), + + /// The subject is a product. + product('SUBJECT_TYPE_PRODUCT'); + + const ImagenSubjectReferenceType(this._jsonString); + final String _jsonString; + + // ignore: public_member_api_docs + String toJson() => _jsonString; +} + +/// The type of control image. +@experimental +enum ImagenControlType { + /// Use edge detection to ensure the new image follow the same outlines. + canny('CONTROL_TYPE_CANNY'), + + /// Use enhanced edge detection to ensure the new image follow similar + /// outlines. + scribble('CONTROL_TYPE_SCRIBBLE'), + + /// Use face mesh control to ensure that the new image has the same facial + /// expressions. + faceMesh('CONTROL_TYPE_FACE_MESH'), + + /// Use color superpixels to ensure that the new image is similar in shape + /// and color to the original. + colorSuperpixel('CONTROL_TYPE_COLOR_SUPERPIXEL'); + + const ImagenControlType(this._jsonString); + final String _jsonString; + + // ignore: public_member_api_docs + String toJson() => _jsonString; +} + +/// The mode of the mask. +@experimental +enum ImagenMaskMode { + /// The mask is user provided. + userProvided('MASK_MODE_USER_PROVIDED'), + + /// The mask is the background. + background('MASK_MODE_BACKGROUND'), + + /// The mask is the foreground. + foreground('MASK_MODE_FOREGROUND'), + + /// The mask is semantic. + semantic('MASK_MODE_SEMANTIC'); + + const ImagenMaskMode(this._jsonString); + final String _jsonString; + // ignore: public_member_api_docs + String toJson() => _jsonString; +} + +sealed class ImagenReferenceConfig { + /// Convert the [ImagenReferenceConfig] content to json format. + Map toJson(); +} + +/// The configuration for the mask. +@experimental +final class ImagenMaskConfig extends ImagenReferenceConfig { + ImagenMaskConfig({ + required this.maskType, + this.maskDilation, + this.maskClasses, + }); + + final ImagenMaskMode maskType; + final double? maskDilation; + final List? maskClasses; + + @override + Map toJson() => { + 'maskImageConfig': { + 'maskMode': maskType.toJson(), + if (maskDilation != null) 'dilation': maskDilation, + if (maskClasses != null) 'maskClasses': jsonEncode(maskClasses), + }, + }; +} + +/// The configuration for the subject. +@experimental +final class ImagenSubjectConfig extends ImagenReferenceConfig { + ImagenSubjectConfig({ + this.description, + this.type, + }); + + final String? description; + final ImagenSubjectReferenceType? type; + + @override + Map toJson() => { + 'subjectImageConfig': { + if (description != null) 'subjectDescription': description, + if (type != null) 'subjectType': type!.toJson(), + }, + }; +} + +/// The configuration for the style. +@experimental +final class ImagenStyleConfig extends ImagenReferenceConfig { + ImagenStyleConfig({ + this.description, + }); + + final String? description; + @override + Map toJson() => { + 'styleImageConfig': { + if (description != null) 'styleDescription': description, + }, + }; +} + +/// The configuration for the control. +@experimental +final class ImagenControlConfig extends ImagenReferenceConfig { + ImagenControlConfig({ + required this.controlType, + this.enableComputation, + this.superpixelRegionSize, + this.superpixelRuler, + }); + + final ImagenControlType controlType; + final bool? enableComputation; + final int? superpixelRegionSize; + final int? superpixelRuler; + @override + Map toJson() => { + 'controlImageConfig': { + 'controlType': controlType.toJson(), + if (enableComputation != null) + 'enableControlImageComputation': enableComputation, + if (superpixelRegionSize != null) + 'superpixelRegionSize': superpixelRegionSize, + if (superpixelRuler != null) 'superpixelRuler': superpixelRuler, + }, + }; +} + +/// The configuration for image editing. +@experimental +final class ImagenEditingConfig { + ImagenEditingConfig({ + this.editMode, + this.editSteps, + }); + + final ImagenEditMode? editMode; + final int? editSteps; +} + +/// The dimensions of an image. +@experimental +final class ImagenDimensions { + ImagenDimensions({ + required this.width, + required this.height, + }); + + final int width; + final int height; +} + +/// The placement of an image. +@experimental +final class ImagenImagePlacement { + const ImagenImagePlacement._(this.x, this.y); + + final int? x; + final int? y; + + /// Creates a placement from a coordinate. + static ImagenImagePlacement fromCoordinate(int x, int y) => + ImagenImagePlacement._(x, y); + + /// The center of the image. + static const ImagenImagePlacement center = ImagenImagePlacement._(null, null); + + /// The top center of the image. + static const ImagenImagePlacement topCenter = + ImagenImagePlacement._(null, null); + + /// The bottom center of the image. + static const ImagenImagePlacement bottomCenter = + ImagenImagePlacement._(null, null); + + /// The left center of the image. + static const ImagenImagePlacement leftCenter = + ImagenImagePlacement._(null, null); + + /// The right center of the image. + static const ImagenImagePlacement rightCenter = + ImagenImagePlacement._(null, null); + + /// The top left of the image. + static const ImagenImagePlacement topLeft = ImagenImagePlacement._(0, 0); + + /// The top right of the image. + static const ImagenImagePlacement topRight = + ImagenImagePlacement._(null, null); + + /// The bottom left of the image. + static const ImagenImagePlacement bottomLeft = + ImagenImagePlacement._(null, null); + + /// The bottom right of the image. + static const ImagenImagePlacement bottomRight = + ImagenImagePlacement._(null, null); + + /// A mock normalization function. + ImagenImagePlacement normalizeToDimensions( + ImagenDimensions original, + ImagenDimensions newDim, + ) { + // In a real implementation, this would calculate the top-left (x, y) + // based on the placement strategy (e.g., center, top-left). + final x = (newDim.width - original.width) / 2; + final y = (newDim.height - original.height) / 2; + return ImagenImagePlacement.fromCoordinate(x.toInt(), y.toInt()); + } +} diff --git a/packages/firebase_ai/firebase_ai/lib/src/imagen_model.dart b/packages/firebase_ai/firebase_ai/lib/src/imagen/imagen_model.dart similarity index 60% rename from packages/firebase_ai/firebase_ai/lib/src/imagen_model.dart rename to packages/firebase_ai/firebase_ai/lib/src/imagen/imagen_model.dart index bf4731a3b264..b4c64016824e 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/imagen_model.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/imagen/imagen_model.dart @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -part of 'base_model.dart'; +part of '../base_model.dart'; /// Represents a remote Imagen model with the ability to generate images using /// text prompts. @@ -110,6 +110,102 @@ final class ImagenModel extends BaseApiClientModel { (jsonObject) => parseImagenGenerationResponse(jsonObject), ); + + /// Edits an image based on a prompt and a list of reference images. + @experimental + Future> editImage( + List referenceImages, + String prompt, { + ImagenEditingConfig? config, + }) => + makeRequest( + Task.predict, + _generateImagenEditRequest( + referenceImages, + prompt, + config: config, + ), + (jsonObject) => + parseImagenGenerationResponse(jsonObject), + ); + + /// Inpaints an image based on a prompt and a mask. + @experimental + Future> inpaintImage( + ImagenInlineImage image, + String prompt, + ImagenMaskReference mask, { + ImagenEditingConfig? config, + }) => + editImage( + [ + mask, + ImagenRawImage(image: image), + ], + prompt, + config: config, + ); + + /// Outpaints an image based on a prompt and new dimensions. + @experimental + Future> outpaintImage( + ImagenInlineImage image, + ImagenDimensions newDimensions, { + ImagenImagePlacement newPosition = ImagenImagePlacement.center, + String prompt = '', + ImagenEditingConfig? config, + }) async { + final referenceImages = + await ImagenMaskReference.generateMaskAndPadForOutpainting( + image: image, + newDimensions: newDimensions, + newPosition: newPosition, + ); + return editImage( + referenceImages, + prompt, + config: ImagenEditingConfig( + editMode: ImagenEditMode.outpaint, editSteps: config?.editSteps), + ); + } + + Map _generateImagenEditRequest( + List images, + String prompt, { + ImagenEditingConfig? config, + }) { + final parameters = { + 'sampleCount': _generationConfig?.numberOfImages ?? 1, + if (config?.editMode case final editMode?) 'editMode': editMode.toJson(), + if (config?.editSteps case final editSteps?) + 'editConfig': {'baseSteps': editSteps}, + if (_generationConfig?.negativePrompt case final negativePrompt?) + 'negativePrompt': negativePrompt, + if (_generationConfig?.addWatermark case final addWatermark?) + 'addWatermark': addWatermark, + if (_generationConfig?.imageFormat case final imageFormat?) + 'outputOption': imageFormat.toJson(), + if (_safetySettings?.personFilterLevel case final personFilterLevel?) + 'personGeneration': personFilterLevel.toJson(), + if (_safetySettings?.safetyFilterLevel case final safetyFilterLevel?) + 'safetySetting': safetyFilterLevel.toJson(), + }; + + return { + 'parameters': parameters, + 'instances': [ + { + 'prompt': prompt, + 'referenceImages': images.asMap().entries.map((entry) { + int index = entry.key; + var image = entry.value; + return image.toJson( + referenceIdOverrideIfNull: index + images.length); + }).toList(), + } + ], + }; + } } /// Returns a [ImagenModel] using it's private constructor. diff --git a/packages/firebase_ai/firebase_ai/lib/src/imagen/imagen_reference.dart b/packages/firebase_ai/firebase_ai/lib/src/imagen/imagen_reference.dart new file mode 100644 index 000000000000..f3eef99878a6 --- /dev/null +++ b/packages/firebase_ai/firebase_ai/lib/src/imagen/imagen_reference.dart @@ -0,0 +1,320 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +import 'dart:ui' as ui; + +import 'package:flutter/material.dart'; +import 'package:meta/meta.dart'; + +import 'imagen_content.dart'; +import 'imagen_edit.dart'; + +enum _ReferenceType { + UNSPECIFIED('REFERENCE_TYPE_UNSPECIFIED'), + RAW('REFERENCE_TYPE_RAW'), + MASK('REFERENCE_TYPE_MASK'), + CONTROL('REFERENCE_TYPE_CONTROL'), + STYLE('REFERENCE_TYPE_STYLE'), + SUBJECT('REFERENCE_TYPE_SUBJECT'), + MASKED_SUBJECT('REFERENCE_TYPE_MASKED_SUBJECT'), + PRODUCT('REFERENCE_TYPE_PRODUCT'); + + const _ReferenceType(this._jsonString); + final String _jsonString; + String toJson() => _jsonString; +} + +/// A reference image for image editing. +@experimental +sealed class ImagenReferenceImage { + ImagenReferenceImage._({ + this.referenceConfig, + this.image, + required this.referenceType, + this.referenceId, + }); + + /// A config describing the reference image. + final ImagenReferenceConfig? referenceConfig; + + /// The actual image data of the reference image. + final ImagenInlineImage? image; + + /// The type of the reference image. + final _ReferenceType referenceType; + + /// The reference ID of the image. + final int? referenceId; + + // ignore: public_member_api_docs + Map toJson({int referenceIdOverrideIfNull = 0}) { + final json = {}; + json['referenceType'] = referenceType.toJson(); + if (referenceId != null) { + json['referenceId'] = referenceId; + } else { + json['referenceId'] = referenceIdOverrideIfNull; + } + if (image != null) { + json['referenceImage'] = image!.toJson(); + } + if (referenceConfig != null) { + json.addAll(referenceConfig!.toJson()); + } + + return json; + } +} + +/// A reference image that is a mask. +@experimental +sealed class ImagenMaskReference extends ImagenReferenceImage { + ImagenMaskReference({ + ImagenMaskConfig? maskConfig, + super.image, + super.referenceId, + }) : super._( + referenceType: _ReferenceType.MASK, + referenceConfig: maskConfig, + ); + + /// Generates a mask and pads the image for outpainting. + static Future> generateMaskAndPadForOutpainting({ + required ImagenInlineImage image, + required ImagenDimensions newDimensions, + ImagenImagePlacement newPosition = ImagenImagePlacement.center, + }) async { + final originalImage = await image.asUiImage(); + + try { + // Validate that the new dimensions are strictly larger. + if (originalImage.width >= newDimensions.width || + originalImage.height >= newDimensions.height) { + throw ArgumentError( + 'New Dimensions must be strictly larger than original image dimensions. ' + 'Original image is: ${originalImage.width}x${originalImage.height}, ' + 'new dimensions are ${newDimensions.width}x${newDimensions.height}', + ); + } + + // Calculate the position of the original image on the new canvas. + final originalDimensions = ImagenDimensions( + width: originalImage.width, height: originalImage.height); + final normalizedPosition = + newPosition.normalizeToDimensions(originalDimensions, newDimensions); + + final x = normalizedPosition.x?.toDouble(); + final y = normalizedPosition.y?.toDouble(); + + if (x == null || y == null) { + throw StateError('Error normalizing position for mask and padding.'); + } + + final sourceRect = ui.Rect.fromLTWH(0, 0, originalImage.width.toDouble(), + originalImage.height.toDouble()); + final destRect = ui.Rect.fromLTWH(x, y, originalImage.width.toDouble(), + originalImage.height.toDouble()); + + final whitePaint = Paint()..color = Colors.white; + final blackPaint = Paint()..color = Colors.black; + + final maskImage = await _createImageFromPainter( + width: newDimensions.width, + height: newDimensions.height, + painter: (canvas, size) { + canvas.drawPaint(whitePaint); + canvas.drawRect(destRect, blackPaint); + }, + ); + final maskBytes = + await maskImage.toByteData(format: ui.ImageByteFormat.png); + maskImage.dispose(); // Dispose right away + + // 2. Create, encode, and immediately dispose of the padded image + final paddedImage = await _createImageFromPainter( + width: newDimensions.width, + height: newDimensions.height, + painter: (canvas, size) { + canvas.drawPaint(blackPaint); + canvas.drawImageRect(originalImage, sourceRect, destRect, Paint()); + }, + ); + final paddedBytes = + await paddedImage.toByteData(format: ui.ImageByteFormat.png); + paddedImage.dispose(); // Dispose right away + + if (paddedBytes == null || maskBytes == null) { + throw StateError('Failed to encode generated images.'); + } + + // 5. Return a cleaner, more readable list + return [ + ImagenRawImage( + image: ImagenInlineImage( + bytesBase64Encoded: paddedBytes.buffer.asUint8List(), + mimeType: image.mimeType, + ), + ), + ImagenRawMask( + mask: ImagenInlineImage( + bytesBase64Encoded: maskBytes.buffer.asUint8List(), + mimeType: image.mimeType, + ), + ), + ]; + } finally { + originalImage.dispose(); + } + } + + /// Helper function to create a ui.Image by drawing on a Canvas. + static Future _createImageFromPainter({ + required int width, + required int height, + required void Function(ui.Canvas canvas, ui.Size size) painter, + }) { + final recorder = ui.PictureRecorder(); + final canvas = ui.Canvas(recorder); + final size = ui.Size(width.toDouble(), height.toDouble()); + + painter(canvas, size); + + final picture = recorder.endRecording(); + return picture.toImage(width, height); + } +} + +/// A raw image. +@experimental +final class ImagenRawImage extends ImagenReferenceImage { + ImagenRawImage({ + required ImagenInlineImage image, + super.referenceId, + }) : super._(image: image, referenceType: _ReferenceType.RAW); +} + +/// A raw mask. +@experimental +final class ImagenRawMask extends ImagenMaskReference { + ImagenRawMask({ + required ImagenInlineImage mask, + double? dilation, + super.referenceId, + }) : super( + image: mask, + maskConfig: ImagenMaskConfig( + maskType: ImagenMaskMode.userProvided, + maskDilation: dilation, + ), + ); +} + +/// A semantic mask. +@experimental +final class ImagenSemanticMask extends ImagenMaskReference { + ImagenSemanticMask({ + required List classes, + double? dilation, + super.referenceId, + }) : super( + maskConfig: ImagenMaskConfig( + maskType: ImagenMaskMode.semantic, + maskDilation: dilation, + maskClasses: classes, + ), + ); +} + +/// A background mask. +@experimental +final class ImagenBackgroundMask extends ImagenMaskReference { + ImagenBackgroundMask({ + double? dilation, + super.referenceId, + }) : super( + maskConfig: ImagenMaskConfig( + maskType: ImagenMaskMode.background, + maskDilation: dilation, + ), + ); +} + +/// A foreground mask. +@experimental +final class ImagenForegroundMask extends ImagenMaskReference { + ImagenForegroundMask({ + double? dilation, + super.referenceId, + }) : super( + maskConfig: ImagenMaskConfig( + maskType: ImagenMaskMode.foreground, + maskDilation: dilation, + ), + ); +} + +/// A subject reference. +@experimental +final class ImagenSubjectReference extends ImagenReferenceImage { + ImagenSubjectReference({ + required ImagenInlineImage image, + String? description, + ImagenSubjectReferenceType? subjectType, + super.referenceId, + }) : super._( + image: image, + referenceConfig: ImagenSubjectConfig( + description: description, + type: subjectType, + ), + referenceType: _ReferenceType.SUBJECT, + ); +} + +/// A style reference. +@experimental +final class ImagenStyleReference extends ImagenReferenceImage { + ImagenStyleReference({ + required ImagenInlineImage image, + String? description, + super.referenceId, + }) : super._( + image: image, + referenceConfig: ImagenStyleConfig( + description: description, + ), + referenceType: _ReferenceType.STYLE, + ); +} + +/// A control reference. +@experimental +final class ImagenControlReference extends ImagenReferenceImage { + ImagenControlReference({ + required ImagenControlType controlType, + ImagenInlineImage? image, + bool? enableComputation, + int? superpixelRegionSize, + int? superpixelRuler, + super.referenceId, + }) : super._( + image: image, + referenceConfig: ImagenControlConfig( + controlType: controlType, + enableComputation: enableComputation, + superpixelRegionSize: superpixelRegionSize, + superpixelRuler: superpixelRuler, + ), + referenceType: _ReferenceType.CONTROL, + ); +} diff --git a/packages/firebase_ai/firebase_ai/test/imagen_edit_test.dart b/packages/firebase_ai/firebase_ai/test/imagen_edit_test.dart new file mode 100644 index 000000000000..6d425a09ee3d --- /dev/null +++ b/packages/firebase_ai/firebase_ai/test/imagen_edit_test.dart @@ -0,0 +1,137 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may +// obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'dart:typed_data'; + +import 'package:firebase_ai/firebase_ai.dart'; +import 'package:flutter_test/flutter_test.dart'; + +void main() { + group('ImagenReferenceImage', () { + test('ImagenRawImage toJson', () { + final image = ImagenRawImage( + image: ImagenInlineImage( + bytesBase64Encoded: Uint8List.fromList([]), + mimeType: 'image/jpeg'), + referenceId: 1); + final json = image.toJson(); + expect(json, { + 'referenceType': 'REFERENCE_TYPE_RAW', + 'referenceId': 1, + 'referenceImage': {'bytesBase64Encoded': '', 'mimeType': 'image/jpeg'} + }); + }); + + test('ImagenRawMask toJson', () { + final image = ImagenRawMask( + mask: ImagenInlineImage( + bytesBase64Encoded: Uint8List.fromList([]), + mimeType: 'image/jpeg'), + referenceId: 1); + final json = image.toJson(); + expect(json, { + 'referenceType': 'REFERENCE_TYPE_MASK', + 'referenceId': 1, + 'referenceImage': {'bytesBase64Encoded': '', 'mimeType': 'image/jpeg'}, + 'maskImageConfig': {'maskMode': 'MASK_MODE_USER_PROVIDED'} + }); + }); + + test('ImagenSemanticMask toJson', () { + final image = ImagenSemanticMask(classes: [1, 2], referenceId: 1); + final json = image.toJson(); + expect(json, { + 'referenceType': 'REFERENCE_TYPE_MASK', + 'referenceId': 1, + 'maskImageConfig': { + 'maskMode': 'MASK_MODE_SEMANTIC', + 'maskClasses': '[1,2]' + } + }); + }); + + test('ImagenBackgroundMask toJson', () { + final image = ImagenBackgroundMask(referenceId: 1); + final json = image.toJson(); + expect(json, { + 'referenceType': 'REFERENCE_TYPE_MASK', + 'referenceId': 1, + 'maskImageConfig': {'maskMode': 'MASK_MODE_BACKGROUND'} + }); + }); + + test('ImagenForegroundMask toJson', () { + final image = ImagenForegroundMask(referenceId: 1); + final json = image.toJson(); + expect(json, { + 'referenceType': 'REFERENCE_TYPE_MASK', + 'referenceId': 1, + 'maskImageConfig': {'maskMode': 'MASK_MODE_FOREGROUND'} + }); + }); + + test('ImagenSubjectReference toJson', () { + final image = ImagenSubjectReference( + image: ImagenInlineImage( + bytesBase64Encoded: Uint8List.fromList([]), mimeType: 'image/jpeg'), + description: 'a cat', + subjectType: ImagenSubjectReferenceType.animal, + referenceId: 1, + ); + final json = image.toJson(); + expect(json, { + 'referenceType': 'REFERENCE_TYPE_SUBJECT', + 'referenceId': 1, + 'referenceImage': {'bytesBase64Encoded': '', 'mimeType': 'image/jpeg'}, + 'subjectImageConfig': { + 'subjectDescription': 'a cat', + 'subjectType': 'SUBJECT_TYPE_ANIMAL' + } + }); + }); + + test('ImagenStyleReference toJson', () { + final image = ImagenStyleReference( + image: ImagenInlineImage( + bytesBase64Encoded: Uint8List.fromList([]), mimeType: 'image/jpeg'), + description: 'van gogh style', + referenceId: 1, + ); + final json = image.toJson(); + expect(json, { + 'referenceType': 'REFERENCE_TYPE_STYLE', + 'referenceId': 1, + 'referenceImage': {'mimeType': 'image/jpeg', 'bytesBase64Encoded': ''}, + 'styleImageConfig': {'styleDescription': 'van gogh style'} + }); + }); + + test('ImagenControlReference toJson', () { + final image = ImagenControlReference( + controlType: ImagenControlType.canny, + image: ImagenInlineImage( + bytesBase64Encoded: Uint8List.fromList([]), mimeType: 'image/jpeg'), + referenceId: 1, + ); + final json = image.toJson(); + expect(json, { + 'referenceType': 'REFERENCE_TYPE_CONTROL', + 'referenceId': 1, + 'referenceImage': {'bytesBase64Encoded': '', 'mimeType': 'image/jpeg'}, + 'controlImageConfig': {'controlType': 'CONTROL_TYPE_CANNY'} + }); + }); + }); +} diff --git a/packages/firebase_ai/firebase_ai/test/imagen_test.dart b/packages/firebase_ai/firebase_ai/test/imagen_test.dart index bdb6200593c1..6f26baffd333 100644 --- a/packages/firebase_ai/firebase_ai/test/imagen_test.dart +++ b/packages/firebase_ai/firebase_ai/test/imagen_test.dart @@ -16,8 +16,8 @@ import 'dart:convert'; import 'dart:typed_data'; import 'package:firebase_ai/src/error.dart'; -import 'package:firebase_ai/src/imagen_api.dart'; -import 'package:firebase_ai/src/imagen_content.dart'; +import 'package:firebase_ai/src/imagen/imagen_api.dart'; +import 'package:firebase_ai/src/imagen/imagen_content.dart'; import 'package:flutter_test/flutter_test.dart'; void main() {