Skip to content

Commit 184f8d6

Browse files
authored
Improvements: (#279)
- Fixes bad typing of `GenerativeConfigRuntime` - adds mocks for this usage - renames `generativeConfigRuntime` to `generativeParameters` to match Python naming
1 parent 6751bd1 commit 184f8d6

File tree

7 files changed

+219
-49
lines changed

7 files changed

+219
-49
lines changed

src/collections/generate/config.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import {
1515
GenerativeOpenAIConfigRuntime,
1616
} from '../index.js';
1717

18-
export const generativeConfigRuntime = {
18+
export const generativeParameters = {
1919
/**
2020
* Create a `ModuleConfig<'generative-anthropic', GenerativeConfigRuntimeType<'generative-anthropic'> | undefined>` object for use when performing runtime-specific AI generation using the `generative-anthropic` module.
2121
*

src/collections/generate/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,5 +396,5 @@ class GenerateManager<T> implements Generate<T> {
396396

397397
export default GenerateManager.use;
398398

399-
export { generativeConfigRuntime } from './config.js';
399+
export { generativeParameters } from './config.js';
400400
export { Generate } from './types.js';

src/collections/generate/integration.test.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
/* eslint-disable @typescript-eslint/no-non-null-assertion */
22
/* eslint-disable @typescript-eslint/no-non-null-asserted-optional-chain */
33
import { WeaviateUnsupportedFeatureError } from '../../errors.js';
4-
import weaviate, { WeaviateClient, generativeConfigRuntime } from '../../index.js';
4+
import weaviate, { WeaviateClient } from '../../index.js';
55
import { Collection } from '../collection/index.js';
66
import { GenerateOptions, GroupByOptions } from '../types/index.js';
7+
import { generativeParameters } from './config.js';
78

89
const maybe = process.env.OPENAI_APIKEY ? describe : describe.skip;
910

@@ -493,7 +494,7 @@ maybe('Testing of the collection.generate methods with runtime generative config
493494
nonBlobProperties: ['testProp'],
494495
metadata: true,
495496
},
496-
config: generativeConfigRuntime.openAI({
497+
config: generativeParameters.openAI({
497498
stop: ['\n'],
498499
}),
499500
});

src/collections/generate/mock.test.ts

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import express from 'express';
2+
import { Server as HttpServer } from 'http';
3+
import { Server as GrpcServer, createServer } from 'nice-grpc';
4+
import weaviate, { Collection, GenerativeConfigRuntime, WeaviateClient } from '../..';
5+
import {
6+
HealthCheckRequest,
7+
HealthCheckResponse,
8+
HealthCheckResponse_ServingStatus,
9+
HealthDefinition,
10+
HealthServiceImplementation,
11+
} from '../../proto/google/health/v1/health';
12+
import { GenerativeResult } from '../../proto/v1/generative';
13+
import { SearchReply, SearchRequest, SearchResult } from '../../proto/v1/search_get';
14+
import { WeaviateDefinition, WeaviateServiceImplementation } from '../../proto/v1/weaviate';
15+
import { generativeParameters } from './config';
16+
17+
const mockedSingleGenerative = 'Mocked single response';
18+
const mockedGroupedGenerative = 'Mocked group response';
19+
20+
class GenerateMock {
21+
private grpc: GrpcServer;
22+
private http: HttpServer;
23+
24+
constructor(grpc: GrpcServer, http: HttpServer) {
25+
this.grpc = grpc;
26+
this.http = http;
27+
}
28+
29+
public static use = async (version: string, httpPort: number, grpcPort: number) => {
30+
const httpApp = express();
31+
// Meta endpoint required for client instantiation
32+
httpApp.get('/v1/meta', (req, res) => res.send({ version }));
33+
34+
// gRPC health check required for client instantiation
35+
const healthMockImpl: HealthServiceImplementation = {
36+
check: (request: HealthCheckRequest): Promise<HealthCheckResponse> =>
37+
Promise.resolve(HealthCheckResponse.create({ status: HealthCheckResponse_ServingStatus.SERVING })),
38+
watch: jest.fn(),
39+
};
40+
41+
const grpc = createServer();
42+
grpc.add(HealthDefinition, healthMockImpl);
43+
44+
// Search endpoint returning generative mock data
45+
const weaviateMockImpl: WeaviateServiceImplementation = {
46+
aggregate: jest.fn(),
47+
tenantsGet: jest.fn(),
48+
search: (req: SearchRequest): Promise<SearchReply> => {
49+
expect(req.generative?.grouped?.queries.length).toBeGreaterThan(0);
50+
expect(req.generative?.single?.queries.length).toBeGreaterThan(0);
51+
return Promise.resolve(
52+
SearchReply.fromPartial({
53+
results: [
54+
SearchResult.fromPartial({
55+
properties: {
56+
nonRefProps: { fields: { name: { textValue: 'thing' } } },
57+
},
58+
generative: GenerativeResult.fromPartial({
59+
values: [
60+
{
61+
result: mockedSingleGenerative,
62+
},
63+
],
64+
}),
65+
metadata: {
66+
id: 'b602a271-d5a9-4324-921d-5abe4748d6b5',
67+
},
68+
}),
69+
],
70+
generativeGroupedResults: GenerativeResult.fromPartial({
71+
values: [
72+
{
73+
result: mockedGroupedGenerative,
74+
},
75+
],
76+
}),
77+
})
78+
);
79+
},
80+
batchDelete: jest.fn(),
81+
batchObjects: jest.fn(),
82+
};
83+
grpc.add(WeaviateDefinition, weaviateMockImpl);
84+
85+
await grpc.listen(`localhost:${grpcPort}`);
86+
const http = await httpApp.listen(httpPort);
87+
return new GenerateMock(grpc, http);
88+
};
89+
90+
public close = () => Promise.all([this.http.close(), this.grpc.shutdown()]);
91+
}
92+
93+
describe('Mock testing of generate with runtime config', () => {
94+
let client: WeaviateClient;
95+
let collection: Collection;
96+
let mock: GenerateMock;
97+
98+
beforeAll(async () => {
99+
mock = await GenerateMock.use('1.30.0-rc.1', 8958, 8959);
100+
client = await weaviate.connectToLocal({ port: 8958, grpcPort: 8959 });
101+
collection = client.collections.use('Whatever');
102+
});
103+
104+
afterAll(() => mock.close());
105+
106+
const stringTest = (config: GenerativeConfigRuntime) =>
107+
collection.generate
108+
.fetchObjects({
109+
singlePrompt: 'What is the meaning of life?',
110+
groupedTask: 'What is the meaning of life?',
111+
config: config,
112+
})
113+
.then((res) => {
114+
expect(res.generative?.text).toEqual(mockedGroupedGenerative);
115+
expect(res.objects[0].generative?.text).toEqual(mockedSingleGenerative);
116+
});
117+
118+
const objectTest = (config: GenerativeConfigRuntime) =>
119+
collection.generate
120+
.fetchObjects({
121+
singlePrompt: {
122+
prompt: 'What is the meaning of life?',
123+
},
124+
groupedTask: {
125+
prompt: 'What is the meaning of life?',
126+
},
127+
config: config,
128+
})
129+
.then((res) => {
130+
expect(res.generative?.text).toEqual(mockedGroupedGenerative);
131+
expect(res.objects[0].generative?.text).toEqual(mockedSingleGenerative);
132+
});
133+
134+
const model = { model: 'llama-2' };
135+
136+
const tests: GenerativeConfigRuntime[] = [
137+
generativeParameters.anthropic(),
138+
generativeParameters.anthropic(model),
139+
generativeParameters.anyscale(),
140+
generativeParameters.anyscale(model),
141+
generativeParameters.aws(),
142+
generativeParameters.aws(model),
143+
generativeParameters.azureOpenAI(),
144+
generativeParameters.azureOpenAI(model),
145+
generativeParameters.cohere(),
146+
generativeParameters.cohere(model),
147+
generativeParameters.databricks(),
148+
generativeParameters.databricks(model),
149+
generativeParameters.friendliai(),
150+
generativeParameters.friendliai(model),
151+
generativeParameters.google(),
152+
generativeParameters.google(model),
153+
generativeParameters.mistral(),
154+
generativeParameters.mistral(model),
155+
generativeParameters.nvidia(),
156+
generativeParameters.nvidia(model),
157+
generativeParameters.ollama(),
158+
generativeParameters.ollama(model),
159+
generativeParameters.openAI(),
160+
generativeParameters.openAI(model),
161+
];
162+
163+
tests.forEach((conf) => {
164+
it(`should get the mocked response for ${conf.name} with config: ${conf.config}`, async () => {
165+
await stringTest(conf);
166+
await objectTest(conf);
167+
});
168+
});
169+
});

src/collections/generate/unit.test.ts

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import { GenerativeConfigRuntimeType, ModuleConfig } from '../types';
2-
import { generativeConfigRuntime } from './config';
2+
import { generativeParameters } from './config';
33

44
// only tests fields that must be mapped from some public name to a gRPC name, e.g. baseURL -> baseUrl and stop: string[] -> stop: TextArray
5-
describe('Unit testing of the generativeConfigRuntime factory methods', () => {
5+
describe('Unit testing of the generativeParameters factory methods', () => {
66
describe('anthropic', () => {
77
it('with defaults', () => {
8-
const config = generativeConfigRuntime.anthropic();
8+
const config = generativeParameters.anthropic();
99
expect(config).toEqual<
1010
ModuleConfig<'generative-anthropic', GenerativeConfigRuntimeType<'generative-anthropic'> | undefined>
1111
>({
@@ -14,7 +14,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => {
1414
});
1515
});
1616
it('with values', () => {
17-
const config = generativeConfigRuntime.anthropic({
17+
const config = generativeParameters.anthropic({
1818
baseURL: 'http://localhost:8080',
1919
stopSequences: ['a', 'b', 'c'],
2020
});
@@ -32,7 +32,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => {
3232

3333
describe('anyscale', () => {
3434
it('with defaults', () => {
35-
const config = generativeConfigRuntime.anyscale();
35+
const config = generativeParameters.anyscale();
3636
expect(config).toEqual<
3737
ModuleConfig<'generative-anyscale', GenerativeConfigRuntimeType<'generative-anyscale'> | undefined>
3838
>({
@@ -41,7 +41,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => {
4141
});
4242
});
4343
it('with values', () => {
44-
const config = generativeConfigRuntime.anyscale({
44+
const config = generativeParameters.anyscale({
4545
baseURL: 'http://localhost:8080',
4646
});
4747
expect(config).toEqual<
@@ -57,7 +57,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => {
5757

5858
describe('aws', () => {
5959
it('with defaults', () => {
60-
const config = generativeConfigRuntime.aws();
60+
const config = generativeParameters.aws();
6161
expect(config).toEqual<
6262
ModuleConfig<'generative-aws', GenerativeConfigRuntimeType<'generative-aws'> | undefined>
6363
>({
@@ -69,7 +69,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => {
6969

7070
describe('azure-openai', () => {
7171
it('with defaults', () => {
72-
const config = generativeConfigRuntime.azureOpenAI();
72+
const config = generativeParameters.azureOpenAI();
7373
expect(config).toEqual<
7474
ModuleConfig<'generative-azure-openai', GenerativeConfigRuntimeType<'generative-azure-openai'>>
7575
>({
@@ -78,7 +78,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => {
7878
});
7979
});
8080
it('with values', () => {
81-
const config = generativeConfigRuntime.azureOpenAI({
81+
const config = generativeParameters.azureOpenAI({
8282
baseURL: 'http://localhost:8080',
8383
model: 'model',
8484
stop: ['a', 'b', 'c'],
@@ -99,7 +99,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => {
9999

100100
describe('cohere', () => {
101101
it('with defaults', () => {
102-
const config = generativeConfigRuntime.cohere();
102+
const config = generativeParameters.cohere();
103103
expect(config).toEqual<
104104
ModuleConfig<'generative-cohere', GenerativeConfigRuntimeType<'generative-cohere'> | undefined>
105105
>({
@@ -108,7 +108,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => {
108108
});
109109
});
110110
it('with values', () => {
111-
const config = generativeConfigRuntime.cohere({
111+
const config = generativeParameters.cohere({
112112
baseURL: 'http://localhost:8080',
113113
stopSequences: ['a', 'b', 'c'],
114114
});
@@ -126,7 +126,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => {
126126

127127
describe('databricks', () => {
128128
it('with defaults', () => {
129-
const config = generativeConfigRuntime.databricks();
129+
const config = generativeParameters.databricks();
130130
expect(config).toEqual<
131131
ModuleConfig<
132132
'generative-databricks',
@@ -138,7 +138,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => {
138138
});
139139
});
140140
it('with values', () => {
141-
const config = generativeConfigRuntime.databricks({
141+
const config = generativeParameters.databricks({
142142
stop: ['a', 'b', 'c'],
143143
});
144144
expect(config).toEqual<
@@ -157,7 +157,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => {
157157

158158
describe('friendliai', () => {
159159
it('with defaults', () => {
160-
const config = generativeConfigRuntime.friendliai();
160+
const config = generativeParameters.friendliai();
161161
expect(config).toEqual<
162162
ModuleConfig<
163163
'generative-friendliai',
@@ -169,7 +169,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => {
169169
});
170170
});
171171
it('with values', () => {
172-
const config = generativeConfigRuntime.friendliai({
172+
const config = generativeParameters.friendliai({
173173
baseURL: 'http://localhost:8080',
174174
});
175175
expect(config).toEqual<
@@ -188,7 +188,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => {
188188

189189
describe('mistral', () => {
190190
it('with defaults', () => {
191-
const config = generativeConfigRuntime.mistral();
191+
const config = generativeParameters.mistral();
192192
expect(config).toEqual<
193193
ModuleConfig<'generative-mistral', GenerativeConfigRuntimeType<'generative-mistral'> | undefined>
194194
>({
@@ -197,7 +197,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => {
197197
});
198198
});
199199
it('with values', () => {
200-
const config = generativeConfigRuntime.mistral({
200+
const config = generativeParameters.mistral({
201201
baseURL: 'http://localhost:8080',
202202
});
203203
expect(config).toEqual<
@@ -213,7 +213,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => {
213213

214214
describe('nvidia', () => {
215215
it('with defaults', () => {
216-
const config = generativeConfigRuntime.nvidia();
216+
const config = generativeParameters.nvidia();
217217
expect(config).toEqual<
218218
ModuleConfig<'generative-nvidia', GenerativeConfigRuntimeType<'generative-nvidia'> | undefined>
219219
>({
@@ -222,7 +222,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => {
222222
});
223223
});
224224
it('with values', () => {
225-
const config = generativeConfigRuntime.nvidia({
225+
const config = generativeParameters.nvidia({
226226
baseURL: 'http://localhost:8080',
227227
});
228228
expect(config).toEqual<
@@ -238,7 +238,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => {
238238

239239
describe('ollama', () => {
240240
it('with defaults', () => {
241-
const config = generativeConfigRuntime.ollama();
241+
const config = generativeParameters.ollama();
242242
expect(config).toEqual<
243243
ModuleConfig<'generative-ollama', GenerativeConfigRuntimeType<'generative-ollama'> | undefined>
244244
>({
@@ -250,7 +250,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => {
250250

251251
describe('openai', () => {
252252
it('with defaults', () => {
253-
const config = generativeConfigRuntime.openAI();
253+
const config = generativeParameters.openAI();
254254
expect(config).toEqual<
255255
ModuleConfig<'generative-openai', GenerativeConfigRuntimeType<'generative-openai'>>
256256
>({
@@ -259,7 +259,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => {
259259
});
260260
});
261261
it('with values', () => {
262-
const config = generativeConfigRuntime.openAI({
262+
const config = generativeParameters.openAI({
263263
baseURL: 'http://localhost:8080',
264264
model: 'model',
265265
stop: ['a', 'b', 'c'],

0 commit comments

Comments
 (0)