Skip to content

Commit 00ad8d2

Browse files
authored
Merge c22ab66 into 07c75ea
2 parents 07c75ea + c22ab66 commit 00ad8d2

File tree

3 files changed

+162
-45
lines changed

3 files changed

+162
-45
lines changed

packages/vertexai/src/methods/chrome-adapter.test.ts

Lines changed: 118 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import {
2424
Availability,
2525
LanguageModel,
2626
LanguageModelCreateOptions,
27-
LanguageModelMessageContent
27+
LanguageModelMessage
2828
} from '../types/language-model';
2929
import { match, stub } from 'sinon';
3030
import { GenerateContentRequest, AIErrorCode } from '../types';
@@ -138,7 +138,7 @@ describe('ChromeAdapter', () => {
138138
})
139139
).to.be.false;
140140
});
141-
it('returns false if request content has non-user role', async () => {
141+
it('returns false if request content has "function" role', async () => {
142142
const adapter = new ChromeAdapter(
143143
{
144144
availability: async () => Availability.available
@@ -149,7 +149,7 @@ describe('ChromeAdapter', () => {
149149
await adapter.isAvailable({
150150
contents: [
151151
{
152-
role: 'model',
152+
role: 'function',
153153
parts: []
154154
}
155155
]
@@ -306,7 +306,7 @@ describe('ChromeAdapter', () => {
306306
} as LanguageModel;
307307
const languageModel = {
308308
// eslint-disable-next-line @typescript-eslint/no-unused-vars
309-
prompt: (p: LanguageModelMessageContent[]) => Promise.resolve('')
309+
prompt: (p: LanguageModelMessage[]) => Promise.resolve('')
310310
} as LanguageModel;
311311
const createStub = stub(languageModelProvider, 'create').resolves(
312312
languageModel
@@ -331,8 +331,13 @@ describe('ChromeAdapter', () => {
331331
// Asserts Vertex input type is mapped to Chrome type.
332332
expect(promptStub).to.have.been.calledOnceWith([
333333
{
334-
type: 'text',
335-
content: request.contents[0].parts[0].text
334+
role: request.contents[0].role,
335+
content: [
336+
{
337+
type: 'text',
338+
content: request.contents[0].parts[0].text
339+
}
340+
]
336341
}
337342
]);
338343
// Asserts expected output.
@@ -352,7 +357,7 @@ describe('ChromeAdapter', () => {
352357
} as LanguageModel;
353358
const languageModel = {
354359
// eslint-disable-next-line @typescript-eslint/no-unused-vars
355-
prompt: (p: LanguageModelMessageContent[]) => Promise.resolve('')
360+
prompt: (p: LanguageModelMessage[]) => Promise.resolve('')
356361
} as LanguageModel;
357362
const createStub = stub(languageModelProvider, 'create').resolves(
358363
languageModel
@@ -390,12 +395,17 @@ describe('ChromeAdapter', () => {
390395
// Asserts Vertex input type is mapped to Chrome type.
391396
expect(promptStub).to.have.been.calledOnceWith([
392397
{
393-
type: 'text',
394-
content: request.contents[0].parts[0].text
395-
},
396-
{
397-
type: 'image',
398-
content: match.instanceOf(ImageBitmap)
398+
role: request.contents[0].role,
399+
content: [
400+
{
401+
type: 'text',
402+
content: request.contents[0].parts[0].text
403+
},
404+
{
405+
type: 'image',
406+
content: match.instanceOf(ImageBitmap)
407+
}
408+
]
399409
}
400410
]);
401411
// Asserts expected output.
@@ -412,7 +422,7 @@ describe('ChromeAdapter', () => {
412422
it('honors prompt options', async () => {
413423
const languageModel = {
414424
// eslint-disable-next-line @typescript-eslint/no-unused-vars
415-
prompt: (p: LanguageModelMessageContent[]) => Promise.resolve('')
425+
prompt: (p: LanguageModelMessage[]) => Promise.resolve('')
416426
} as LanguageModel;
417427
const languageModelProvider = {
418428
create: () => Promise.resolve(languageModel)
@@ -436,13 +446,48 @@ describe('ChromeAdapter', () => {
436446
expect(promptStub).to.have.been.calledOnceWith(
437447
[
438448
{
439-
type: 'text',
440-
content: request.contents[0].parts[0].text
449+
role: request.contents[0].role,
450+
content: [
451+
{
452+
type: 'text',
453+
content: request.contents[0].parts[0].text
454+
}
455+
]
441456
}
442457
],
443458
promptOptions
444459
);
445460
});
461+
it('normalizes roles', async () => {
462+
const languageModel = {
463+
// eslint-disable-next-line @typescript-eslint/no-unused-vars
464+
prompt: (p: LanguageModelMessage[]) => Promise.resolve('unused')
465+
} as LanguageModel;
466+
const promptStub = stub(languageModel, 'prompt').resolves('unused');
467+
const languageModelProvider = {
468+
create: () => Promise.resolve(languageModel)
469+
} as LanguageModel;
470+
const adapter = new ChromeAdapter(
471+
languageModelProvider,
472+
'prefer_on_device'
473+
);
474+
const request = {
475+
contents: [{ role: 'model', parts: [{ text: 'unused' }] }]
476+
} as GenerateContentRequest;
477+
await adapter.generateContent(request);
478+
expect(promptStub).to.have.been.calledOnceWith([
479+
{
480+
// Asserts Vertex's "model" role normalized to Chrome's "assistant" role.
481+
role: 'assistant',
482+
content: [
483+
{
484+
type: 'text',
485+
content: request.contents[0].parts[0].text
486+
}
487+
]
488+
}
489+
]);
490+
});
446491
});
447492
describe('countTokens', () => {
448493
it('counts tokens is not yet available', async () => {
@@ -514,8 +559,13 @@ describe('ChromeAdapter', () => {
514559
expect(createStub).to.have.been.calledOnceWith(createOptions);
515560
expect(promptStub).to.have.been.calledOnceWith([
516561
{
517-
type: 'text',
518-
content: request.contents[0].parts[0].text
562+
role: request.contents[0].role,
563+
content: [
564+
{
565+
type: 'text',
566+
content: request.contents[0].parts[0].text
567+
}
568+
]
519569
}
520570
]);
521571
const actual = await toStringArray(response.body!);
@@ -570,12 +620,17 @@ describe('ChromeAdapter', () => {
570620
expect(createStub).to.have.been.calledOnceWith(createOptions);
571621
expect(promptStub).to.have.been.calledOnceWith([
572622
{
573-
type: 'text',
574-
content: request.contents[0].parts[0].text
575-
},
576-
{
577-
type: 'image',
578-
content: match.instanceOf(ImageBitmap)
623+
role: request.contents[0].role,
624+
content: [
625+
{
626+
type: 'text',
627+
content: request.contents[0].parts[0].text
628+
},
629+
{
630+
type: 'image',
631+
content: match.instanceOf(ImageBitmap)
632+
}
633+
]
579634
}
580635
]);
581636
const actual = await toStringArray(response.body!);
@@ -611,13 +666,50 @@ describe('ChromeAdapter', () => {
611666
expect(promptStub).to.have.been.calledOnceWith(
612667
[
613668
{
614-
type: 'text',
615-
content: request.contents[0].parts[0].text
669+
role: request.contents[0].role,
670+
content: [
671+
{
672+
type: 'text',
673+
content: request.contents[0].parts[0].text
674+
}
675+
]
616676
}
617677
],
618678
promptOptions
619679
);
620680
});
681+
it('normalizes roles', async () => {
682+
const languageModel = {
683+
// eslint-disable-next-line @typescript-eslint/no-unused-vars
684+
promptStreaming: p => new ReadableStream()
685+
} as LanguageModel;
686+
const promptStub = stub(languageModel, 'promptStreaming').returns(
687+
new ReadableStream()
688+
);
689+
const languageModelProvider = {
690+
create: () => Promise.resolve(languageModel)
691+
} as LanguageModel;
692+
const adapter = new ChromeAdapter(
693+
languageModelProvider,
694+
'prefer_on_device'
695+
);
696+
const request = {
697+
contents: [{ role: 'model', parts: [{ text: 'unused' }] }]
698+
} as GenerateContentRequest;
699+
await adapter.generateContentStream(request);
700+
expect(promptStub).to.have.been.calledOnceWith([
701+
{
702+
// Asserts Vertex's "model" role normalized to Chrome's "assistant" role.
703+
role: 'assistant',
704+
content: [
705+
{
706+
type: 'text',
707+
content: request.contents[0].parts[0].text
708+
}
709+
]
710+
}
711+
]);
712+
});
621713
});
622714
});
623715

packages/vertexai/src/methods/chrome-adapter.ts

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,16 @@ import {
2323
InferenceMode,
2424
Part,
2525
AIErrorCode,
26-
OnDeviceParams
26+
OnDeviceParams,
27+
Content,
28+
Role
2729
} from '../types';
2830
import {
2931
Availability,
3032
LanguageModel,
31-
LanguageModelMessageContent
33+
LanguageModelMessage,
34+
LanguageModelMessageContent,
35+
LanguageModelMessageRole
3236
} from '../types/language-model';
3337

3438
/**
@@ -109,10 +113,8 @@ export class ChromeAdapter {
109113
*/
110114
async generateContent(request: GenerateContentRequest): Promise<Response> {
111115
const session = await this.createSession();
112-
// TODO: support multiple content objects when Chrome supports
113-
// sequence<LanguageModelMessage>
114116
const contents = await Promise.all(
115-
request.contents[0].parts.map(ChromeAdapter.toLanguageModelMessageContent)
117+
request.contents.map(ChromeAdapter.toLanguageModelMessage)
116118
);
117119
const text = await session.prompt(
118120
contents,
@@ -133,10 +135,8 @@ export class ChromeAdapter {
133135
request: GenerateContentRequest
134136
): Promise<Response> {
135137
const session = await this.createSession();
136-
// TODO: support multiple content objects when Chrome supports
137-
// sequence<LanguageModelMessage>
138138
const contents = await Promise.all(
139-
request.contents[0].parts.map(ChromeAdapter.toLanguageModelMessageContent)
139+
request.contents.map(ChromeAdapter.toLanguageModelMessage)
140140
);
141141
const stream = await session.promptStreaming(
142142
contents,
@@ -163,12 +163,8 @@ export class ChromeAdapter {
163163
}
164164

165165
for (const content of request.contents) {
166-
// Returns false if the request contains multiple roles, eg a chat history.
167-
// TODO: remove this guard once LanguageModelMessage is supported.
168-
if (content.role !== 'user') {
169-
logger.debug(
170-
`Non-user role "${content.role}" rejected for on-device inference.`
171-
);
166+
if (content.role === 'function') {
167+
logger.debug(`"Function" role rejected for on-device inference.`);
172168
return false;
173169
}
174170

@@ -227,6 +223,21 @@ export class ChromeAdapter {
227223
});
228224
}
229225

226+
/**
227+
* Converts Vertex {@link Content} object to a Chrome {@link LanguageModelMessage} object.
228+
*/
229+
private static async toLanguageModelMessage(
230+
content: Content
231+
): Promise<LanguageModelMessage> {
232+
const languageModelMessageContents = await Promise.all(
233+
content.parts.map(ChromeAdapter.toLanguageModelMessageContent)
234+
);
235+
return {
236+
role: ChromeAdapter.toLanguageModelMessageRole(content.role),
237+
content: languageModelMessageContents
238+
};
239+
}
240+
230241
/**
231242
* Converts a Vertex Part object to a Chrome LanguageModelMessageContent object.
232243
*/
@@ -254,6 +265,16 @@ export class ChromeAdapter {
254265
throw new Error('Not yet implemented');
255266
}
256267

268+
/**
269+
* Converts a Vertex {@link Role} string to a {@link LanguageModelMessageRole} string.
270+
*/
271+
private static toLanguageModelMessageRole(
272+
role: Role
273+
): LanguageModelMessageRole {
274+
// Assumes 'function' rule has been filtered by isOnDeviceRequest
275+
return role === 'model' ? 'assistant' : 'user';
276+
}
277+
257278
/**
258279
* Abstracts Chrome session creation.
259280
*

packages/vertexai/src/types/language-model.ts

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17-
17+
/**
18+
* {@see https://github.com/webmachinelearning/prompt-api#full-api-surface-in-web-idl}
19+
*/
1820
export interface LanguageModel extends EventTarget {
1921
create(options?: LanguageModelCreateOptions): Promise<LanguageModel>;
2022
availability(options?: LanguageModelCreateCoreOptions): Promise<Availability>;
@@ -57,12 +59,14 @@ interface LanguageModelExpectedInput {
5759
type: LanguageModelMessageType;
5860
languages?: string[];
5961
}
60-
// TODO: revert to type from Prompt API explainer once it's supported.
61-
export type LanguageModelPrompt = LanguageModelMessageContent[];
62+
export type LanguageModelPrompt =
63+
| LanguageModelMessage[]
64+
| LanguageModelMessageShorthand[]
65+
| string;
6266
type LanguageModelInitialPrompts =
6367
| LanguageModelMessage[]
6468
| LanguageModelMessageShorthand[];
65-
interface LanguageModelMessage {
69+
export interface LanguageModelMessage {
6670
role: LanguageModelMessageRole;
6771
content: LanguageModelMessageContent[];
6872
}
@@ -74,7 +78,7 @@ export interface LanguageModelMessageContent {
7478
type: LanguageModelMessageType;
7579
content: LanguageModelMessageContentValue;
7680
}
77-
type LanguageModelMessageRole = 'system' | 'user' | 'assistant';
81+
export type LanguageModelMessageRole = 'system' | 'user' | 'assistant';
7882
type LanguageModelMessageType = 'text' | 'image' | 'audio';
7983
type LanguageModelMessageContentValue =
8084
| ImageBitmapSource

0 commit comments

Comments
 (0)