diff --git a/e2e/sample-apps/modular.js b/e2e/sample-apps/modular.js index 9e943e04494..4c5238d44dc 100644 --- a/e2e/sample-apps/modular.js +++ b/e2e/sample-apps/modular.js @@ -58,7 +58,12 @@ import { onValue, off } from 'firebase/database'; -import { getGenerativeModel, getVertexAI, VertexAI } from 'firebase/vertexai'; +import { + getGenerativeModel, + getVertexAI, + InferenceMode, + VertexAI +} from 'firebase/vertexai'; import { getDataConnect, DataConnect } from 'firebase/data-connect'; /** diff --git a/packages/vertexai/src/api.ts b/packages/vertexai/src/api.ts index 236ca73ce87..2f6de198608 100644 --- a/packages/vertexai/src/api.ts +++ b/packages/vertexai/src/api.ts @@ -31,6 +31,7 @@ import { import { VertexAIError } from './errors'; import { VertexAIModel, GenerativeModel, ImagenModel } from './models'; import { ChromeAdapter } from './methods/chrome-adapter'; +import { LanguageModel } from './types/language-model'; export { ChatSession } from './methods/chat-session'; export * from './requests/schema-builder'; @@ -95,7 +96,11 @@ export function getGenerativeModel( return new GenerativeModel( vertexAI, inCloudParams, - new ChromeAdapter(hybridParams.mode, hybridParams.onDeviceParams), + new ChromeAdapter( + window.LanguageModel as LanguageModel, + hybridParams.mode, + hybridParams.onDeviceParams + ), requestOptions ); } diff --git a/packages/vertexai/src/methods/chrome-adapter.test.ts b/packages/vertexai/src/methods/chrome-adapter.test.ts new file mode 100644 index 00000000000..b11fb9c937e --- /dev/null +++ b/packages/vertexai/src/methods/chrome-adapter.test.ts @@ -0,0 +1,310 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { expect, use } from 'chai'; +import sinonChai from 'sinon-chai'; +import chaiAsPromised from 'chai-as-promised'; +import { ChromeAdapter } from './chrome-adapter'; +import { + Availability, + LanguageModel, + LanguageModelCreateOptions +} from '../types/language-model'; +import { stub } from 'sinon'; +import { GenerateContentRequest } from '../types'; + +use(sinonChai); +use(chaiAsPromised); + +describe('ChromeAdapter', () => { + describe('isAvailable', () => { + it('returns false if mode is only cloud', async () => { + const adapter = new ChromeAdapter(undefined, 'only_in_cloud'); + expect( + await adapter.isAvailable({ + contents: [] + }) + ).to.be.false; + }); + it('returns false if AI API is undefined', async () => { + const adapter = new ChromeAdapter(undefined, 'prefer_on_device'); + expect( + await adapter.isAvailable({ + contents: [] + }) + ).to.be.false; + }); + it('returns false if LanguageModel API is undefined', async () => { + const adapter = new ChromeAdapter( + {} as LanguageModel, + 'prefer_on_device' + ); + expect( + await adapter.isAvailable({ + contents: [] + }) + ).to.be.false; + }); + it('returns false if request contents empty', async () => { + const adapter = new ChromeAdapter( + {} as LanguageModel, + 'prefer_on_device' + ); + expect( + await adapter.isAvailable({ + contents: [] + }) + ).to.be.false; + }); + it('returns false if request content has function role', async () => { + const adapter = new ChromeAdapter( + {} as LanguageModel, + 'prefer_on_device' + ); + expect( + await adapter.isAvailable({ + contents: [ + { + role: 'function', + parts: [] + } + ] + }) + ).to.be.false; + }); + it('returns false if request content has multiple parts', async () => { + const adapter = new ChromeAdapter( + {} as LanguageModel, + 'prefer_on_device' + ); + expect( + await adapter.isAvailable({ + contents: [ + { + role: 'user', + parts: [{ text: 'a' }, { text: 'b' }] + } + ] + }) + ).to.be.false; + }); + it('returns false if request content has non-text part', async () => { + const adapter = new ChromeAdapter( + {} as LanguageModel, + 'prefer_on_device' + ); + expect( + await adapter.isAvailable({ + contents: [ + { + role: 'user', + parts: [{ inlineData: { mimeType: 'a', data: 'b' } }] + } + ] + }) + ).to.be.false; + }); + it('returns false if request system instruction has function role', async () => { + const adapter = new ChromeAdapter( + {} as LanguageModel, + 'prefer_on_device' + ); + expect( + await adapter.isAvailable({ + contents: [], + systemInstruction: { + role: 'function', + parts: [] + } + }) + ).to.be.false; + }); + it('returns false if request system instruction has multiple parts', async () => { + const adapter = new ChromeAdapter( + {} as LanguageModel, + 'prefer_on_device' + ); + expect( + await adapter.isAvailable({ + contents: [], + systemInstruction: { + role: 'function', + parts: [{ text: 'a' }, { text: 'b' }] + } + }) + ).to.be.false; + }); + it('returns false if request system instruction has non-text part', async () => { + const adapter = new ChromeAdapter( + {} as LanguageModel, + 'prefer_on_device' + ); + expect( + await adapter.isAvailable({ + contents: [], + systemInstruction: { + role: 'function', + parts: [{ inlineData: { mimeType: 'a', data: 'b' } }] + } + }) + ).to.be.false; + }); + it('returns true if model is readily available', async () => { + const languageModelProvider = { + availability: () => Promise.resolve(Availability.available) + } as LanguageModel; + const adapter = new ChromeAdapter( + languageModelProvider, + 'prefer_on_device' + ); + expect( + await adapter.isAvailable({ + contents: [{ role: 'user', parts: [{ text: 'hi' }] }] + }) + ).to.be.true; + }); + it('returns false and triggers download when model is available after download', async () => { + const languageModelProvider = { + availability: () => Promise.resolve(Availability.downloadable), + create: () => Promise.resolve({}) + } as LanguageModel; + const createStub = stub(languageModelProvider, 'create').resolves( + {} as LanguageModel + ); + const onDeviceParams = {} as LanguageModelCreateOptions; + const adapter = new ChromeAdapter( + languageModelProvider, + 'prefer_on_device', + onDeviceParams + ); + expect( + await adapter.isAvailable({ + contents: [{ role: 'user', parts: [{ text: 'hi' }] }] + }) + ).to.be.false; + expect(createStub).to.have.been.calledOnceWith(onDeviceParams); + }); + it('avoids redundant downloads', async () => { + const languageModelProvider = { + availability: () => Promise.resolve(Availability.downloadable), + create: () => Promise.resolve({}) + } as LanguageModel; + const downloadPromise = new Promise(() => { + /* never resolves */ + }); + const createStub = stub(languageModelProvider, 'create').returns( + downloadPromise + ); + const adapter = new ChromeAdapter(languageModelProvider); + await adapter.isAvailable({ + contents: [{ role: 'user', parts: [{ text: 'hi' }] }] + }); + await adapter.isAvailable({ + contents: [{ role: 'user', parts: [{ text: 'hi' }] }] + }); + expect(createStub).to.have.been.calledOnce; + }); + it('clears state when download completes', async () => { + const languageModelProvider = { + availability: () => Promise.resolve(Availability.downloadable), + create: () => Promise.resolve({}) + } as LanguageModel; + let resolveDownload; + const downloadPromise = new Promise(resolveCallback => { + resolveDownload = resolveCallback; + }); + const createStub = stub(languageModelProvider, 'create').returns( + downloadPromise + ); + const adapter = new ChromeAdapter(languageModelProvider); + await adapter.isAvailable({ + contents: [{ role: 'user', parts: [{ text: 'hi' }] }] + }); + resolveDownload!(); + await adapter.isAvailable({ + contents: [{ role: 'user', parts: [{ text: 'hi' }] }] + }); + expect(createStub).to.have.been.calledTwice; + }); + it('returns false when model is never available', async () => { + const languageModelProvider = { + availability: () => Promise.resolve(Availability.unavailable), + create: () => Promise.resolve({}) + } as LanguageModel; + const adapter = new ChromeAdapter( + languageModelProvider, + 'prefer_on_device' + ); + expect( + await adapter.isAvailable({ + contents: [{ role: 'user', parts: [{ text: 'hi' }] }] + }) + ).to.be.false; + }); + }); + describe('generateContentOnDevice', () => { + it('generates content', async () => { + const languageModelProvider = { + create: () => Promise.resolve({}) + } as LanguageModel; + const languageModel = { + prompt: i => Promise.resolve(i) + } as LanguageModel; + const createStub = stub(languageModelProvider, 'create').resolves( + languageModel + ); + const promptOutput = 'hi'; + const promptStub = stub(languageModel, 'prompt').resolves(promptOutput); + const onDeviceParams = { + systemPrompt: 'be yourself' + } as LanguageModelCreateOptions; + const adapter = new ChromeAdapter( + languageModelProvider, + 'prefer_on_device', + onDeviceParams + ); + const request = { + contents: [{ role: 'user', parts: [{ text: 'anything' }] }] + } as GenerateContentRequest; + const response = await adapter.generateContentOnDevice(request); + // Asserts initialization params are proxied. + expect(createStub).to.have.been.calledOnceWith(onDeviceParams); + // Asserts Vertex input type is mapped to Chrome type. + expect(promptStub).to.have.been.calledOnceWith([ + { + role: request.contents[0].role, + content: [ + { + type: 'text', + content: request.contents[0].parts[0].text + } + ] + } + ]); + // Asserts expected output. + expect(await response.json()).to.deep.equal({ + candidates: [ + { + content: { + parts: [{ text: promptOutput }] + } + } + ] + }); + }); + }); +}); diff --git a/packages/vertexai/src/methods/chrome-adapter.ts b/packages/vertexai/src/methods/chrome-adapter.ts index 26ecd55c2da..10844079c03 100644 --- a/packages/vertexai/src/methods/chrome-adapter.ts +++ b/packages/vertexai/src/methods/chrome-adapter.ts @@ -15,37 +15,213 @@ * limitations under the License. */ -import { GenerateContentRequest, InferenceMode } from '../types'; -import { LanguageModelCreateOptions } from '../types/language-model'; +import { + Content, + GenerateContentRequest, + InferenceMode, + Part, + Role +} from '../types'; +import { + Availability, + LanguageModel, + LanguageModelCreateOptions, + LanguageModelMessage, + LanguageModelMessageRole, + LanguageModelMessageContent +} from '../types/language-model'; /** * Defines an inference "backend" that uses Chrome's on-device model, * and encapsulates logic for detecting when on-device is possible. */ export class ChromeAdapter { + private isDownloading = false; + private downloadPromise: Promise | undefined; + private oldSession: LanguageModel | undefined; constructor( + private languageModelProvider?: LanguageModel, private mode?: InferenceMode, private onDeviceParams?: LanguageModelCreateOptions ) {} - // eslint-disable-next-line @typescript-eslint/no-unused-vars + + /** + * Checks if a given request can be made on-device. + * + *
    Encapsulates a few concerns: + *
  1. the mode
  2. + *
  3. API existence
  4. + *
  5. prompt formatting
  6. + *
  7. model availability, including triggering download if necessary
  8. + *
+ * + *

Pros: callers needn't be concerned with details of on-device availability.

+ *

Cons: this method spans a few concerns and splits request validation from usage. + * If instance variables weren't already part of the API, we could consider a better + * separation of concerns.

+ */ async isAvailable(request: GenerateContentRequest): Promise { - return false; + // Returns false if we should only use in-cloud inference. + if (this.mode === 'only_in_cloud') { + return false; + } + // Returns false if the on-device inference API is undefined.; + if (!this.languageModelProvider) { + return false; + } + // Returns false if the request can't be run on-device. + if (!ChromeAdapter.isOnDeviceRequest(request)) { + return false; + } + const availability = await this.languageModelProvider.availability(); + switch (availability) { + case Availability.available: + // Returns true only if a model is immediately available. + return true; + case Availability.downloadable: + // Triggers async download if model is downloadable. + this.download(); + default: + return false; + } } + + /** + * Generates content on device. + * + *

This is comparable to {@link GenerativeModel.generateContent} for generating content in + * Cloud.

+ * @param request a standard Vertex {@link GenerateContentRequest} + * @returns {@link Response}, so we can reuse common response formatting. + */ async generateContentOnDevice( - // eslint-disable-next-line @typescript-eslint/no-unused-vars request: GenerateContentRequest ): Promise { + const session = await this.createSession( + // TODO: normalize on-device params during construction. + this.onDeviceParams || {} + ); + const messages = ChromeAdapter.toLanguageModelMessages(request.contents); + const text = await session.prompt(messages); return { json: () => Promise.resolve({ candidates: [ { content: { - parts: [{ text: '' }] + parts: [{ text }] } } ] }) } as Response; } + + /** + * Asserts inference for the given request can be performed by an on-device model. + */ + private static isOnDeviceRequest(request: GenerateContentRequest): boolean { + // Returns false if the prompt is empty. + if (request.contents.length === 0) { + return false; + } + + // Applies the same checks as above, but for each content item. + for (const content of request.contents) { + if (content.role === 'function') { + return false; + } + + if (content.parts.length > 1) { + return false; + } + + if (!content.parts[0].text) { + return false; + } + } + + return true; + } + + /** + * Triggers the download of an on-device model. + * + *

Chrome only downloads models as needed. Chrome knows a model is needed when code calls + * LanguageModel.create.

+ * + *

Since Chrome manages the download, the SDK can only avoid redundant download requests by + * tracking if a download has previously been requested.

+ */ + private download(): void { + if (this.isDownloading) { + return; + } + this.isDownloading = true; + this.downloadPromise = this.languageModelProvider + ?.create(this.onDeviceParams) + .then(() => { + this.isDownloading = false; + }); + } + + /** + * Converts a Vertex role string to a Chrome role string. + */ + private static toOnDeviceRole(role: Role): LanguageModelMessageRole { + return role === 'model' ? 'assistant' : 'user'; + } + + /** + * Converts a Vertex Content object to a Chrome LanguageModelMessage object. + */ + private static toLanguageModelMessages( + contents: Content[] + ): LanguageModelMessage[] { + return contents.map(c => ({ + role: ChromeAdapter.toOnDeviceRole(c.role), + content: c.parts.map(ChromeAdapter.toLanguageModelMessageContent) + })); + } + + /** + * Converts a Vertex Part object to a Chrome LanguageModelMessageContent object. + */ + private static toLanguageModelMessageContent( + part: Part + ): LanguageModelMessageContent { + if (part.text) { + return { + type: 'text', + content: part.text + }; + } + // Assumes contents have been verified to contain only a single TextPart. + // TODO: support other input types + throw new Error('Not yet implemented'); + } + + /** + * Abstracts Chrome session creation. + * + *

Chrome uses a multi-turn session for all inference. Vertex uses single-turn for all + * inference. To map the Vertex API to Chrome's API, the SDK creates a new session for all + * inference.

+ * + *

Chrome will remove a model from memory if it's no longer in use, so this method ensures a + * new session is created before an old session is destroyed.

+ */ + private async createSession( + // TODO: define a default value, since these are optional. + options: LanguageModelCreateOptions + ): Promise { + // TODO: could we use this.onDeviceParams instead of passing in options? + const newSession = await this.languageModelProvider!.create(options); + if (this.oldSession) { + this.oldSession.destroy(); + } + // Holds session reference, so model isn't unloaded from memory. + this.oldSession = newSession; + return newSession; + } } diff --git a/packages/vertexai/src/types/language-model.ts b/packages/vertexai/src/types/language-model.ts index e564ca467b4..88354d0aeec 100644 --- a/packages/vertexai/src/types/language-model.ts +++ b/packages/vertexai/src/types/language-model.ts @@ -32,7 +32,7 @@ export interface LanguageModel extends EventTarget { ): Promise; destroy(): undefined; } -enum Availability { +export enum Availability { 'unavailable', 'downloadable', 'downloading', @@ -56,14 +56,14 @@ interface LanguageModelExpectedInput { type: LanguageModelMessageType; languages?: string[]; } -type LanguageModelPrompt = +export type LanguageModelPrompt = | LanguageModelMessage[] | LanguageModelMessageShorthand[] | string; type LanguageModelInitialPrompts = | LanguageModelMessage[] | LanguageModelMessageShorthand[]; -interface LanguageModelMessage { +export interface LanguageModelMessage { role: LanguageModelMessageRole; content: LanguageModelMessageContent[]; } @@ -71,11 +71,11 @@ interface LanguageModelMessageShorthand { role: LanguageModelMessageRole; content: string; } -interface LanguageModelMessageContent { +export interface LanguageModelMessageContent { type: LanguageModelMessageType; content: LanguageModelMessageContentValue; } -type LanguageModelMessageRole = 'system' | 'user' | 'assistant'; +export type LanguageModelMessageRole = 'system' | 'user' | 'assistant'; type LanguageModelMessageType = 'text' | 'image' | 'audio'; type LanguageModelMessageContentValue = | ImageBitmapSource diff --git a/repo-scripts/changelog-generator/tsconfig.json b/repo-scripts/changelog-generator/tsconfig.json index 38bdb7035e4..cffe622284d 100644 --- a/repo-scripts/changelog-generator/tsconfig.json +++ b/repo-scripts/changelog-generator/tsconfig.json @@ -3,7 +3,8 @@ "strict": true, "outDir": "dist", "lib": [ - "ESNext" + "ESNext", + "dom" ], "module": "CommonJS", "moduleResolution": "node", diff --git a/yarn.lock b/yarn.lock index 51ede769d03..d5ea91a7093 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2938,11 +2938,9 @@ "@types/node" "*" "@types/cors@^2.8.12": - version "2.8.17" - resolved "https://registry.npmjs.org/@types/cors/-/cors-2.8.17.tgz#5d718a5e494a8166f569d986794e49c48b216b2b" - integrity sha512-8CGDvrBj1zgo2qE+oS3pOCyYNqCPryMWY2bGfwA0dcfopWGgxs+78df0Rs3rc9THP4JkOhLsAa+15VdpAqkcUA== - dependencies: - "@types/node" "*" + version "2.8.12" + resolved "https://registry.npmjs.org/@types/cors/-/cors-2.8.12.tgz" + integrity sha512-vt+kDhq/M2ayberEtJcIN/hxXy1Pk+59g2FV/ZQceeaTyCtCucjL2Q7FXlFjtWn4n15KCr1NE2lNNFhp0lEThw== "@types/deep-eql@*": version "4.0.2"