diff --git a/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/AIModels.kt b/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/AIModels.kt index 4e6bb018dac..c4df8a889a8 100644 --- a/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/AIModels.kt +++ b/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/AIModels.kt @@ -24,9 +24,6 @@ import com.google.firebase.ai.type.PublicPreviewAPI class AIModels { companion object { - private val API_KEY: String = "" - private val APP_ID: String = "" - private val PROJECT_ID: String = "fireescape-integ-tests" // General purpose models var app: FirebaseApp? = null lateinit var vertexAIFlashModel: GenerativeModel @@ -49,14 +46,6 @@ class AIModels { ) } - /** Returns a list of template models to test */ - fun getTemplateModels(): List { - if (app == null) { - setup() - } - return listOf(vertexAITemplateModel, googleAITemplateModel) - } - fun app(): FirebaseApp { if (app == null) { setup() @@ -80,7 +69,7 @@ class AIModels { googleAIFlashModel = FirebaseAI.getInstance(app!!, GenerativeBackend.googleAI()) .generativeModel( - modelName = "gemini-2.5-flash", + modelName = "gemini-3.1-flash-lite", ) googleAIFlashLiteModel = FirebaseAI.getInstance(app!!, GenerativeBackend.googleAI()) @@ -94,3 +83,6 @@ class AIModels { } } } + +@OptIn(PublicPreviewAPI::class) +data class TemplateModel(val backend: String, val model: TemplateGenerativeModel) diff --git a/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/ChatTemplateIntegrationTests.kt b/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/ChatTemplateIntegrationTests.kt index 2971b5fc32c..45364afb09f 100644 --- a/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/ChatTemplateIntegrationTests.kt +++ b/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/ChatTemplateIntegrationTests.kt @@ -16,13 +16,13 @@ package com.google.firebase.ai -import com.google.firebase.ai.AIModels.Companion.getTemplateModels import com.google.firebase.ai.type.PublicPreviewAPI import com.google.firebase.ai.type.content import io.kotest.matchers.shouldBe import io.kotest.matchers.string.shouldContainIgnoringCase import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking +import org.junit.Before import org.junit.Test @OptIn(PublicPreviewAPI::class) @@ -59,55 +59,87 @@ class ChatTemplateIntegrationTests { private val topic = "Firebase" private val inputs = mapOf("customerName" to customerName, "topic" to topic) + @Before + fun setup() { + AIModels.setup() + } + + @Test + fun testTemplateChat_sendMessage_googleAI(): Unit = runBlocking { + val chat = AIModels.googleAITemplateModel.startChat("$templateId-google-ai", inputs) + val response = chat.sendMessage("which number is higher, one or ten?") + + response.candidates.isNotEmpty() shouldBe true + response.text shouldContainIgnoringCase "ten" + + chat.history.size shouldBe 2 + } + @Test + fun testTemplateChat_sendMessage_vertexAI(): Unit = runBlocking { + val chat = AIModels.vertexAITemplateModel.startChat("$templateId-vertex-ai", inputs) + val response = chat.sendMessage("which number is higher, one or ten?") + + response.candidates.isNotEmpty() shouldBe true + response.text shouldContainIgnoringCase "ten" + + chat.history.size shouldBe 2 + } + @Test - fun testTemplateChat_sendMessage() { - for (model in getTemplateModels()) { - runBlocking { - val chat = model.startChat(templateId, inputs) - val response = chat.sendMessage("which number is higher, one or ten?") - - response.candidates.isNotEmpty() shouldBe true - response.text shouldContainIgnoringCase "ten" - - chat.history.size shouldBe 2 - } - } + fun testTemplateChat_sendMessageStream_googleAI(): Unit = runBlocking { + val chat = AIModels.googleAITemplateModel.startChat("$templateId-google-ai", inputs) + val responses = chat.sendMessageStream("which number is higher, one or ten?").toList() + responses.isNotEmpty() shouldBe true + responses.joinToString { it.text ?: "" } shouldContainIgnoringCase "ten" + chat.history.size shouldBe 2 } @Test - fun testTemplateChat_sendMessageStream() { - for (model in getTemplateModels()) { - runBlocking { - val chat = model.startChat(templateId, inputs) - val responses = chat.sendMessageStream("which number is higher, one or ten?").toList() - responses.isNotEmpty() shouldBe true - responses.joinToString { it.text ?: "" } shouldContainIgnoringCase "ten" - chat.history.size shouldBe 2 - } - } + fun testTemplateChat_sendMessageStream_vertexAI(): Unit = runBlocking { + val chat = AIModels.vertexAITemplateModel.startChat("$templateId-vertex-ai", inputs) + val responses = chat.sendMessageStream("which number is higher, one or ten?").toList() + responses.isNotEmpty() shouldBe true + responses.joinToString { it.text ?: "" } shouldContainIgnoringCase "ten" + chat.history.size shouldBe 2 } @Test - fun testTemplateChat_withHistory() { - for (model in getTemplateModels()) { - runBlocking { - val history = - listOf( - content("user") { text("which number is higher, one or ten?") }, - content("model") { text("Ten.") } - ) - val chat = model.startChat(templateId, inputs, history) - chat.history.size shouldBe 2 - val response = - chat.sendMessage( - "Please concatenate them both, first the smaller one, then the bigger one." - ) - - response.candidates.isNotEmpty() shouldBe true - response.text shouldContainIgnoringCase "oneten" - - chat.history.size shouldBe 4 - } - } + fun testTemplateChat_withHistory_googleAI(): Unit = runBlocking { + val history = + listOf( + content("user") { text("which number is higher, one or ten?") }, + content("model") { text("Ten.") } + ) + val chat = AIModels.googleAITemplateModel.startChat("$templateId-google-ai", inputs, history) + chat.history.size shouldBe 2 + val response = + chat.sendMessage( + "Please concatenate them both, first the smaller one, then the bigger one. Do not use punctuation or spaces." + ) + + response.candidates.isNotEmpty() shouldBe true + response.text?.replace(" ", "") shouldContainIgnoringCase "oneten" + + chat.history.size shouldBe 4 + } + + @Test + fun testTemplateChat_withHistory_vertexAI(): Unit = runBlocking { + val history = + listOf( + content("user") { text("which number is higher, one or ten?") }, + content("model") { text("Ten.") } + ) + val chat = AIModels.googleAITemplateModel.startChat("$templateId-google-ai", inputs, history) + chat.history.size shouldBe 2 + val response = + chat.sendMessage( + "Please concatenate them both, first the smaller one, then the bigger one. Do not use punctuation or spaces." + ) + + response.candidates.isNotEmpty() shouldBe true + response.text?.replace(" ", "") shouldContainIgnoringCase "oneten" + + chat.history.size shouldBe 4 } } diff --git a/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/TemplateIntegrationTests.kt b/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/TemplateIntegrationTests.kt index 6f9b595aeb9..f4a2939ba17 100644 --- a/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/TemplateIntegrationTests.kt +++ b/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/TemplateIntegrationTests.kt @@ -16,13 +16,12 @@ package com.google.firebase.ai -import com.google.firebase.ai.AIModels.Companion.getTemplateModels import com.google.firebase.ai.type.PublicPreviewAPI import io.kotest.matchers.collections.shouldNotBeEmpty import io.kotest.matchers.string.shouldContainIgnoringCase -import io.kotest.matchers.string.shouldNotBeEmpty import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking +import org.junit.Before import org.junit.Test @OptIn(PublicPreviewAPI::class) @@ -51,32 +50,56 @@ class TemplateIntegrationTests { private val topic = "Firebase" private val inputs = mapOf("customerName" to customerName, "topic" to topic) + @Before + fun setup() { + AIModels.setup() + } + @Test - fun testTemplateGenerateContent() { - for (model in getTemplateModels()) { - runBlocking { - val response = model.generateContent(templateId, inputs) + fun testTemplateGenerateContent_googleAI(): Unit = runBlocking { + val response = AIModels.googleAITemplateModel.generateContent("$templateId-google-ai", inputs) - response.candidates.shouldNotBeEmpty() - response.text shouldContainIgnoringCase customerName - response.text shouldContainIgnoringCase topic + response.candidates.shouldNotBeEmpty() + response.text shouldContainIgnoringCase customerName + response.text shouldContainIgnoringCase topic + } + + @Test + fun testTemplateGenerateContent_vertexAI(): Unit = runBlocking { + val response = AIModels.vertexAITemplateModel.generateContent("$templateId-vertex-ai", inputs) + + response.candidates.shouldNotBeEmpty() + response.text shouldContainIgnoringCase customerName + response.text shouldContainIgnoringCase topic + } + + @Test + fun testTemplateGenerateContentStream_googleAI(): Unit = runBlocking { + val responses = + AIModels.googleAITemplateModel.generateContentStream("$templateId-google-ai", inputs).toList() + responses + .joinToString { it.text ?: "" } + .lowercase() + .replace(",", "") + .replace(" ", " ") // Model sometimes doubles spacing + .let { + it shouldContainIgnoringCase customerName + it shouldContainIgnoringCase topic } - } } @Test - fun testTemplateGenerateContentStream() { - for (model in getTemplateModels()) { - runBlocking { - val responses = model.generateContentStream(templateId, inputs).toList() - responses - .joinToString { it.text ?: "" } - .lowercase() - .let { - it shouldContainIgnoringCase customerName - it shouldContainIgnoringCase topic - } + fun testTemplateGenerateContentStream_vertexAI(): Unit = runBlocking { + val responses = + AIModels.vertexAITemplateModel.generateContentStream("$templateId-vertex-ai", inputs).toList() + responses + .joinToString { it.text ?: "" } + .lowercase() + .replace(",", "") + .replace(" ", " ") // Model sometimes doubles spacing + .let { + it shouldContainIgnoringCase customerName + it shouldContainIgnoringCase topic } - } } }