Skip to content

Commit 55c05a0

Browse files
dlarocquegsiddh
authored andcommitted
Hybrid inference code changes
1 parent 39505cc commit 55c05a0

15 files changed

+1430
-105
lines changed

packages/ai/src/api.test.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,21 @@ describe('Top level API', () => {
102102
expect(genModel).to.be.an.instanceOf(GenerativeModel);
103103
expect(genModel.model).to.equal('publishers/google/models/my-model');
104104
});
105+
it('getGenerativeModel with HybridParams sets a default model', () => {
106+
const genModel = getGenerativeModel(fakeAI, {
107+
mode: 'only_on_device'
108+
});
109+
expect(genModel.model).to.equal(
110+
`publishers/google/models/${GenerativeModel.DEFAULT_HYBRID_IN_CLOUD_MODEL}`
111+
);
112+
});
113+
it('getGenerativeModel with HybridParams honors a model override', () => {
114+
const genModel = getGenerativeModel(fakeAI, {
115+
mode: 'prefer_on_device',
116+
inCloudParams: { model: 'my-model' }
117+
});
118+
expect(genModel.model).to.equal('publishers/google/models/my-model');
119+
});
105120
it('getImagenModel throws if no model is provided', () => {
106121
try {
107122
getImagenModel(fakeAI, {} as ImagenModelParams);

packages/ai/src/api.ts

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import { AIService } from './service';
2323
import { AI, AIOptions, VertexAI, VertexAIOptions } from './public-types';
2424
import {
2525
ImagenModelParams,
26+
HybridParams,
2627
ModelParams,
2728
RequestOptions,
2829
AIErrorCode
@@ -31,6 +32,8 @@ import { AIError } from './errors';
3132
import { AIModel, GenerativeModel, ImagenModel } from './models';
3233
import { encodeInstanceIdentifier } from './helpers';
3334
import { GoogleAIBackend, VertexAIBackend } from './backend';
35+
import { ChromeAdapter } from './methods/chrome-adapter';
36+
import { LanguageModel } from './types/language-model';
3437

3538
export { ChatSession } from './methods/chat-session';
3639
export * from './requests/schema-builder';
@@ -147,16 +150,36 @@ export function getAI(
147150
*/
148151
export function getGenerativeModel(
149152
ai: AI,
150-
modelParams: ModelParams,
153+
modelParams: ModelParams | HybridParams,
151154
requestOptions?: RequestOptions
152155
): GenerativeModel {
153-
if (!modelParams.model) {
156+
// Uses the existence of HybridParams.mode to clarify the type of the modelParams input.
157+
const hybridParams = modelParams as HybridParams;
158+
let inCloudParams: ModelParams;
159+
if (hybridParams.mode) {
160+
inCloudParams = hybridParams.inCloudParams || {
161+
model: GenerativeModel.DEFAULT_HYBRID_IN_CLOUD_MODEL
162+
};
163+
} else {
164+
inCloudParams = modelParams as ModelParams;
165+
}
166+
167+
if (!inCloudParams.model) {
154168
throw new AIError(
155169
AIErrorCode.NO_MODEL,
156170
`Must provide a model name. Example: getGenerativeModel({ model: 'my-model-name' })`
157171
);
158172
}
159-
return new GenerativeModel(ai, modelParams, requestOptions);
173+
return new GenerativeModel(
174+
ai,
175+
inCloudParams,
176+
new ChromeAdapter(
177+
window.LanguageModel as LanguageModel,
178+
hybridParams.mode,
179+
hybridParams.onDeviceParams
180+
),
181+
requestOptions
182+
);
160183
}
161184

162185
/**

packages/ai/src/backwards-compatbility.test.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import {
2828
} from './api';
2929
import { AI, VertexAI, AIErrorCode } from './public-types';
3030
import { VertexAIBackend } from './backend';
31+
import { ChromeAdapter } from './methods/chrome-adapter';
3132

3233
function assertAssignable<T, _U extends T>(): void {}
3334

@@ -65,7 +66,11 @@ describe('backwards-compatible types', () => {
6566
it('AIModel is backwards compatible with VertexAIModel', () => {
6667
assertAssignable<typeof VertexAIModel, typeof AIModel>();
6768

68-
const model = new GenerativeModel(fakeAI, { model: 'model-name' });
69+
const model = new GenerativeModel(
70+
fakeAI,
71+
{ model: 'model-name' },
72+
new ChromeAdapter()
73+
);
6974
expect(model).to.be.instanceOf(AIModel);
7075
expect(model).to.be.instanceOf(VertexAIModel);
7176
});

packages/ai/src/methods/chat-session.test.ts

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import { GenerateContentStreamResult } from '../types';
2424
import { ChatSession } from './chat-session';
2525
import { ApiSettings } from '../types/internal';
2626
import { VertexAIBackend } from '../backend';
27+
import { ChromeAdapter } from './chrome-adapter';
2728

2829
use(sinonChai);
2930
use(chaiAsPromised);
@@ -46,7 +47,11 @@ describe('ChatSession', () => {
4647
generateContentMethods,
4748
'generateContent'
4849
).rejects('generateContent failed');
49-
const chatSession = new ChatSession(fakeApiSettings, 'a-model');
50+
const chatSession = new ChatSession(
51+
fakeApiSettings,
52+
'a-model',
53+
new ChromeAdapter()
54+
);
5055
await expect(chatSession.sendMessage('hello')).to.be.rejected;
5156
expect(generateContentStub).to.be.calledWith(
5257
fakeApiSettings,
@@ -63,7 +68,11 @@ describe('ChatSession', () => {
6368
generateContentMethods,
6469
'generateContentStream'
6570
).rejects('generateContentStream failed');
66-
const chatSession = new ChatSession(fakeApiSettings, 'a-model');
71+
const chatSession = new ChatSession(
72+
fakeApiSettings,
73+
'a-model',
74+
new ChromeAdapter()
75+
);
6776
await expect(chatSession.sendMessageStream('hello')).to.be.rejected;
6877
expect(generateContentStreamStub).to.be.calledWith(
6978
fakeApiSettings,
@@ -82,7 +91,11 @@ describe('ChatSession', () => {
8291
generateContentMethods,
8392
'generateContentStream'
8493
).resolves({} as unknown as GenerateContentStreamResult);
85-
const chatSession = new ChatSession(fakeApiSettings, 'a-model');
94+
const chatSession = new ChatSession(
95+
fakeApiSettings,
96+
'a-model',
97+
new ChromeAdapter()
98+
);
8699
await chatSession.sendMessageStream('hello');
87100
expect(generateContentStreamStub).to.be.calledWith(
88101
fakeApiSettings,

packages/ai/src/methods/chat-session.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import { validateChatHistory } from './chat-session-helpers';
3030
import { generateContent, generateContentStream } from './generate-content';
3131
import { ApiSettings } from '../types/internal';
3232
import { logger } from '../logger';
33+
import { ChromeAdapter } from './chrome-adapter';
3334

3435
/**
3536
* Do not log a message for this error.
@@ -50,6 +51,7 @@ export class ChatSession {
5051
constructor(
5152
apiSettings: ApiSettings,
5253
public model: string,
54+
private chromeAdapter: ChromeAdapter,
5355
public params?: StartChatParams,
5456
public requestOptions?: RequestOptions
5557
) {
@@ -95,6 +97,7 @@ export class ChatSession {
9597
this._apiSettings,
9698
this.model,
9799
generateContentRequest,
100+
this.chromeAdapter,
98101
this.requestOptions
99102
)
100103
)
@@ -146,6 +149,7 @@ export class ChatSession {
146149
this._apiSettings,
147150
this.model,
148151
generateContentRequest,
152+
this.chromeAdapter,
149153
this.requestOptions
150154
);
151155

0 commit comments

Comments
 (0)