Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -339,4 +339,28 @@ val FIREBASE_AI_SAMPLES = listOf(
)
},
),
Sample(
title = "Imagen 3 - Server Template Generation",
description = "Generate an image using a server prompt template.",
navRoute = "imagen",
categories = listOf(Category.IMAGE),
initialPrompt = content { text("List of things that should be in the image") },
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use a more actionable prompt here? Or no initial prompt at all? (assuming the prompt comes from the template)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prompt does come from the template unfortunately. Let me know what you think is best here.

allowEmptyPrompt = false,
editingMode = EditingMode.TEMPLATE,
// To make this work on your project, create an `Imagen (Basic)` template in your project with this name
templateId = "imagen-basic",
templateKey = "prompt"
),
Sample(
title = "Server Prompt Templates",
description = "Generate an invoice using server prompt templates",
navRoute = "text",
categories = listOf(Category.TEXT),
initialPrompt = content { text("Customer Name") },
allowEmptyPrompt = false,
// To make this work on your project, create an `Input + System Instructions` template in your project with this name
templateId = "input-system-instructions",
templateKey = "customerName"
),

)
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ import com.google.firebase.quickstart.ai.feature.media.imagen.ImagenRoute
import com.google.firebase.quickstart.ai.feature.media.imagen.ImagenScreen
import com.google.firebase.quickstart.ai.feature.text.ChatRoute
import com.google.firebase.quickstart.ai.feature.text.ChatScreen
import com.google.firebase.quickstart.ai.feature.text.TextGenRoute
import com.google.firebase.quickstart.ai.feature.text.TextGenScreen
import com.google.firebase.quickstart.ai.ui.navigation.MainMenuScreen
import com.google.firebase.quickstart.ai.ui.theme.FirebaseAILogicTheme

Expand Down Expand Up @@ -90,6 +92,9 @@ class MainActivity : ComponentActivity() {
"stream" -> {
navController.navigate(StreamRealtimeRoute(it.id))
}
"text" -> {
navController.navigate(TextGenRoute(it.id))
}
}
}
)
Expand All @@ -106,6 +111,9 @@ class MainActivity : ComponentActivity() {
composable<StreamRealtimeRoute> {
StreamRealtimeScreen()
}
composable<TextGenRoute> {
TextGenScreen()
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ enum class EditingMode {
OUTPAINTING,
SUBJECT_REFERENCE,
STYLE_TRANSFER,
TEMPLATE,
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.launch
import androidx.core.graphics.scale
import com.google.firebase.ai.TemplateImagenModel
import com.google.firebase.ai.type.Dimensions
import com.google.firebase.ai.type.ImagenBackgroundMask
import com.google.firebase.ai.type.ImagenEditMode
Expand Down Expand Up @@ -67,6 +68,10 @@ class ImagenViewModel(

val additionalImage = sample.additionalImage

val templateId = sample.templateId

val templateKey = sample.templateKey

private val _attachedImage = MutableStateFlow<Bitmap?>(null)
val attachedImage: StateFlow<Bitmap?> = _attachedImage

Expand All @@ -75,6 +80,7 @@ class ImagenViewModel(

// Firebase AI Logic
private val imagenModel: ImagenModel
private val templateImagenModel: TemplateImagenModel

init {
val config = imagenGenerationConfig {
Expand All @@ -92,23 +98,31 @@ class ImagenViewModel(
generationConfig = config,
safetySettings = settings
)
templateImagenModel = Firebase.ai.templateImagenModel()
}

fun generateImages(inputText: String) {
viewModelScope.launch {
_isLoading.value = true
_errorMessage.value = null // clear error message
try {
val imageResponse = when(sample.editingMode) {
EditingMode.INPAINTING -> inpaint(imagenModel, inputText)
EditingMode.OUTPAINTING -> outpaint(imagenModel, inputText)
EditingMode.SUBJECT_REFERENCE -> drawReferenceSubject(imagenModel, inputText)
EditingMode.STYLE_TRANSFER -> transferStyle(imagenModel, inputText)
EditingMode.TEMPLATE -> generateWithTemplate(templateImagenModel, templateId!!, mapOf(templateKey!! to inputText))
else -> generate(imagenModel, inputText)
}
_generatedBitmaps.value = imageResponse.images.map { it.asBitmap() }
_errorMessage.value = null // clear error message
} catch (e: Exception) {
_errorMessage.value = e.localizedMessage
val errorMessage =
if ((e.localizedMessage?.contains("not found") == true) && sample.editingMode == EditingMode.TEMPLATE) {
"Template was not found, please verify that your project contains a template named \"$templateId\"."
} else {
e.localizedMessage
}
_errorMessage.value = errorMessage
} finally {
_isLoading.value = false
}
Expand Down Expand Up @@ -212,4 +226,12 @@ class ImagenViewModel(
inputText
)
}

suspend fun generateWithTemplate(
model: TemplateImagenModel,
templateId: String,
inputMap: Map<String, String>
): ImagenGenerationResponse<ImagenInlineImage> {
return model.generateImages(templateId, inputMap)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package com.google.firebase.quickstart.ai.feature.text

import android.net.Uri
import android.provider.OpenableColumns
import android.text.format.Formatter
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.foundation.Image
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.lazy.grid.GridCells
import androidx.compose.foundation.lazy.grid.LazyHorizontalGrid
import androidx.compose.foundation.lazy.grid.items
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.verticalScroll
import androidx.compose.material3.Card
import androidx.compose.material3.CardDefaults
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.DropdownMenu
import androidx.compose.material3.DropdownMenuItem
import androidx.compose.material3.ElevatedCard
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.OutlinedTextField
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableIntStateOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.saveable.rememberSaveable
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.unit.dp
import androidx.lifecycle.compose.collectAsStateWithLifecycle
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.firebase.quickstart.ai.R
import kotlinx.coroutines.launch
import kotlinx.serialization.Serializable

@Serializable
class TextGenRoute(val sampleId: String)

@Composable
fun TextGenScreen(
textGenViewModel: TextGenViewModel = viewModel<TextGenViewModel>()
) {
var textPrompt by rememberSaveable { mutableStateOf(textGenViewModel.initialPrompt) }
val errorMessage by textGenViewModel.errorMessage.collectAsStateWithLifecycle()
val isLoading by textGenViewModel.isLoading.collectAsStateWithLifecycle()
val generatedText by textGenViewModel.generatedText.collectAsStateWithLifecycle()

Column(
modifier = Modifier.verticalScroll(rememberScrollState())
) {
ElevatedCard(
modifier = Modifier
.padding(all = 16.dp)
.fillMaxWidth(),
shape = MaterialTheme.shapes.large
) {
OutlinedTextField(
value = textPrompt,
label = { Text("Prompt") },
placeholder = { Text("Enter text to generate") },
onValueChange = { textPrompt = it },
modifier = Modifier
.padding(16.dp)
.fillMaxWidth()
)
Row() {
TextButton(
onClick = {
if (textGenViewModel.allowEmptyPrompt || textPrompt.isNotBlank()) {
textGenViewModel.generate(textPrompt)
}
},
modifier = Modifier.padding(end = 16.dp, bottom = 16.dp)
) {
Text("Generate")
}
}

}

if (isLoading) {
Box(
contentAlignment = Alignment.Center,
modifier = Modifier
.padding(all = 8.dp)
.align(Alignment.CenterHorizontally)
) {
CircularProgressIndicator()
}
}
errorMessage?.let {
Card(
modifier = Modifier
.padding(horizontal = 16.dp)
.fillMaxWidth(),
shape = MaterialTheme.shapes.large,
colors = CardDefaults.cardColors(
containerColor = MaterialTheme.colorScheme.errorContainer
)
) {
Text(
text = it,
color = MaterialTheme.colorScheme.error,
modifier = Modifier.padding(all = 16.dp)
)
}
}
generatedText?.let {
Card(
modifier = Modifier
.padding(horizontal = 16.dp)
.fillMaxWidth(),
shape = MaterialTheme.shapes.large,
colors = CardDefaults.cardColors(
containerColor = MaterialTheme.colorScheme.primaryContainer
)
) {
Text(
text = it,
color = MaterialTheme.colorScheme.primary,
modifier = Modifier.padding(all = 16.dp)
)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package com.google.firebase.quickstart.ai.feature.text

import androidx.lifecycle.SavedStateHandle
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import androidx.navigation.toRoute
import com.google.firebase.Firebase
import com.google.firebase.ai.ai
import com.google.firebase.ai.type.PublicPreviewAPI
import com.google.firebase.ai.type.asTextOrNull
import com.google.firebase.quickstart.ai.FIREBASE_AI_SAMPLES
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.launch
import com.google.firebase.ai.GenerativeModel
import com.google.firebase.ai.TemplateGenerativeModel

@OptIn(PublicPreviewAPI::class)
class TextGenViewModel(
savedStateHandle: SavedStateHandle
) : ViewModel() {
private val sampleId = savedStateHandle.toRoute<TextGenRoute>().sampleId
private val sample = FIREBASE_AI_SAMPLES.first { it.id == sampleId }
val initialPrompt = sample.initialPrompt?.parts?.first()?.asTextOrNull().orEmpty()

private val _errorMessage: MutableStateFlow<String?> = MutableStateFlow(null)
val errorMessage: StateFlow<String?> = _errorMessage

private val _isLoading = MutableStateFlow(false)
val isLoading: StateFlow<Boolean> = _isLoading

val allowEmptyPrompt = sample.allowEmptyPrompt

val templateId = sample.templateId

val templateKey = sample.templateKey

private val _generatedText = MutableStateFlow<String?>(null)
val generatedText: StateFlow<String?> = _generatedText

// Firebase AI Logic
private val generativeModel: GenerativeModel
private val templateGenerativeModel: TemplateGenerativeModel

init {
generativeModel = Firebase.ai(
backend = sample.backend // GenerativeBackend.googleAI() by default
).generativeModel(
modelName = sample.modelName ?: "gemini-2.5-flash",
systemInstruction = sample.systemInstructions,
generationConfig = sample.generationConfig,
tools = sample.tools
)
templateGenerativeModel = Firebase.ai.templateGenerativeModel()
}

fun generate(inputText: String) {
viewModelScope.launch {
_isLoading.value = true
_errorMessage.value = null // clear error message
try {
val generativeResponse = if (templateId != null) {
templateGenerativeModel
.generateContent(templateId, mapOf(templateKey!! to inputText))
} else {
generativeModel.generateContent(inputText)
}
_generatedText.value = generativeResponse.text
} catch (e: Exception) {
val errorMessage =
if ((e.localizedMessage?.contains("not found") == true) && (templateId != null)) {
"Template was not found, please verify that your project contains a template named \"$templateId\"."
} else {
e.localizedMessage
}
_errorMessage.value = errorMessage
} finally {
_isLoading.value = false
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,6 @@ data class Sample(
val imageLabels: List<String> = emptyList(),
val selectionOptions: List<String> = emptyList(),
val editingMode: EditingMode? = null,
val templateId: String? = null,
val templateKey: String? = null,
)
Loading