diff --git a/.changeset/metal-ties-cry.md b/.changeset/metal-ties-cry.md new file mode 100644 index 00000000000..ffff494e6d2 --- /dev/null +++ b/.changeset/metal-ties-cry.md @@ -0,0 +1,6 @@ +--- +'firebase': minor +'@firebase/ai': minor +--- + +Add support for Server Prompt Templates. diff --git a/common/api-review/ai.api.md b/common/api-review/ai.api.md index 08a12efc36e..2bf194fbaf2 100644 --- a/common/api-review/ai.api.md +++ b/common/api-review/ai.api.md @@ -100,6 +100,10 @@ export interface AudioTranscriptionConfig { export abstract class Backend { protected constructor(type: BackendType); readonly backendType: BackendType; + // @internal (undocumented) + abstract _getModelPath(project: string, model: string): string; + // @internal (undocumented) + abstract _getTemplatePath(project: string, templateId: string): string; } // @public @@ -567,9 +571,19 @@ export function getImagenModel(ai: AI, modelParams: ImagenModelParams, requestOp // @beta export function getLiveGenerativeModel(ai: AI, modelParams: LiveModelParams): LiveGenerativeModel; +// @beta +export function getTemplateGenerativeModel(ai: AI, requestOptions?: RequestOptions): TemplateGenerativeModel; + +// @beta +export function getTemplateImagenModel(ai: AI, requestOptions?: RequestOptions): TemplateImagenModel; + // @public export class GoogleAIBackend extends Backend { constructor(); + // @internal (undocumented) + _getModelPath(project: string, model: string): string; + // @internal (undocumented) + _getTemplatePath(project: string, templateId: string): string; } // Warning: (ae-internal-missing-underscore) The name "GoogleAICitationMetadata" should be prefixed with an underscore because the declaration is marked as @internal @@ -1314,6 +1328,25 @@ export class StringSchema extends Schema { toJSON(): SchemaRequest; } +// @beta +export class TemplateGenerativeModel { + constructor(ai: AI, requestOptions?: RequestOptions); + // @internal (undocumented) + _apiSettings: ApiSettings; + generateContent(templateId: string, templateVariables: object): Promise; + generateContentStream(templateId: string, templateVariables: object): Promise; + requestOptions?: RequestOptions; +} + +// @beta +export class TemplateImagenModel { + constructor(ai: AI, requestOptions?: RequestOptions); + // @internal (undocumented) + _apiSettings: ApiSettings; + generateImages(templateId: string, templateVariables: object): Promise>; + requestOptions?: RequestOptions; +} + // @public export interface TextPart { // (undocumented) @@ -1412,6 +1445,10 @@ export interface UsageMetadata { // @public export class VertexAIBackend extends Backend { constructor(location?: string); + // @internal (undocumented) + _getModelPath(project: string, model: string): string; + // @internal (undocumented) + _getTemplatePath(project: string, templateId: string): string; readonly location: string; } diff --git a/docs-devsite/_toc.yaml b/docs-devsite/_toc.yaml index 4f3bb1f3ca4..92633c553a3 100644 --- a/docs-devsite/_toc.yaml +++ b/docs-devsite/_toc.yaml @@ -198,6 +198,10 @@ toc: path: /docs/reference/js/ai.startchatparams.md - title: StringSchema path: /docs/reference/js/ai.stringschema.md + - title: TemplateGenerativeModel + path: /docs/reference/js/ai.templategenerativemodel.md + - title: TemplateImagenModel + path: /docs/reference/js/ai.templateimagenmodel.md - title: TextPart path: /docs/reference/js/ai.textpart.md - title: ThinkingConfig diff --git a/docs-devsite/ai.md b/docs-devsite/ai.md index 79902cab4e7..53e4057cade 100644 --- a/docs-devsite/ai.md +++ b/docs-devsite/ai.md @@ -22,6 +22,8 @@ The Firebase AI Web SDK. | [getGenerativeModel(ai, modelParams, requestOptions)](./ai.md#getgenerativemodel_c63f46a) | Returns a [GenerativeModel](./ai.generativemodel.md#generativemodel_class) class with methods for inference and other functionality. | | [getImagenModel(ai, modelParams, requestOptions)](./ai.md#getimagenmodel_e1f6645) | Returns an [ImagenModel](./ai.imagenmodel.md#imagenmodel_class) class with methods for using Imagen.Only Imagen 3 models (named imagen-3.0-*) are supported. | | [getLiveGenerativeModel(ai, modelParams)](./ai.md#getlivegenerativemodel_f2099ac) | (Public Preview) Returns a [LiveGenerativeModel](./ai.livegenerativemodel.md#livegenerativemodel_class) class for real-time, bidirectional communication.The Live API is only supported in modern browser windows and Node >= 22. | +| [getTemplateGenerativeModel(ai, requestOptions)](./ai.md#gettemplategenerativemodel_9476bbc) | (Public Preview) Returns a [TemplateGenerativeModel](./ai.templategenerativemodel.md#templategenerativemodel_class) class for executing server-side templates. | +| [getTemplateImagenModel(ai, requestOptions)](./ai.md#gettemplateimagenmodel_9476bbc) | (Public Preview) Returns a [TemplateImagenModel](./ai.templateimagenmodel.md#templateimagenmodel_class) class for executing server-side Imagen templates. | | function(liveSession, ...) | | [startAudioConversation(liveSession, options)](./ai.md#startaudioconversation_01c8e7f) | (Public Preview) Starts a real-time, bidirectional audio conversation with the model. This helper function manages the complexities of microphone access, audio recording, playback, and interruptions. | @@ -47,6 +49,8 @@ The Firebase AI Web SDK. | [ObjectSchema](./ai.objectschema.md#objectschema_class) | Schema class for "object" types. The properties param must be a map of Schema objects. | | [Schema](./ai.schema.md#schema_class) | Parent class encompassing all Schema types, with static methods that allow building specific Schema types. This class can be converted with JSON.stringify() into a JSON string accepted by Vertex AI REST endpoints. (This string conversion is automatically done when calling SDK methods.) | | [StringSchema](./ai.stringschema.md#stringschema_class) | Schema class for "string" types. Can be used with or without enum values. | +| [TemplateGenerativeModel](./ai.templategenerativemodel.md#templategenerativemodel_class) | (Public Preview) [GenerativeModel](./ai.generativemodel.md#generativemodel_class) APIs that execute on a server-side template.This class should only be instantiated with [getTemplateGenerativeModel()](./ai.md#gettemplategenerativemodel_9476bbc). | +| [TemplateImagenModel](./ai.templateimagenmodel.md#templateimagenmodel_class) | (Public Preview) Class for Imagen model APIs that execute on a server-side template.This class should only be instantiated with [getTemplateImagenModel()](./ai.md#gettemplateimagenmodel_9476bbc). | | [VertexAIBackend](./ai.vertexaibackend.md#vertexaibackend_class) | Configuration class for the Vertex AI Gemini API.Use this with [AIOptions](./ai.aioptions.md#aioptions_interface) when initializing the AI service via [getAI()](./ai.md#getai_a94a413) to specify the Vertex AI Gemini API as the backend. | ## Interfaces @@ -341,6 +345,54 @@ export declare function getLiveGenerativeModel(ai: AI, modelParams: LiveModelPar If the `apiKey` or `projectId` fields are missing in your Firebase config. +### getTemplateGenerativeModel(ai, requestOptions) {:#gettemplategenerativemodel_9476bbc} + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Returns a [TemplateGenerativeModel](./ai.templategenerativemodel.md#templategenerativemodel_class) class for executing server-side templates. + +Signature: + +```typescript +export declare function getTemplateGenerativeModel(ai: AI, requestOptions?: RequestOptions): TemplateGenerativeModel; +``` + +#### Parameters + +| Parameter | Type | Description | +| --- | --- | --- | +| ai | [AI](./ai.ai.md#ai_interface) | An [AI](./ai.ai.md#ai_interface) instance. | +| requestOptions | [RequestOptions](./ai.requestoptions.md#requestoptions_interface) | Additional options to use when making requests. | + +Returns: + +[TemplateGenerativeModel](./ai.templategenerativemodel.md#templategenerativemodel_class) + +### getTemplateImagenModel(ai, requestOptions) {:#gettemplateimagenmodel_9476bbc} + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Returns a [TemplateImagenModel](./ai.templateimagenmodel.md#templateimagenmodel_class) class for executing server-side Imagen templates. + +Signature: + +```typescript +export declare function getTemplateImagenModel(ai: AI, requestOptions?: RequestOptions): TemplateImagenModel; +``` + +#### Parameters + +| Parameter | Type | Description | +| --- | --- | --- | +| ai | [AI](./ai.ai.md#ai_interface) | An [AI](./ai.ai.md#ai_interface) instance. | +| requestOptions | [RequestOptions](./ai.requestoptions.md#requestoptions_interface) | Additional options to use when making requests. | + +Returns: + +[TemplateImagenModel](./ai.templateimagenmodel.md#templateimagenmodel_class) + ## function(liveSession, ...) ### startAudioConversation(liveSession, options) {:#startaudioconversation_01c8e7f} diff --git a/docs-devsite/ai.templategenerativemodel.md b/docs-devsite/ai.templategenerativemodel.md new file mode 100644 index 00000000000..c115af62b1e --- /dev/null +++ b/docs-devsite/ai.templategenerativemodel.md @@ -0,0 +1,125 @@ +Project: /docs/reference/js/_project.yaml +Book: /docs/reference/_book.yaml +page_type: reference + +{% comment %} +DO NOT EDIT THIS FILE! +This is generated by the JS SDK team, and any local changes will be +overwritten. Changes should be made in the source code at +https://github.com/firebase/firebase-js-sdk +{% endcomment %} + +# TemplateGenerativeModel class +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +[GenerativeModel](./ai.generativemodel.md#generativemodel_class) APIs that execute on a server-side template. + +This class should only be instantiated with [getTemplateGenerativeModel()](./ai.md#gettemplategenerativemodel_9476bbc). + +Signature: + +```typescript +export declare class TemplateGenerativeModel +``` + +## Constructors + +| Constructor | Modifiers | Description | +| --- | --- | --- | +| [(constructor)(ai, requestOptions)](./ai.templategenerativemodel.md#templategenerativemodelconstructor) | | (Public Preview) Constructs a new instance of the TemplateGenerativeModel class | + +## Properties + +| Property | Modifiers | Type | Description | +| --- | --- | --- | --- | +| [requestOptions](./ai.templategenerativemodel.md#templategenerativemodelrequestoptions) | | [RequestOptions](./ai.requestoptions.md#requestoptions_interface) | (Public Preview) Additional options to use when making requests. | + +## Methods + +| Method | Modifiers | Description | +| --- | --- | --- | +| [generateContent(templateId, templateVariables)](./ai.templategenerativemodel.md#templategenerativemodelgeneratecontent) | | (Public Preview) Makes a single non-streaming call to the model and returns an object containing a single [GenerateContentResponse](./ai.generatecontentresponse.md#generatecontentresponse_interface). | +| [generateContentStream(templateId, templateVariables)](./ai.templategenerativemodel.md#templategenerativemodelgeneratecontentstream) | | (Public Preview) Makes a single streaming call to the model and returns an object containing an iterable stream that iterates over all chunks in the streaming response as well as a promise that returns the final aggregated response. | + +## TemplateGenerativeModel.(constructor) + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + + Constructs a new instance of the `TemplateGenerativeModel` class + +Signature: + +```typescript +constructor(ai: AI, requestOptions?: RequestOptions); +``` + +#### Parameters + +| Parameter | Type | Description | +| --- | --- | --- | +| ai | [AI](./ai.ai.md#ai_interface) | | +| requestOptions | [RequestOptions](./ai.requestoptions.md#requestoptions_interface) | | + +## TemplateGenerativeModel.requestOptions + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Additional options to use when making requests. + +Signature: + +```typescript +requestOptions?: RequestOptions; +``` + +## TemplateGenerativeModel.generateContent() + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Makes a single non-streaming call to the model and returns an object containing a single [GenerateContentResponse](./ai.generatecontentresponse.md#generatecontentresponse_interface). + +Signature: + +```typescript +generateContent(templateId: string, templateVariables: object): Promise; +``` + +#### Parameters + +| Parameter | Type | Description | +| --- | --- | --- | +| templateId | string | The ID of the server-side template to execute. | +| templateVariables | object | A key-value map of variables to populate the template with. | + +Returns: + +Promise<[GenerateContentResult](./ai.generatecontentresult.md#generatecontentresult_interface)> + +## TemplateGenerativeModel.generateContentStream() + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Makes a single streaming call to the model and returns an object containing an iterable stream that iterates over all chunks in the streaming response as well as a promise that returns the final aggregated response. + +Signature: + +```typescript +generateContentStream(templateId: string, templateVariables: object): Promise; +``` + +#### Parameters + +| Parameter | Type | Description | +| --- | --- | --- | +| templateId | string | The ID of the server-side template to execute. | +| templateVariables | object | A key-value map of variables to populate the template with. | + +Returns: + +Promise<[GenerateContentStreamResult](./ai.generatecontentstreamresult.md#generatecontentstreamresult_interface)> + diff --git a/docs-devsite/ai.templateimagenmodel.md b/docs-devsite/ai.templateimagenmodel.md new file mode 100644 index 00000000000..2d86071993f --- /dev/null +++ b/docs-devsite/ai.templateimagenmodel.md @@ -0,0 +1,100 @@ +Project: /docs/reference/js/_project.yaml +Book: /docs/reference/_book.yaml +page_type: reference + +{% comment %} +DO NOT EDIT THIS FILE! +This is generated by the JS SDK team, and any local changes will be +overwritten. Changes should be made in the source code at +https://github.com/firebase/firebase-js-sdk +{% endcomment %} + +# TemplateImagenModel class +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Class for Imagen model APIs that execute on a server-side template. + +This class should only be instantiated with [getTemplateImagenModel()](./ai.md#gettemplateimagenmodel_9476bbc). + +Signature: + +```typescript +export declare class TemplateImagenModel +``` + +## Constructors + +| Constructor | Modifiers | Description | +| --- | --- | --- | +| [(constructor)(ai, requestOptions)](./ai.templateimagenmodel.md#templateimagenmodelconstructor) | | (Public Preview) Constructs a new instance of the TemplateImagenModel class | + +## Properties + +| Property | Modifiers | Type | Description | +| --- | --- | --- | --- | +| [requestOptions](./ai.templateimagenmodel.md#templateimagenmodelrequestoptions) | | [RequestOptions](./ai.requestoptions.md#requestoptions_interface) | (Public Preview) Additional options to use when making requests. | + +## Methods + +| Method | Modifiers | Description | +| --- | --- | --- | +| [generateImages(templateId, templateVariables)](./ai.templateimagenmodel.md#templateimagenmodelgenerateimages) | | (Public Preview) Makes a single call to the model and returns an object containing a single [ImagenGenerationResponse](./ai.imagengenerationresponse.md#imagengenerationresponse_interface). | + +## TemplateImagenModel.(constructor) + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + + Constructs a new instance of the `TemplateImagenModel` class + +Signature: + +```typescript +constructor(ai: AI, requestOptions?: RequestOptions); +``` + +#### Parameters + +| Parameter | Type | Description | +| --- | --- | --- | +| ai | [AI](./ai.ai.md#ai_interface) | | +| requestOptions | [RequestOptions](./ai.requestoptions.md#requestoptions_interface) | | + +## TemplateImagenModel.requestOptions + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Additional options to use when making requests. + +Signature: + +```typescript +requestOptions?: RequestOptions; +``` + +## TemplateImagenModel.generateImages() + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Makes a single call to the model and returns an object containing a single [ImagenGenerationResponse](./ai.imagengenerationresponse.md#imagengenerationresponse_interface). + +Signature: + +```typescript +generateImages(templateId: string, templateVariables: object): Promise>; +``` + +#### Parameters + +| Parameter | Type | Description | +| --- | --- | --- | +| templateId | string | The ID of the server-side template to execute. | +| templateVariables | object | A key-value map of variables to populate the template with. | + +Returns: + +Promise<[ImagenGenerationResponse](./ai.imagengenerationresponse.md#imagengenerationresponse_interface)<[ImagenInlineImage](./ai.imageninlineimage.md#imageninlineimage_interface)>> + diff --git a/packages/ai/integration/constants.ts b/packages/ai/integration/constants.ts index f4a74e75039..99a65f31c54 100644 --- a/packages/ai/integration/constants.ts +++ b/packages/ai/integration/constants.ts @@ -44,7 +44,7 @@ function formatConfigAsString(config: { ai: AI; model: string }): string { const backends: readonly Backend[] = [ new GoogleAIBackend(), - new VertexAIBackend() + new VertexAIBackend('global') ]; const backendNames: Map = new Map([ diff --git a/packages/ai/integration/prompt-templates.test.ts b/packages/ai/integration/prompt-templates.test.ts new file mode 100644 index 00000000000..3a7f9038561 --- /dev/null +++ b/packages/ai/integration/prompt-templates.test.ts @@ -0,0 +1,66 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { expect } from 'chai'; +import { + BackendType, + getTemplateGenerativeModel, + getTemplateImagenModel +} from '../src'; +import { testConfigs } from './constants'; +import { STAGING_URL } from '../src/constants'; + +const templateBackendSuffix = ( + backendType: BackendType +): 'googleai' | 'vertexai' => + backendType === BackendType.GOOGLE_AI ? 'googleai' : 'vertexai'; + +describe('Prompt templates', function () { + this.timeout(20_000); + testConfigs.forEach(testConfig => { + describe(`${testConfig.toString()}`, () => { + describe('Generative Model', () => { + it('successfully generates content', async () => { + const model = getTemplateGenerativeModel(testConfig.ai, { + baseUrl: STAGING_URL + }); + const { response } = await model.generateContent( + `sassy-greeting-${templateBackendSuffix( + testConfig.ai.backend.backendType + )}`, + { name: 'John' } + ); + expect(response.text()).to.contain('John'); // Template asks to address directly by name + }); + }); + describe('Imagen model', async () => { + it('successfully generates images', async () => { + const model = getTemplateImagenModel(testConfig.ai, { + baseUrl: STAGING_URL + }); + const { images } = await model.generateImages( + `portrait-${templateBackendSuffix( + testConfig.ai.backend.backendType + )}`, + { animal: 'Rhino' } + ); + expect(images.length).to.equal(2); // We ask for two images in the prompt template + }); + }); + }); + }); +}); diff --git a/packages/ai/src/api.test.ts b/packages/ai/src/api.test.ts index 3854f010fc7..3a56f3a8feb 100644 --- a/packages/ai/src/api.test.ts +++ b/packages/ai/src/api.test.ts @@ -22,7 +22,11 @@ import { LiveGenerativeModel, getGenerativeModel, getImagenModel, - getLiveGenerativeModel + getLiveGenerativeModel, + getTemplateGenerativeModel, + TemplateGenerativeModel, + getTemplateImagenModel, + TemplateImagenModel } from './api'; import { expect } from 'chai'; import { AI } from './public-types'; @@ -266,4 +270,14 @@ describe('Top level API', () => { 'publishers/google/models/my-model' ); }); + it('getTemplateGenerativeModel gets a TemplateGenerativeModel', () => { + expect(getTemplateGenerativeModel(fakeAI)).to.be.an.instanceOf( + TemplateGenerativeModel + ); + }); + it('getImagenModel gets a TemplateImagenModel', () => { + expect(getTemplateImagenModel(fakeAI)).to.be.an.instanceOf( + TemplateImagenModel + ); + }); }); diff --git a/packages/ai/src/api.ts b/packages/ai/src/api.ts index 6e56aea793c..29614d88cec 100644 --- a/packages/ai/src/api.ts +++ b/packages/ai/src/api.ts @@ -39,12 +39,22 @@ import { import { encodeInstanceIdentifier } from './helpers'; import { GoogleAIBackend } from './backend'; import { WebSocketHandlerImpl } from './websocket'; +import { TemplateGenerativeModel } from './models/template-generative-model'; +import { TemplateImagenModel } from './models/template-imagen-model'; export { ChatSession } from './methods/chat-session'; export { LiveSession } from './methods/live-session'; export * from './requests/schema-builder'; export { ImagenImageFormat } from './requests/imagen-image-format'; -export { AIModel, GenerativeModel, LiveGenerativeModel, ImagenModel, AIError }; +export { + AIModel, + GenerativeModel, + LiveGenerativeModel, + ImagenModel, + TemplateGenerativeModel, + TemplateImagenModel, + AIError +}; export { Backend, VertexAIBackend, GoogleAIBackend } from './backend'; export { startAudioConversation, @@ -202,3 +212,35 @@ export function getLiveGenerativeModel( const webSocketHandler = new WebSocketHandlerImpl(); return new LiveGenerativeModel(ai, modelParams, webSocketHandler); } + +/** + * Returns a {@link TemplateGenerativeModel} class for executing server-side + * templates. + * + * @param ai - An {@link AI} instance. + * @param requestOptions - Additional options to use when making requests. + * + * @beta + */ +export function getTemplateGenerativeModel( + ai: AI, + requestOptions?: RequestOptions +): TemplateGenerativeModel { + return new TemplateGenerativeModel(ai, requestOptions); +} + +/** + * Returns a {@link TemplateImagenModel} class for executing server-side + * Imagen templates. + * + * @param ai - An {@link AI} instance. + * @param requestOptions - Additional options to use when making requests. + * + * @beta + */ +export function getTemplateImagenModel( + ai: AI, + requestOptions?: RequestOptions +): TemplateImagenModel { + return new TemplateImagenModel(ai, requestOptions); +} diff --git a/packages/ai/src/backend.test.ts b/packages/ai/src/backend.test.ts index 0c6609277e3..46d6507a499 100644 --- a/packages/ai/src/backend.test.ts +++ b/packages/ai/src/backend.test.ts @@ -18,7 +18,7 @@ import { expect } from 'chai'; import { GoogleAIBackend, VertexAIBackend } from './backend'; import { BackendType } from './public-types'; -import { DEFAULT_LOCATION } from './constants'; +import { DEFAULT_API_VERSION, DEFAULT_LOCATION } from './constants'; describe('Backend', () => { describe('GoogleAIBackend', () => { @@ -26,6 +26,18 @@ describe('Backend', () => { const backend = new GoogleAIBackend(); expect(backend.backendType).to.equal(BackendType.GOOGLE_AI); }); + it('getModelPath', () => { + const backend = new GoogleAIBackend(); + expect(backend._getModelPath('my-project', 'model-name')).to.equal( + `/${DEFAULT_API_VERSION}/projects/my-project/model-name` + ); + }); + it('getTemplatePath', () => { + const backend = new GoogleAIBackend(); + expect(backend._getTemplatePath('my-project', 'template-id')).to.equal( + `/${DEFAULT_API_VERSION}/projects/my-project/templates/template-id` + ); + }); }); describe('VertexAIBackend', () => { it('set backendType to VERTEX_AI', () => { @@ -48,5 +60,17 @@ describe('Backend', () => { expect(backend.backendType).to.equal(BackendType.VERTEX_AI); expect(backend.location).to.equal(DEFAULT_LOCATION); }); + it('getModelPath', () => { + const backend = new VertexAIBackend(); + expect(backend._getModelPath('my-project', 'model-name')).to.equal( + `/${DEFAULT_API_VERSION}/projects/my-project/locations/${backend.location}/model-name` + ); + }); + it('getTemplatePath', () => { + const backend = new VertexAIBackend(); + expect(backend._getTemplatePath('my-project', 'template-id')).to.equal( + `/${DEFAULT_API_VERSION}/projects/my-project/locations/${backend.location}/templates/template-id` + ); + }); }); }); diff --git a/packages/ai/src/backend.ts b/packages/ai/src/backend.ts index 7209828122b..2eaec59448f 100644 --- a/packages/ai/src/backend.ts +++ b/packages/ai/src/backend.ts @@ -15,7 +15,7 @@ * limitations under the License. */ -import { DEFAULT_LOCATION } from './constants'; +import { DEFAULT_API_VERSION, DEFAULT_LOCATION } from './constants'; import { BackendType } from './public-types'; /** @@ -39,6 +39,16 @@ export abstract class Backend { protected constructor(type: BackendType) { this.backendType = type; } + + /** + * @internal + */ + abstract _getModelPath(project: string, model: string): string; + + /** + * @internal + */ + abstract _getTemplatePath(project: string, templateId: string): string; } /** @@ -56,6 +66,20 @@ export class GoogleAIBackend extends Backend { constructor() { super(BackendType.GOOGLE_AI); } + + /** + * @internal + */ + _getModelPath(project: string, model: string): string { + return `/${DEFAULT_API_VERSION}/projects/${project}/${model}`; + } + + /** + * @internal + */ + _getTemplatePath(project: string, templateId: string): string { + return `/${DEFAULT_API_VERSION}/projects/${project}/templates/${templateId}`; + } } /** @@ -89,4 +113,18 @@ export class VertexAIBackend extends Backend { this.location = location; } } + + /** + * @internal + */ + _getModelPath(project: string, model: string): string { + return `/${DEFAULT_API_VERSION}/projects/${project}/locations/${this.location}/${model}`; + } + + /** + * @internal + */ + _getTemplatePath(project: string, templateId: string): string { + return `/${DEFAULT_API_VERSION}/projects/${project}/locations/${this.location}/templates/${templateId}`; + } } diff --git a/packages/ai/src/constants.ts b/packages/ai/src/constants.ts index 82482527f3b..0a6f7e91436 100644 --- a/packages/ai/src/constants.ts +++ b/packages/ai/src/constants.ts @@ -23,6 +23,9 @@ export const DEFAULT_LOCATION = 'us-central1'; export const DEFAULT_DOMAIN = 'firebasevertexai.googleapis.com'; +export const STAGING_URL = + 'https://staging-firebasevertexai.sandbox.googleapis.com'; + export const DEFAULT_API_VERSION = 'v1beta'; export const PACKAGE_VERSION = version; diff --git a/packages/ai/src/methods/count-tokens.test.ts b/packages/ai/src/methods/count-tokens.test.ts index 84976d00ac9..b3ed7f7fa4d 100644 --- a/packages/ai/src/methods/count-tokens.test.ts +++ b/packages/ai/src/methods/count-tokens.test.ts @@ -71,16 +71,17 @@ describe('countTokens()', () => { fakeChromeAdapter ); expect(result.totalTokens).to.equal(6); - expect(result.totalBillableCharacters).to.equal(16); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.COUNT_TOKENS, - fakeApiSettings, - false, + { + model: 'model', + task: Task.COUNT_TOKENS, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, match((value: string) => { return value.includes('contents'); - }), - undefined + }) ); }); it('total tokens with modality details', async () => { @@ -102,14 +103,16 @@ describe('countTokens()', () => { expect(result.promptTokensDetails?.[0].modality).to.equal('IMAGE'); expect(result.promptTokensDetails?.[0].tokenCount).to.equal(1806); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.COUNT_TOKENS, - fakeApiSettings, - false, + { + model: 'model', + task: Task.COUNT_TOKENS, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, match((value: string) => { return value.includes('contents'); - }), - undefined + }) ); }); it('total tokens no billable characters', async () => { @@ -129,14 +132,16 @@ describe('countTokens()', () => { expect(result.totalTokens).to.equal(258); expect(result).to.not.have.property('totalBillableCharacters'); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.COUNT_TOKENS, - fakeApiSettings, - false, + { + model: 'model', + task: Task.COUNT_TOKENS, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, match((value: string) => { return value.includes('contents'); - }), - undefined + }) ); }); it('model not found', async () => { @@ -181,12 +186,14 @@ describe('countTokens()', () => { ); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.COUNT_TOKENS, - fakeGoogleAIApiSettings, - false, - JSON.stringify(mapCountTokensRequest(fakeRequestParams, 'model')), - undefined + { + model: 'model', + task: Task.COUNT_TOKENS, + apiSettings: fakeGoogleAIApiSettings, + stream: false, + requestOptions: undefined + }, + JSON.stringify(mapCountTokensRequest(fakeRequestParams, 'model')) ); }); }); diff --git a/packages/ai/src/methods/count-tokens.ts b/packages/ai/src/methods/count-tokens.ts index c6041a0bb99..20c633ee703 100644 --- a/packages/ai/src/methods/count-tokens.ts +++ b/packages/ai/src/methods/count-tokens.ts @@ -23,7 +23,7 @@ import { RequestOptions, AIErrorCode } from '../types'; -import { Task, makeRequest } from '../requests/request'; +import { makeRequest, Task } from '../requests/request'; import { ApiSettings } from '../types/internal'; import * as GoogleAIMapper from '../googleai-mappers'; import { BackendType } from '../public-types'; @@ -43,12 +43,14 @@ export async function countTokensOnCloud( body = JSON.stringify(params); } const response = await makeRequest( - model, - Task.COUNT_TOKENS, - apiSettings, - false, - body, - requestOptions + { + model, + task: Task.COUNT_TOKENS, + apiSettings, + stream: false, + requestOptions + }, + body ); return response.json(); } diff --git a/packages/ai/src/methods/generate-content.test.ts b/packages/ai/src/methods/generate-content.test.ts index 33a9ae5f5e3..8a274c24417 100644 --- a/packages/ai/src/methods/generate-content.test.ts +++ b/packages/ai/src/methods/generate-content.test.ts @@ -19,9 +19,16 @@ import { expect, use } from 'chai'; import Sinon, { match, restore, stub } from 'sinon'; import sinonChai from 'sinon-chai'; import chaiAsPromised from 'chai-as-promised'; -import { getMockResponse } from '../../test-utils/mock-response'; +import { + getMockResponse, + getMockResponseStreaming +} from '../../test-utils/mock-response'; import * as request from '../requests/request'; -import { generateContent } from './generate-content'; +import { + generateContent, + templateGenerateContent, + templateGenerateContentStream +} from './generate-content'; import { AIErrorCode, GenerateContentRequest, @@ -103,12 +110,14 @@ describe('generateContent()', () => { ); expect(result.response.text()).to.include('Mountain View, California'); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - JSON.stringify(fakeRequestParams), - undefined + { + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, + JSON.stringify(fakeRequestParams) ); }); it('long response', async () => { @@ -127,11 +136,14 @@ describe('generateContent()', () => { expect(result.response.text()).to.include('Use Freshly Ground Coffee'); expect(result.response.text()).to.include('30 minutes of brewing'); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - match.any + { + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, + JSON.stringify(fakeRequestParams) ); }); it('long response with token details', async () => { @@ -162,11 +174,14 @@ describe('generateContent()', () => { result.response.usageMetadata?.candidatesTokensDetails?.[0].tokenCount ).to.equal(76); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - match.any + { + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, + JSON.stringify(fakeRequestParams) ); }); it('citations', async () => { @@ -189,11 +204,14 @@ describe('generateContent()', () => { result.response.candidates?.[0].citationMetadata?.citations.length ).to.equal(3); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - match.any + { + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, + JSON.stringify(fakeRequestParams) ); }); it('google search grounding', async () => { @@ -236,11 +254,14 @@ describe('generateContent()', () => { .undefined; expect(makeRequestStub).to.be.calledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - match.any + { + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, + JSON.stringify(fakeRequestParams) ); it('url context', async () => { @@ -286,10 +307,12 @@ describe('generateContent()', () => { .be.undefined; expect(makeRequestStub).to.be.calledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, + { + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false + }, match.any ); }); @@ -328,11 +351,14 @@ describe('generateContent()', () => { ); expect(result.response.text).to.throw('SAFETY'); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - match.any + { + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, + JSON.stringify(fakeRequestParams) ); }); it('finishReason safety', async () => { @@ -350,11 +376,14 @@ describe('generateContent()', () => { ); expect(result.response.text).to.throw('SAFETY'); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - match.any + { + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, + JSON.stringify(fakeRequestParams) ); }); it('empty content', async () => { @@ -372,11 +401,14 @@ describe('generateContent()', () => { ); expect(result.response.text()).to.equal(''); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - match.any + { + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, + JSON.stringify(fakeRequestParams) ); }); it('empty part', async () => { @@ -410,11 +442,14 @@ describe('generateContent()', () => { ); expect(result.response.text()).to.include('Some text'); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - match.any + { + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, + JSON.stringify(fakeRequestParams) ); }); it('image rejected (400)', async () => { @@ -502,12 +537,14 @@ describe('generateContent()', () => { ); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.GENERATE_CONTENT, - fakeGoogleAIApiSettings, - false, - JSON.stringify(mapGenerateContentRequest(fakeGoogleAIRequestParams)), - undefined + { + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeGoogleAIApiSettings, + stream: false, + requestOptions: match.any + }, + JSON.stringify(mapGenerateContentRequest(fakeGoogleAIRequestParams)) ); }); }); @@ -533,3 +570,83 @@ describe('generateContent()', () => { expect(generateContentStub).to.be.calledWith(fakeRequestParams); }); }); + +describe('templateGenerateContent', () => { + afterEach(() => { + restore(); + }); + it('should call makeRequest with correct parameters and process the response', async () => { + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-basic-reply-short.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + const templateId = 'my-template'; + const templateParams = { name: 'world' }; + const requestOptions = { timeout: 5000 }; + + const result = await templateGenerateContent( + fakeApiSettings, + templateId, + templateParams, + requestOptions + ); + + expect(makeRequestStub).to.have.been.calledOnceWith( + { + task: 'templateGenerateContent', + templateId, + apiSettings: fakeApiSettings, + stream: false, + requestOptions + }, + JSON.stringify(templateParams) + ); + expect(result.response.text()).to.include('Mountain View, California'); + }); +}); + +describe('templateGenerateContentStream', () => { + afterEach(() => { + restore(); + }); + it('should call makeRequest with correct parameters for streaming', async () => { + const mockResponse = getMockResponseStreaming( + 'vertexAI', + 'streaming-success-basic-reply-short.txt' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + const templateId = 'my-stream-template'; + const templateParams = { name: 'streaming world' }; + const requestOptions = { timeout: 10000 }; + + const result = await templateGenerateContentStream( + fakeApiSettings, + templateId, + templateParams, + requestOptions + ); + + expect(makeRequestStub).to.have.been.calledOnceWith( + { + task: 'templateStreamGenerateContent', + templateId, + apiSettings: fakeApiSettings, + stream: true, + requestOptions + }, + JSON.stringify(templateParams) + ); + + // Verify the stream processing part + for await (const item of result.stream) { + expect(item.text()).to.not.be.empty; + } + const response = await result.response; + expect(response.text()).to.include('Cheyenne'); + }); +}); diff --git a/packages/ai/src/methods/generate-content.ts b/packages/ai/src/methods/generate-content.ts index a2fb29e20d1..fc6eac15c74 100644 --- a/packages/ai/src/methods/generate-content.ts +++ b/packages/ai/src/methods/generate-content.ts @@ -22,7 +22,11 @@ import { GenerateContentStreamResult, RequestOptions } from '../types'; -import { Task, makeRequest } from '../requests/request'; +import { + makeRequest, + ServerPromptTemplateTask, + Task +} from '../requests/request'; import { createEnhancedContentResponse } from '../requests/response-helpers'; import { processStream } from '../requests/stream-reader'; import { ApiSettings } from '../types/internal'; @@ -41,12 +45,14 @@ async function generateContentStreamOnCloud( params = GoogleAIMapper.mapGenerateContentRequest(params); } return makeRequest( - model, - Task.STREAM_GENERATE_CONTENT, - apiSettings, - /* stream */ true, - JSON.stringify(params), - requestOptions + { + task: Task.STREAM_GENERATE_CONTENT, + model, + apiSettings, + stream: true, + requestOptions + }, + JSON.stringify(params) ); } @@ -77,13 +83,62 @@ async function generateContentOnCloud( params = GoogleAIMapper.mapGenerateContentRequest(params); } return makeRequest( - model, - Task.GENERATE_CONTENT, - apiSettings, - /* stream */ false, - JSON.stringify(params), - requestOptions + { + model, + task: Task.GENERATE_CONTENT, + apiSettings, + stream: false, + requestOptions + }, + JSON.stringify(params) + ); +} + +export async function templateGenerateContent( + apiSettings: ApiSettings, + templateId: string, + templateParams: object, + requestOptions?: RequestOptions +): Promise { + const response = await makeRequest( + { + task: ServerPromptTemplateTask.TEMPLATE_GENERATE_CONTENT, + templateId, + apiSettings, + stream: false, + requestOptions + }, + JSON.stringify(templateParams) + ); + const generateContentResponse = await processGenerateContentResponse( + response, + apiSettings + ); + const enhancedResponse = createEnhancedContentResponse( + generateContentResponse + ); + return { + response: enhancedResponse + }; +} + +export async function templateGenerateContentStream( + apiSettings: ApiSettings, + templateId: string, + templateParams: object, + requestOptions?: RequestOptions +): Promise { + const response = await makeRequest( + { + task: ServerPromptTemplateTask.TEMPLATE_STREAM_GENERATE_CONTENT, + templateId, + apiSettings, + stream: true, + requestOptions + }, + JSON.stringify(templateParams) ); + return processStream(response, apiSettings); } export async function generateContent( diff --git a/packages/ai/src/models/ai-model.test.ts b/packages/ai/src/models/ai-model.test.ts index 2e8f8998c58..4786adc8546 100644 --- a/packages/ai/src/models/ai-model.test.ts +++ b/packages/ai/src/models/ai-model.test.ts @@ -15,13 +15,10 @@ * limitations under the License. */ import { use, expect } from 'chai'; -import { AI, AIErrorCode } from '../public-types'; +import { AI } from '../public-types'; import sinonChai from 'sinon-chai'; -import { stub } from 'sinon'; import { AIModel } from './ai-model'; -import { AIError } from '../errors'; import { VertexAIBackend } from '../backend'; -import { AIService } from '../service'; use(sinonChai); @@ -69,105 +66,4 @@ describe('AIModel', () => { const testModel = new TestModel(fakeAI, 'tunedModels/my-model'); expect(testModel.model).to.equal('tunedModels/my-model'); }); - it('calls regular app check token when option is set', async () => { - const getTokenStub = stub().resolves(); - const getLimitedUseTokenStub = stub().resolves(); - const testModel = new TestModel( - //@ts-ignore - { - ...fakeAI, - options: { useLimitedUseAppCheckTokens: false }, - appCheck: { - getToken: getTokenStub, - getLimitedUseToken: getLimitedUseTokenStub - } - } as AIService, - 'models/my-model' - ); - if (testModel._apiSettings?.getAppCheckToken) { - await testModel._apiSettings.getAppCheckToken(); - } - expect(getTokenStub).to.be.called; - expect(getLimitedUseTokenStub).to.not.be.called; - getTokenStub.reset(); - getLimitedUseTokenStub.reset(); - }); - it('calls limited use token when option is set', async () => { - const getTokenStub = stub().resolves(); - const getLimitedUseTokenStub = stub().resolves(); - const testModel = new TestModel( - //@ts-ignore - { - ...fakeAI, - options: { useLimitedUseAppCheckTokens: true }, - appCheck: { - getToken: getTokenStub, - getLimitedUseToken: getLimitedUseTokenStub - } - } as AIService, - 'models/my-model' - ); - if (testModel._apiSettings?.getAppCheckToken) { - await testModel._apiSettings.getAppCheckToken(); - } - expect(getTokenStub).to.not.be.called; - expect(getLimitedUseTokenStub).to.be.called; - getTokenStub.reset(); - getLimitedUseTokenStub.reset(); - }); - it('throws if not passed an api key', () => { - const fakeAI: AI = { - app: { - name: 'DEFAULT', - automaticDataCollectionEnabled: true, - options: { - projectId: 'my-project' - } - }, - backend: new VertexAIBackend('us-central1'), - location: 'us-central1' - }; - try { - new TestModel(fakeAI, 'my-model'); - } catch (e) { - expect((e as AIError).code).to.equal(AIErrorCode.NO_API_KEY); - } - }); - it('throws if not passed a project ID', () => { - const fakeAI: AI = { - app: { - name: 'DEFAULT', - automaticDataCollectionEnabled: true, - options: { - apiKey: 'key' - } - }, - backend: new VertexAIBackend('us-central1'), - location: 'us-central1' - }; - try { - new TestModel(fakeAI, 'my-model'); - } catch (e) { - expect((e as AIError).code).to.equal(AIErrorCode.NO_PROJECT_ID); - } - }); - it('throws if not passed an app ID', () => { - const fakeAI: AI = { - app: { - name: 'DEFAULT', - automaticDataCollectionEnabled: true, - options: { - apiKey: 'key', - projectId: 'my-project' - } - }, - backend: new VertexAIBackend('us-central1'), - location: 'us-central1' - }; - try { - new TestModel(fakeAI, 'my-model'); - } catch (e) { - expect((e as AIError).code).to.equal(AIErrorCode.NO_APP_ID); - } - }); }); diff --git a/packages/ai/src/models/ai-model.ts b/packages/ai/src/models/ai-model.ts index 3fe202d5eb2..e2bc70319d8 100644 --- a/packages/ai/src/models/ai-model.ts +++ b/packages/ai/src/models/ai-model.ts @@ -15,11 +15,9 @@ * limitations under the License. */ -import { AIError } from '../errors'; -import { AIErrorCode, AI, BackendType } from '../public-types'; -import { AIService } from '../service'; +import { AI, BackendType } from '../public-types'; import { ApiSettings } from '../types/internal'; -import { _isFirebaseServerApp } from '@firebase/app'; +import { initApiSettings } from './utils'; /** * Base class for Firebase AI model APIs. @@ -59,56 +57,11 @@ export abstract class AIModel { * @internal */ protected constructor(ai: AI, modelName: string) { - if (!ai.app?.options?.apiKey) { - throw new AIError( - AIErrorCode.NO_API_KEY, - `The "apiKey" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid API key.` - ); - } else if (!ai.app?.options?.projectId) { - throw new AIError( - AIErrorCode.NO_PROJECT_ID, - `The "projectId" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid project ID.` - ); - } else if (!ai.app?.options?.appId) { - throw new AIError( - AIErrorCode.NO_APP_ID, - `The "appId" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid app ID.` - ); - } else { - this._apiSettings = { - apiKey: ai.app.options.apiKey, - project: ai.app.options.projectId, - appId: ai.app.options.appId, - automaticDataCollectionEnabled: ai.app.automaticDataCollectionEnabled, - location: ai.location, - backend: ai.backend - }; - - if (_isFirebaseServerApp(ai.app) && ai.app.settings.appCheckToken) { - const token = ai.app.settings.appCheckToken; - this._apiSettings.getAppCheckToken = () => { - return Promise.resolve({ token }); - }; - } else if ((ai as AIService).appCheck) { - if (ai.options?.useLimitedUseAppCheckTokens) { - this._apiSettings.getAppCheckToken = () => - (ai as AIService).appCheck!.getLimitedUseToken(); - } else { - this._apiSettings.getAppCheckToken = () => - (ai as AIService).appCheck!.getToken(); - } - } - - if ((ai as AIService).auth) { - this._apiSettings.getAuthToken = () => - (ai as AIService).auth!.getToken(); - } - - this.model = AIModel.normalizeModelName( - modelName, - this._apiSettings.backend.backendType - ); - } + this._apiSettings = initApiSettings(ai); + this.model = AIModel.normalizeModelName( + modelName, + this._apiSettings.backend.backendType + ); } /** diff --git a/packages/ai/src/models/generative-model.test.ts b/packages/ai/src/models/generative-model.test.ts index 90399e6811b..45430cb5f59 100644 --- a/packages/ai/src/models/generative-model.test.ts +++ b/packages/ai/src/models/generative-model.test.ts @@ -92,10 +92,13 @@ describe('GenerativeModel', () => { ); await genModel.generateContent('hello'); expect(makeRequestStub).to.be.calledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - match.any, - false, + { + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: match.any, + stream: false, + requestOptions: {} + }, match((value: string) => { return ( value.includes('myfunc') && @@ -104,8 +107,7 @@ describe('GenerativeModel', () => { value.includes(FunctionCallingMode.NONE) && value.includes('be friendly') ); - }), - {} + }) ); restore(); }); @@ -129,14 +131,16 @@ describe('GenerativeModel', () => { ); await genModel.generateContent('hello'); expect(makeRequestStub).to.be.calledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - match.any, - false, + { + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: match.any, + stream: false, + requestOptions: {} + }, match((value: string) => { return value.includes('be friendly'); - }), - {} + }) ); restore(); }); @@ -190,10 +194,13 @@ describe('GenerativeModel', () => { systemInstruction: { role: 'system', parts: [{ text: 'be formal' }] } }); expect(makeRequestStub).to.be.calledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - match.any, - false, + { + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: match.any, + stream: false, + requestOptions: {} + }, match((value: string) => { return ( value.includes('otherfunc') && @@ -202,8 +209,7 @@ describe('GenerativeModel', () => { value.includes(FunctionCallingMode.AUTO) && value.includes('be formal') ); - }), - {} + }) ); restore(); }); @@ -281,10 +287,13 @@ describe('GenerativeModel', () => { ); await genModel.startChat().sendMessage('hello'); expect(makeRequestStub).to.be.calledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - match.any, - false, + { + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: match.any, + stream: false, + requestOptions: {} + }, match((value: string) => { return ( value.includes('myfunc') && @@ -294,8 +303,7 @@ describe('GenerativeModel', () => { value.includes('be friendly') && value.includes('topK') ); - }), - {} + }) ); restore(); }); @@ -319,14 +327,16 @@ describe('GenerativeModel', () => { ); await genModel.startChat().sendMessage('hello'); expect(makeRequestStub).to.be.calledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - match.any, - false, + { + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: match.any, + stream: false, + requestOptions: {} + }, match((value: string) => { return value.includes('be friendly'); - }), - {} + }) ); restore(); }); @@ -382,10 +392,13 @@ describe('GenerativeModel', () => { }) .sendMessage('hello'); expect(makeRequestStub).to.be.calledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - match.any, - false, + { + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: match.any, + stream: false, + requestOptions: {} + }, match((value: string) => { return ( value.includes('otherfunc') && @@ -396,8 +409,7 @@ describe('GenerativeModel', () => { value.includes('image/png') && !value.includes('image/jpeg') ); - }), - {} + }) ); restore(); }); @@ -417,10 +429,13 @@ describe('GenerativeModel', () => { ); await genModel.countTokens('hello'); expect(makeRequestStub).to.be.calledWith( - 'publishers/google/models/my-model', - request.Task.COUNT_TOKENS, - match.any, - false, + { + model: 'publishers/google/models/my-model', + task: request.Task.COUNT_TOKENS, + apiSettings: match.any, + stream: false, + requestOptions: undefined + }, match((value: string) => { return value.includes('hello'); }) diff --git a/packages/ai/src/models/imagen-model.test.ts b/packages/ai/src/models/imagen-model.test.ts index f4121e18f2d..68b6caca098 100644 --- a/packages/ai/src/models/imagen-model.test.ts +++ b/packages/ai/src/models/imagen-model.test.ts @@ -62,17 +62,19 @@ describe('ImagenModel', () => { const prompt = 'A photorealistic image of a toy boat at sea.'; await imagenModel.generateImages(prompt); expect(makeRequestStub).to.be.calledWith( - 'publishers/google/models/my-model', - request.Task.PREDICT, - match.any, - false, + { + model: 'publishers/google/models/my-model', + task: request.Task.PREDICT, + apiSettings: match.any, + stream: false, + requestOptions: undefined + }, match((value: string) => { return ( value.includes(`"prompt":"${prompt}"`) && value.includes(`"sampleCount":1`) ); - }), - undefined + }) ); restore(); }); @@ -102,10 +104,13 @@ describe('ImagenModel', () => { const prompt = 'A photorealistic image of a toy boat at sea.'; await imagenModel.generateImages(prompt); expect(makeRequestStub).to.be.calledWith( - 'publishers/google/models/my-model', - request.Task.PREDICT, - match.any, - false, + { + model: 'publishers/google/models/my-model', + task: request.Task.PREDICT, + apiSettings: match.any, + stream: false, + requestOptions: undefined + }, match((value: string) => { return ( value.includes( @@ -130,8 +135,7 @@ describe('ImagenModel', () => { JSON.stringify(imagenModel.safetySettings?.personFilterLevel) ) ); - }), - undefined + }) ); restore(); }); diff --git a/packages/ai/src/models/imagen-model.ts b/packages/ai/src/models/imagen-model.ts index a41a03f25cf..567333ee64f 100644 --- a/packages/ai/src/models/imagen-model.ts +++ b/packages/ai/src/models/imagen-model.ts @@ -16,7 +16,7 @@ */ import { AI } from '../public-types'; -import { Task, makeRequest } from '../requests/request'; +import { makeRequest, Task } from '../requests/request'; import { createPredictRequestBody } from '../requests/request-helpers'; import { handlePredictResponse } from '../requests/response-helpers'; import { @@ -109,12 +109,14 @@ export class ImagenModel extends AIModel { ...this.safetySettings }); const response = await makeRequest( - this.model, - Task.PREDICT, - this._apiSettings, - /* stream */ false, - JSON.stringify(body), - this.requestOptions + { + task: Task.PREDICT, + model: this.model, + apiSettings: this._apiSettings, + stream: false, + requestOptions: this.requestOptions + }, + JSON.stringify(body) ); return handlePredictResponse(response); } @@ -148,12 +150,14 @@ export class ImagenModel extends AIModel { ...this.safetySettings }); const response = await makeRequest( - this.model, - Task.PREDICT, - this._apiSettings, - /* stream */ false, - JSON.stringify(body), - this.requestOptions + { + task: Task.PREDICT, + model: this.model, + apiSettings: this._apiSettings, + stream: false, + requestOptions: this.requestOptions + }, + JSON.stringify(body) ); return handlePredictResponse(response); } diff --git a/packages/ai/src/models/template-generative-model.test.ts b/packages/ai/src/models/template-generative-model.test.ts new file mode 100644 index 00000000000..c3eb43af491 --- /dev/null +++ b/packages/ai/src/models/template-generative-model.test.ts @@ -0,0 +1,96 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { use, expect } from 'chai'; +import sinonChai from 'sinon-chai'; +import { restore, stub } from 'sinon'; +import { AI } from '../public-types'; +import { VertexAIBackend } from '../backend'; +import { TemplateGenerativeModel } from './template-generative-model'; +import * as generateContentMethods from '../methods/generate-content'; + +use(sinonChai); + +const fakeAI: AI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + apiKey: 'key', + projectId: 'my-project', + appId: 'my-appid' + } + }, + backend: new VertexAIBackend('us-central1'), + location: 'us-central1' +}; + +const TEMPLATE_ID = 'my-template'; +const TEMPLATE_VARS = { a: 1, b: '2' }; + +describe('TemplateGenerativeModel', () => { + afterEach(() => { + restore(); + }); + + describe('constructor', () => { + it('should initialize _apiSettings correctly', () => { + const model = new TemplateGenerativeModel(fakeAI); + expect(model._apiSettings.apiKey).to.equal('key'); + expect(model._apiSettings.project).to.equal('my-project'); + expect(model._apiSettings.appId).to.equal('my-appid'); + }); + }); + + describe('generateContent', () => { + it('should call templateGenerateContent with correct parameters', async () => { + const templateGenerateContentStub = stub( + generateContentMethods, + 'templateGenerateContent' + ).resolves({} as any); + const model = new TemplateGenerativeModel(fakeAI, { timeout: 5000 }); + + await model.generateContent(TEMPLATE_ID, TEMPLATE_VARS); + + expect(templateGenerateContentStub).to.have.been.calledOnceWith( + model._apiSettings, + TEMPLATE_ID, + { inputs: TEMPLATE_VARS }, + { timeout: 5000 } + ); + }); + }); + + describe('generateContentStream', () => { + it('should call templateGenerateContentStream with correct parameters', async () => { + const templateGenerateContentStreamStub = stub( + generateContentMethods, + 'templateGenerateContentStream' + ).resolves({} as any); + const model = new TemplateGenerativeModel(fakeAI, { timeout: 5000 }); + + await model.generateContentStream(TEMPLATE_ID, TEMPLATE_VARS); + + expect(templateGenerateContentStreamStub).to.have.been.calledOnceWith( + model._apiSettings, + TEMPLATE_ID, + { inputs: TEMPLATE_VARS }, + { timeout: 5000 } + ); + }); + }); +}); diff --git a/packages/ai/src/models/template-generative-model.ts b/packages/ai/src/models/template-generative-model.ts new file mode 100644 index 00000000000..ec9e653618d --- /dev/null +++ b/packages/ai/src/models/template-generative-model.ts @@ -0,0 +1,98 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + templateGenerateContent, + templateGenerateContentStream +} from '../methods/generate-content'; +import { GenerateContentResult, RequestOptions } from '../types'; +import { AI, GenerateContentStreamResult } from '../public-types'; +import { ApiSettings } from '../types/internal'; +import { initApiSettings } from './utils'; + +/** + * {@link GenerativeModel} APIs that execute on a server-side template. + * + * This class should only be instantiated with {@link getTemplateGenerativeModel}. + * + * @beta + */ +export class TemplateGenerativeModel { + /** + * @internal + */ + _apiSettings: ApiSettings; + + /** + * Additional options to use when making requests. + */ + requestOptions?: RequestOptions; + + /** + * @hideconstructor + */ + constructor(ai: AI, requestOptions?: RequestOptions) { + this.requestOptions = requestOptions || {}; + this._apiSettings = initApiSettings(ai); + } + + /** + * Makes a single non-streaming call to the model and returns an object + * containing a single {@link GenerateContentResponse}. + * + * @param templateId - The ID of the server-side template to execute. + * @param templateVariables - A key-value map of variables to populate the + * template with. + * + * @beta + */ + async generateContent( + templateId: string, + templateVariables: object // anything! + ): Promise { + return templateGenerateContent( + this._apiSettings, + templateId, + { inputs: templateVariables }, + this.requestOptions + ); + } + + /** + * Makes a single streaming call to the model and returns an object + * containing an iterable stream that iterates over all chunks in the + * streaming response as well as a promise that returns the final aggregated + * response. + * + * @param templateId - The ID of the server-side template to execute. + * @param templateVariables - A key-value map of variables to populate the + * template with. + * + * @beta + */ + async generateContentStream( + templateId: string, + templateVariables: object + ): Promise { + return templateGenerateContentStream( + this._apiSettings, + templateId, + { inputs: templateVariables }, + this.requestOptions + ); + } +} diff --git a/packages/ai/src/models/template-imagen-model.test.ts b/packages/ai/src/models/template-imagen-model.test.ts new file mode 100644 index 00000000000..c053753ea0f --- /dev/null +++ b/packages/ai/src/models/template-imagen-model.test.ts @@ -0,0 +1,139 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { use, expect } from 'chai'; +import sinonChai from 'sinon-chai'; +import chaiAsPromised from 'chai-as-promised'; +import { restore, stub } from 'sinon'; +import { AI } from '../public-types'; +import { VertexAIBackend } from '../backend'; +import { TemplateImagenModel } from './template-imagen-model'; +import { AIError } from '../errors'; +import * as request from '../requests/request'; + +use(sinonChai); +use(chaiAsPromised); + +const fakeAI: AI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + apiKey: 'key', + projectId: 'my-project', + appId: 'my-appid' + } + }, + backend: new VertexAIBackend('us-central1'), + location: 'us-central1' +}; + +const TEMPLATE_ID = 'my-imagen-template'; +const TEMPLATE_VARS = { a: 1, b: '2' }; + +describe('TemplateImagenModel', () => { + afterEach(() => { + restore(); + }); + + describe('constructor', () => { + it('should initialize _apiSettings correctly', () => { + const model = new TemplateImagenModel(fakeAI); + expect(model._apiSettings.apiKey).to.equal('key'); + expect(model._apiSettings.project).to.equal('my-project'); + expect(model._apiSettings.appId).to.equal('my-appid'); + }); + }); + + describe('generateImages', () => { + it('should call makeRequest with correct parameters', async () => { + const makeRequestStub = stub(request, 'makeRequest').resolves({ + json: () => + Promise.resolve({ + predictions: [ + { + bytesBase64Encoded: + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==', + mimeType: 'image/png' + } + ] + }) + } as Response); + const model = new TemplateImagenModel(fakeAI, { timeout: 5000 }); + + await model.generateImages(TEMPLATE_ID, TEMPLATE_VARS); + + expect(makeRequestStub).to.have.been.calledOnceWith( + { + task: 'templatePredict', + templateId: TEMPLATE_ID, + apiSettings: model._apiSettings, + stream: false, + requestOptions: { timeout: 5000 } + }, + JSON.stringify({ inputs: TEMPLATE_VARS }) + ); + }); + + it('should return the result of handlePredictResponse', async () => { + const mockPrediction = { + 'bytesBase64Encoded': + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==', + 'mimeType': 'image/png' + }; + stub(request, 'makeRequest').resolves({ + json: () => Promise.resolve({ predictions: [mockPrediction] }) + } as Response); + + const model = new TemplateImagenModel(fakeAI); + const result = await model.generateImages(TEMPLATE_ID, TEMPLATE_VARS); + + expect(result.images).to.deep.equal([mockPrediction]); + }); + + it('should throw an AIError if the prompt is blocked', async () => { + const error = new AIError('fetch-error', 'Request failed'); + stub(request, 'makeRequest').rejects(error); + + const model = new TemplateImagenModel(fakeAI); + await expect( + model.generateImages(TEMPLATE_ID, TEMPLATE_VARS) + ).to.be.rejectedWith(error); + }); + + it('should handle responses with filtered images', async () => { + const mockPrediction = { + bytesBase64Encoded: 'iVBOR...ggg==', + mimeType: 'image/png' + }; + const filteredReason = 'This image was filtered for safety reasons.'; + stub(request, 'makeRequest').resolves({ + json: () => + Promise.resolve({ + predictions: [mockPrediction, { raiFilteredReason: filteredReason }] + }) + } as Response); + + const model = new TemplateImagenModel(fakeAI); + const result = await model.generateImages(TEMPLATE_ID, TEMPLATE_VARS); + + expect(result.images).to.have.lengthOf(1); + expect(result.images[0]).to.deep.equal(mockPrediction); + expect(result.filteredReason).to.equal(filteredReason); + }); + }); +}); diff --git a/packages/ai/src/models/template-imagen-model.ts b/packages/ai/src/models/template-imagen-model.ts new file mode 100644 index 00000000000..34325c711b3 --- /dev/null +++ b/packages/ai/src/models/template-imagen-model.ts @@ -0,0 +1,81 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { RequestOptions } from '../types'; +import { + AI, + ImagenGenerationResponse, + ImagenInlineImage +} from '../public-types'; +import { ApiSettings } from '../types/internal'; +import { makeRequest, ServerPromptTemplateTask } from '../requests/request'; +import { handlePredictResponse } from '../requests/response-helpers'; +import { initApiSettings } from './utils'; + +/** + * Class for Imagen model APIs that execute on a server-side template. + * + * This class should only be instantiated with {@link getTemplateImagenModel}. + * + * @beta + */ +export class TemplateImagenModel { + /** + * @internal + */ + _apiSettings: ApiSettings; + + /** + * Additional options to use when making requests. + */ + requestOptions?: RequestOptions; + + /** + * @hideconstructor + */ + constructor(ai: AI, requestOptions?: RequestOptions) { + this.requestOptions = requestOptions || {}; + this._apiSettings = initApiSettings(ai); + } + + /** + * Makes a single call to the model and returns an object containing a single + * {@link ImagenGenerationResponse}. + * + * @param templateId - The ID of the server-side template to execute. + * @param templateVariables - A key-value map of variables to populate the + * template with. + * + * @beta + */ + async generateImages( + templateId: string, + templateVariables: object + ): Promise> { + const response = await makeRequest( + { + task: ServerPromptTemplateTask.TEMPLATE_PREDICT, + templateId, + apiSettings: this._apiSettings, + stream: false, + requestOptions: this.requestOptions + }, + JSON.stringify({ inputs: templateVariables }) + ); + return handlePredictResponse(response); + } +} diff --git a/packages/ai/src/models/utils.test.ts b/packages/ai/src/models/utils.test.ts new file mode 100644 index 00000000000..42d19007275 --- /dev/null +++ b/packages/ai/src/models/utils.test.ts @@ -0,0 +1,142 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { use, expect } from 'chai'; +import { AI, AIErrorCode } from '../public-types'; +import sinonChai from 'sinon-chai'; +import { stub } from 'sinon'; +import { AIError } from '../errors'; +import { VertexAIBackend } from '../backend'; +import { AIService } from '../service'; +import { initApiSettings } from './utils'; + +use(sinonChai); + +const fakeAI: AI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + apiKey: 'key', + projectId: 'my-project', + appId: 'my-appid' + } + }, + backend: new VertexAIBackend('us-central1'), + location: 'us-central1' +}; + +describe('initApiSettings', () => { + it('calls regular app check token when option is set', async () => { + const getTokenStub = stub().resolves(); + const getLimitedUseTokenStub = stub().resolves(); + const apiSettings = initApiSettings( + //@ts-ignore + { + ...fakeAI, + options: { useLimitedUseAppCheckTokens: false }, + appCheck: { + getToken: getTokenStub, + getLimitedUseToken: getLimitedUseTokenStub + } + } as AIService + ); + if (apiSettings?.getAppCheckToken) { + await apiSettings.getAppCheckToken(); + } + expect(getTokenStub).to.be.called; + expect(getLimitedUseTokenStub).to.not.be.called; + getTokenStub.reset(); + getLimitedUseTokenStub.reset(); + }); + it('calls limited use token when option is set', async () => { + const getTokenStub = stub().resolves(); + const getLimitedUseTokenStub = stub().resolves(); + const apiSettings = initApiSettings( + //@ts-ignore + { + ...fakeAI, + options: { useLimitedUseAppCheckTokens: true }, + appCheck: { + getToken: getTokenStub, + getLimitedUseToken: getLimitedUseTokenStub + } + } as AIService + ); + if (apiSettings?.getAppCheckToken) { + await apiSettings.getAppCheckToken(); + } + expect(getTokenStub).to.not.be.called; + expect(getLimitedUseTokenStub).to.be.called; + getTokenStub.reset(); + getLimitedUseTokenStub.reset(); + }); + it('throws if not passed an api key', () => { + const fakeAI: AI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + projectId: 'my-project' + } + }, + backend: new VertexAIBackend('us-central1'), + location: 'us-central1' + }; + try { + initApiSettings(fakeAI); + } catch (e) { + expect((e as AIError).code).to.equal(AIErrorCode.NO_API_KEY); + } + }); + it('throws if not passed a project ID', () => { + const fakeAI: AI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + apiKey: 'key' + } + }, + backend: new VertexAIBackend('us-central1'), + location: 'us-central1' + }; + try { + initApiSettings(fakeAI); + } catch (e) { + expect((e as AIError).code).to.equal(AIErrorCode.NO_PROJECT_ID); + } + }); + it('throws if not passed an app ID', () => { + const fakeAI: AI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + apiKey: 'key', + projectId: 'my-project' + } + }, + backend: new VertexAIBackend('us-central1'), + location: 'us-central1' + }; + try { + initApiSettings(fakeAI); + } catch (e) { + expect((e as AIError).code).to.equal(AIErrorCode.NO_APP_ID); + } + }); +}); diff --git a/packages/ai/src/models/utils.ts b/packages/ai/src/models/utils.ts new file mode 100644 index 00000000000..035ed3f734d --- /dev/null +++ b/packages/ai/src/models/utils.ts @@ -0,0 +1,78 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { _isFirebaseServerApp } from '@firebase/app'; +import { AIError } from '../errors'; +import { AI, AIErrorCode } from '../public-types'; +import { AIService } from '../service'; +import { ApiSettings } from '../types/internal'; + +/** + * Initializes an {@link ApiSettings} object from an {@link AI} instance. + * + * If this is a Server App, the {@link ApiSettings} object's `getAppCheckToken()` will resolve + * with the `FirebaseServerAppSettings.appCheckToken`, instead of requiring that an App Check + * instance is initialized. + */ +export function initApiSettings(ai: AI): ApiSettings { + if (!ai.app?.options?.apiKey) { + throw new AIError( + AIErrorCode.NO_API_KEY, + `The "apiKey" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid API key.` + ); + } else if (!ai.app?.options?.projectId) { + throw new AIError( + AIErrorCode.NO_PROJECT_ID, + `The "projectId" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid project ID.` + ); + } else if (!ai.app?.options?.appId) { + throw new AIError( + AIErrorCode.NO_APP_ID, + `The "appId" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid app ID.` + ); + } + + const apiSettings: ApiSettings = { + apiKey: ai.app.options.apiKey, + project: ai.app.options.projectId, + appId: ai.app.options.appId, + automaticDataCollectionEnabled: ai.app.automaticDataCollectionEnabled, + location: ai.location, + backend: ai.backend + }; + + if (_isFirebaseServerApp(ai.app) && ai.app.settings.appCheckToken) { + const token = ai.app.settings.appCheckToken; + apiSettings.getAppCheckToken = () => { + return Promise.resolve({ token }); + }; + } else if ((ai as AIService).appCheck) { + if (ai.options?.useLimitedUseAppCheckTokens) { + apiSettings.getAppCheckToken = () => + (ai as AIService).appCheck!.getLimitedUseToken(); + } else { + apiSettings.getAppCheckToken = () => + (ai as AIService).appCheck!.getToken(); + } + } + + if ((ai as AIService).auth) { + apiSettings.getAuthToken = () => (ai as AIService).auth!.getToken(); + } + + return apiSettings; +} diff --git a/packages/ai/src/requests/request.test.ts b/packages/ai/src/requests/request.test.ts index 0d162906fdc..a54ff521bea 100644 --- a/packages/ai/src/requests/request.test.ts +++ b/packages/ai/src/requests/request.test.ts @@ -19,7 +19,13 @@ import { expect, use } from 'chai'; import { match, restore, stub } from 'sinon'; import sinonChai from 'sinon-chai'; import chaiAsPromised from 'chai-as-promised'; -import { RequestUrl, Task, getHeaders, makeRequest } from './request'; +import { + RequestURL, + ServerPromptTemplateTask, + Task, + getHeaders, + makeRequest +} from './request'; import { ApiSettings } from '../types/internal'; import { DEFAULT_API_VERSION } from '../constants'; import { AIErrorCode } from '../types'; @@ -42,65 +48,77 @@ describe('request methods', () => { afterEach(() => { restore(); }); - describe('RequestUrl', () => { + describe('RequestURL', () => { it('stream', async () => { - const url = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - true, - {} - ); + const url = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: true, + requestOptions: {} + }); expect(url.toString()).to.include('models/model-name:generateContent'); - expect(url.toString()).to.not.include(fakeApiSettings); expect(url.toString()).to.include('alt=sse'); }); it('non-stream', async () => { - const url = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - {} - ); + const url = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: {} + }); expect(url.toString()).to.include('models/model-name:generateContent'); expect(url.toString()).to.not.include(fakeApiSettings); expect(url.toString()).to.not.include('alt=sse'); }); it('default apiVersion', async () => { - const url = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - {} - ); + const url = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: {} + }); expect(url.toString()).to.include(DEFAULT_API_VERSION); }); it('custom baseUrl', async () => { - const url = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - { baseUrl: 'https://my.special.endpoint' } - ); + const url = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: { baseUrl: 'https://my.special.endpoint' } + }); expect(url.toString()).to.include('https://my.special.endpoint'); }); it('non-stream - tunedModels/', async () => { - const url = new RequestUrl( - 'tunedModels/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - {} - ); + const url = new RequestURL({ + model: 'tunedModels/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: {} + }); expect(url.toString()).to.include( 'tunedModels/model-name:generateContent' ); expect(url.toString()).to.not.include(fakeApiSettings); expect(url.toString()).to.not.include('alt=sse'); }); + it('prompt server template', async () => { + const url = new RequestURL({ + templateId: 'my-template', + task: ServerPromptTemplateTask.TEMPLATE_GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: {} + }); + expect(url.toString()).to.include( + 'templates/my-template:templateGenerateContent' + ); + expect(url.toString()).to.not.include(fakeApiSettings); + }); }); describe('getHeaders', () => { const fakeApiSettings: ApiSettings = { @@ -112,13 +130,13 @@ describe('request methods', () => { getAuthToken: () => Promise.resolve({ accessToken: 'authtoken' }), getAppCheckToken: () => Promise.resolve({ token: 'appchecktoken' }) }; - const fakeUrl = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - true, - {} - ); + const fakeUrl = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: true, + requestOptions: {} + }); it('adds client headers', async () => { const headers = await getHeaders(fakeUrl); expect(headers.get('x-goog-api-client')).to.match( @@ -140,13 +158,13 @@ describe('request methods', () => { getAuthToken: () => Promise.resolve({ accessToken: 'authtoken' }), getAppCheckToken: () => Promise.resolve({ token: 'appchecktoken' }) }; - const fakeUrl = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - true, - {} - ); + const fakeUrl = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: true, + requestOptions: {} + }); const headers = await getHeaders(fakeUrl); expect(headers.get('X-Firebase-Appid')).to.equal('my-appid'); }); @@ -165,13 +183,13 @@ describe('request methods', () => { getAuthToken: () => Promise.resolve({ accessToken: 'authtoken' }), getAppCheckToken: () => Promise.resolve({ token: 'appchecktoken' }) }; - const fakeUrl = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - true, - {} - ); + const fakeUrl = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: true, + requestOptions: {} + }); const headers = await getHeaders(fakeUrl); expect(headers.get('X-Firebase-Appid')).to.be.null; }); @@ -180,44 +198,44 @@ describe('request methods', () => { expect(headers.get('X-Firebase-AppCheck')).to.equal('appchecktoken'); }); it('ignores app check token header if no appcheck service', async () => { - const fakeUrl = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - { + const fakeUrl = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: { apiKey: 'key', project: 'myproject', appId: 'my-appid', location: 'moon', backend: new VertexAIBackend() }, - true, - {} - ); + stream: true, + requestOptions: {} + }); const headers = await getHeaders(fakeUrl); expect(headers.has('X-Firebase-AppCheck')).to.be.false; }); it('ignores app check token header if returned token was undefined', async () => { - const fakeUrl = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - { + const fakeUrl = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: { apiKey: 'key', project: 'myproject', location: 'moon', //@ts-ignore getAppCheckToken: () => Promise.resolve() }, - true, - {} - ); + stream: true, + requestOptions: {} + }); const headers = await getHeaders(fakeUrl); expect(headers.has('X-Firebase-AppCheck')).to.be.false; }); it('ignores app check token header if returned token had error', async () => { - const fakeUrl = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - { + const fakeUrl = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: { apiKey: 'key', project: 'myproject', appId: 'my-appid', @@ -226,9 +244,9 @@ describe('request methods', () => { getAppCheckToken: () => Promise.resolve({ token: 'dummytoken', error: Error('oops') }) }, - true, - {} - ); + stream: true, + requestOptions: {} + }); const warnStub = stub(console, 'warn'); const headers = await getHeaders(fakeUrl); expect(headers.get('X-Firebase-AppCheck')).to.equal('dummytoken'); @@ -242,36 +260,36 @@ describe('request methods', () => { expect(headers.get('Authorization')).to.equal('Firebase authtoken'); }); it('ignores auth token header if no auth service', async () => { - const fakeUrl = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - { + const fakeUrl = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: { apiKey: 'key', project: 'myproject', appId: 'my-appid', location: 'moon', backend: new VertexAIBackend() }, - true, - {} - ); + stream: true, + requestOptions: {} + }); const headers = await getHeaders(fakeUrl); expect(headers.has('Authorization')).to.be.false; }); it('ignores auth token header if returned token was undefined', async () => { - const fakeUrl = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - { + const fakeUrl = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: { apiKey: 'key', project: 'myproject', location: 'moon', //@ts-ignore getAppCheckToken: () => Promise.resolve() }, - true, - {} - ); + stream: true, + requestOptions: {} + }); const headers = await getHeaders(fakeUrl); expect(headers.has('Authorization')).to.be.false; }); @@ -282,10 +300,12 @@ describe('request methods', () => { ok: true } as Response); const response = await makeRequest( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false + }, '' ); expect(fetchStub).to.be.calledOnce; @@ -300,14 +320,16 @@ describe('request methods', () => { try { await makeRequest( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - '', { - timeout: 180000 - } + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: { + timeout: 180000 + } + }, + '' ); } catch (e) { expect((e as AIError).code).to.equal(AIErrorCode.FETCH_ERROR); @@ -328,10 +350,12 @@ describe('request methods', () => { } as Response); try { await makeRequest( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false + }, '' ); } catch (e) { @@ -353,10 +377,12 @@ describe('request methods', () => { } as Response); try { await makeRequest( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false + }, '' ); } catch (e) { @@ -391,10 +417,12 @@ describe('request methods', () => { } as Response); try { await makeRequest( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false + }, '' ); } catch (e) { @@ -420,10 +448,12 @@ describe('request methods', () => { ); try { await makeRequest( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false + }, '' ); } catch (e) { diff --git a/packages/ai/src/requests/request.ts b/packages/ai/src/requests/request.ts index 90195b4b788..7664765ab03 100644 --- a/packages/ai/src/requests/request.ts +++ b/packages/ai/src/requests/request.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2024 Google LLC + * Copyright 2025 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,62 +19,87 @@ import { ErrorDetails, RequestOptions, AIErrorCode } from '../types'; import { AIError } from '../errors'; import { ApiSettings } from '../types/internal'; import { - DEFAULT_API_VERSION, DEFAULT_DOMAIN, DEFAULT_FETCH_TIMEOUT_MS, LANGUAGE_TAG, PACKAGE_VERSION } from '../constants'; import { logger } from '../logger'; -import { GoogleAIBackend, VertexAIBackend } from '../backend'; import { BackendType } from '../public-types'; -export enum Task { +export const enum Task { GENERATE_CONTENT = 'generateContent', STREAM_GENERATE_CONTENT = 'streamGenerateContent', COUNT_TOKENS = 'countTokens', PREDICT = 'predict' } -export class RequestUrl { +export const enum ServerPromptTemplateTask { + TEMPLATE_GENERATE_CONTENT = 'templateGenerateContent', + TEMPLATE_STREAM_GENERATE_CONTENT = 'templateStreamGenerateContent', + TEMPLATE_PREDICT = 'templatePredict' +} + +interface BaseRequestURLParams { + apiSettings: ApiSettings; + stream: boolean; + requestOptions?: RequestOptions; +} + +/** + * Parameters used to construct the URL of a request to use a model. + */ +interface ModelRequestURLParams extends BaseRequestURLParams { + task: Task; + model: string; + templateId?: never; +} + +/** + * Parameters used to construct the URL of a request to use server side prompt templates. + */ +interface TemplateRequestURLParams extends BaseRequestURLParams { + task: ServerPromptTemplateTask; + templateId: string; + model?: never; +} + +export class RequestURL { constructor( - public model: string, - public task: Task, - public apiSettings: ApiSettings, - public stream: boolean, - public requestOptions?: RequestOptions + public readonly params: ModelRequestURLParams | TemplateRequestURLParams ) {} + toString(): string { const url = new URL(this.baseUrl); // Throws if the URL is invalid - url.pathname = `/${this.apiVersion}/${this.modelPath}:${this.task}`; + url.pathname = this.pathname; url.search = this.queryParams.toString(); return url.toString(); } - private get baseUrl(): string { - return this.requestOptions?.baseUrl || `https://${DEFAULT_DOMAIN}`; - } - - private get apiVersion(): string { - return DEFAULT_API_VERSION; // TODO: allow user-set options if that feature becomes available - } - - private get modelPath(): string { - if (this.apiSettings.backend instanceof GoogleAIBackend) { - return `projects/${this.apiSettings.project}/${this.model}`; - } else if (this.apiSettings.backend instanceof VertexAIBackend) { - return `projects/${this.apiSettings.project}/locations/${this.apiSettings.backend.location}/${this.model}`; + private get pathname(): string { + // We need to construct a different URL if the request is for server side prompt templates, + // since the URL patterns are different. Server side prompt templates expect a templateId + // instead of a model name. + if (this.params.templateId) { + return `${this.params.apiSettings.backend._getTemplatePath( + this.params.apiSettings.project, + this.params.templateId + )}:${this.params.task}`; } else { - throw new AIError( - AIErrorCode.ERROR, - `Invalid backend: ${JSON.stringify(this.apiSettings.backend)}` - ); + return `${this.params.apiSettings.backend._getModelPath( + this.params.apiSettings.project, + (this.params as ModelRequestURLParams).model + )}:${this.params.task}`; } } + private get baseUrl(): string { + return this.params.requestOptions?.baseUrl ?? `https://${DEFAULT_DOMAIN}`; + } + private get queryParams(): URLSearchParams { const params = new URLSearchParams(); - if (this.stream) { + if (this.params.stream) { params.set('alt', 'sse'); } @@ -114,16 +139,16 @@ function getClientHeaders(): string { return loggingTags.join(' '); } -export async function getHeaders(url: RequestUrl): Promise { +export async function getHeaders(url: RequestURL): Promise { const headers = new Headers(); headers.append('Content-Type', 'application/json'); headers.append('x-goog-api-client', getClientHeaders()); - headers.append('x-goog-api-key', url.apiSettings.apiKey); - if (url.apiSettings.automaticDataCollectionEnabled) { - headers.append('X-Firebase-Appid', url.apiSettings.appId); + headers.append('x-goog-api-key', url.params.apiSettings.apiKey); + if (url.params.apiSettings.automaticDataCollectionEnabled) { + headers.append('X-Firebase-Appid', url.params.apiSettings.appId); } - if (url.apiSettings.getAppCheckToken) { - const appCheckToken = await url.apiSettings.getAppCheckToken(); + if (url.params.apiSettings.getAppCheckToken) { + const appCheckToken = await url.params.apiSettings.getAppCheckToken(); if (appCheckToken) { headers.append('X-Firebase-AppCheck', appCheckToken.token); if (appCheckToken.error) { @@ -134,8 +159,8 @@ export async function getHeaders(url: RequestUrl): Promise { } } - if (url.apiSettings.getAuthToken) { - const authToken = await url.apiSettings.getAuthToken(); + if (url.params.apiSettings.getAuthToken) { + const authToken = await url.params.apiSettings.getAuthToken(); if (authToken) { headers.append('Authorization', `Firebase ${authToken.accessToken}`); } @@ -144,55 +169,31 @@ export async function getHeaders(url: RequestUrl): Promise { return headers; } -export async function constructRequest( - model: string, - task: Task, - apiSettings: ApiSettings, - stream: boolean, - body: string, - requestOptions?: RequestOptions -): Promise<{ url: string; fetchOptions: RequestInit }> { - const url = new RequestUrl(model, task, apiSettings, stream, requestOptions); - return { - url: url.toString(), - fetchOptions: { - method: 'POST', - headers: await getHeaders(url), - body - } - }; -} - export async function makeRequest( - model: string, - task: Task, - apiSettings: ApiSettings, - stream: boolean, - body: string, - requestOptions?: RequestOptions + requestUrlParams: TemplateRequestURLParams | ModelRequestURLParams, + body: string ): Promise { - const url = new RequestUrl(model, task, apiSettings, stream, requestOptions); + const url = new RequestURL(requestUrlParams); let response; let fetchTimeoutId: string | number | NodeJS.Timeout | undefined; try { - const request = await constructRequest( - model, - task, - apiSettings, - stream, - body, - requestOptions - ); - // Timeout is 180s by default + const fetchOptions: RequestInit = { + method: 'POST', + headers: await getHeaders(url), + body + }; + + // Timeout is 180s by default. const timeoutMillis = - requestOptions?.timeout != null && requestOptions.timeout >= 0 - ? requestOptions.timeout + requestUrlParams.requestOptions?.timeout != null && + requestUrlParams.requestOptions.timeout >= 0 + ? requestUrlParams.requestOptions.timeout : DEFAULT_FETCH_TIMEOUT_MS; const abortController = new AbortController(); fetchTimeoutId = setTimeout(() => abortController.abort(), timeoutMillis); - request.fetchOptions.signal = abortController.signal; + fetchOptions.signal = abortController.signal; - response = await fetch(request.url, request.fetchOptions); + response = await fetch(url.toString(), fetchOptions); if (!response.ok) { let message = ''; let errorDetails; @@ -225,7 +226,7 @@ export async function makeRequest( `The Firebase AI SDK requires the Firebase AI ` + `API ('firebasevertexai.googleapis.com') to be enabled in your ` + `Firebase project. Enable this API by visiting the Firebase Console ` + - `at https://console.firebase.google.com/project/${url.apiSettings.project}/genai/ ` + + `at https://console.firebase.google.com/project/${url.params.apiSettings.project}/genai/ ` + `and clicking "Get started". If you enabled this API recently, ` + `wait a few minutes for the action to propagate to our systems and ` + `then retry.`,