diff --git a/src/collections/generate/config.ts b/src/collections/generate/config.ts index 1ab12b04..2f5eb714 100644 --- a/src/collections/generate/config.ts +++ b/src/collections/generate/config.ts @@ -15,7 +15,7 @@ import { GenerativeOpenAIConfigRuntime, } from '../index.js'; -export const generativeConfigRuntime = { +export const generativeParameters = { /** * Create a `ModuleConfig<'generative-anthropic', GenerativeConfigRuntimeType<'generative-anthropic'> | undefined>` object for use when performing runtime-specific AI generation using the `generative-anthropic` module. * diff --git a/src/collections/generate/index.ts b/src/collections/generate/index.ts index bf010fd9..3a5fb1b3 100644 --- a/src/collections/generate/index.ts +++ b/src/collections/generate/index.ts @@ -396,5 +396,5 @@ class GenerateManager implements Generate { export default GenerateManager.use; -export { generativeConfigRuntime } from './config.js'; +export { generativeParameters } from './config.js'; export { Generate } from './types.js'; diff --git a/src/collections/generate/integration.test.ts b/src/collections/generate/integration.test.ts index 50a7bed9..39e8e351 100644 --- a/src/collections/generate/integration.test.ts +++ b/src/collections/generate/integration.test.ts @@ -1,9 +1,10 @@ /* eslint-disable @typescript-eslint/no-non-null-assertion */ /* eslint-disable @typescript-eslint/no-non-null-asserted-optional-chain */ import { WeaviateUnsupportedFeatureError } from '../../errors.js'; -import weaviate, { WeaviateClient, generativeConfigRuntime } from '../../index.js'; +import weaviate, { WeaviateClient } from '../../index.js'; import { Collection } from '../collection/index.js'; import { GenerateOptions, GroupByOptions } from '../types/index.js'; +import { generativeParameters } from './config.js'; const maybe = process.env.OPENAI_APIKEY ? describe : describe.skip; @@ -493,7 +494,7 @@ maybe('Testing of the collection.generate methods with runtime generative config nonBlobProperties: ['testProp'], metadata: true, }, - config: generativeConfigRuntime.openAI({ + config: generativeParameters.openAI({ stop: ['\n'], }), }); diff --git a/src/collections/generate/mock.test.ts b/src/collections/generate/mock.test.ts new file mode 100644 index 00000000..d4a0c025 --- /dev/null +++ b/src/collections/generate/mock.test.ts @@ -0,0 +1,169 @@ +import express from 'express'; +import { Server as HttpServer } from 'http'; +import { Server as GrpcServer, createServer } from 'nice-grpc'; +import weaviate, { Collection, GenerativeConfigRuntime, WeaviateClient } from '../..'; +import { + HealthCheckRequest, + HealthCheckResponse, + HealthCheckResponse_ServingStatus, + HealthDefinition, + HealthServiceImplementation, +} from '../../proto/google/health/v1/health'; +import { GenerativeResult } from '../../proto/v1/generative'; +import { SearchReply, SearchRequest, SearchResult } from '../../proto/v1/search_get'; +import { WeaviateDefinition, WeaviateServiceImplementation } from '../../proto/v1/weaviate'; +import { generativeParameters } from './config'; + +const mockedSingleGenerative = 'Mocked single response'; +const mockedGroupedGenerative = 'Mocked group response'; + +class GenerateMock { + private grpc: GrpcServer; + private http: HttpServer; + + constructor(grpc: GrpcServer, http: HttpServer) { + this.grpc = grpc; + this.http = http; + } + + public static use = async (version: string, httpPort: number, grpcPort: number) => { + const httpApp = express(); + // Meta endpoint required for client instantiation + httpApp.get('/v1/meta', (req, res) => res.send({ version })); + + // gRPC health check required for client instantiation + const healthMockImpl: HealthServiceImplementation = { + check: (request: HealthCheckRequest): Promise => + Promise.resolve(HealthCheckResponse.create({ status: HealthCheckResponse_ServingStatus.SERVING })), + watch: jest.fn(), + }; + + const grpc = createServer(); + grpc.add(HealthDefinition, healthMockImpl); + + // Search endpoint returning generative mock data + const weaviateMockImpl: WeaviateServiceImplementation = { + aggregate: jest.fn(), + tenantsGet: jest.fn(), + search: (req: SearchRequest): Promise => { + expect(req.generative?.grouped?.queries.length).toBeGreaterThan(0); + expect(req.generative?.single?.queries.length).toBeGreaterThan(0); + return Promise.resolve( + SearchReply.fromPartial({ + results: [ + SearchResult.fromPartial({ + properties: { + nonRefProps: { fields: { name: { textValue: 'thing' } } }, + }, + generative: GenerativeResult.fromPartial({ + values: [ + { + result: mockedSingleGenerative, + }, + ], + }), + metadata: { + id: 'b602a271-d5a9-4324-921d-5abe4748d6b5', + }, + }), + ], + generativeGroupedResults: GenerativeResult.fromPartial({ + values: [ + { + result: mockedGroupedGenerative, + }, + ], + }), + }) + ); + }, + batchDelete: jest.fn(), + batchObjects: jest.fn(), + }; + grpc.add(WeaviateDefinition, weaviateMockImpl); + + await grpc.listen(`localhost:${grpcPort}`); + const http = await httpApp.listen(httpPort); + return new GenerateMock(grpc, http); + }; + + public close = () => Promise.all([this.http.close(), this.grpc.shutdown()]); +} + +describe('Mock testing of generate with runtime config', () => { + let client: WeaviateClient; + let collection: Collection; + let mock: GenerateMock; + + beforeAll(async () => { + mock = await GenerateMock.use('1.30.0-rc.1', 8958, 8959); + client = await weaviate.connectToLocal({ port: 8958, grpcPort: 8959 }); + collection = client.collections.use('Whatever'); + }); + + afterAll(() => mock.close()); + + const stringTest = (config: GenerativeConfigRuntime) => + collection.generate + .fetchObjects({ + singlePrompt: 'What is the meaning of life?', + groupedTask: 'What is the meaning of life?', + config: config, + }) + .then((res) => { + expect(res.generative?.text).toEqual(mockedGroupedGenerative); + expect(res.objects[0].generative?.text).toEqual(mockedSingleGenerative); + }); + + const objectTest = (config: GenerativeConfigRuntime) => + collection.generate + .fetchObjects({ + singlePrompt: { + prompt: 'What is the meaning of life?', + }, + groupedTask: { + prompt: 'What is the meaning of life?', + }, + config: config, + }) + .then((res) => { + expect(res.generative?.text).toEqual(mockedGroupedGenerative); + expect(res.objects[0].generative?.text).toEqual(mockedSingleGenerative); + }); + + const model = { model: 'llama-2' }; + + const tests: GenerativeConfigRuntime[] = [ + generativeParameters.anthropic(), + generativeParameters.anthropic(model), + generativeParameters.anyscale(), + generativeParameters.anyscale(model), + generativeParameters.aws(), + generativeParameters.aws(model), + generativeParameters.azureOpenAI(), + generativeParameters.azureOpenAI(model), + generativeParameters.cohere(), + generativeParameters.cohere(model), + generativeParameters.databricks(), + generativeParameters.databricks(model), + generativeParameters.friendliai(), + generativeParameters.friendliai(model), + generativeParameters.google(), + generativeParameters.google(model), + generativeParameters.mistral(), + generativeParameters.mistral(model), + generativeParameters.nvidia(), + generativeParameters.nvidia(model), + generativeParameters.ollama(), + generativeParameters.ollama(model), + generativeParameters.openAI(), + generativeParameters.openAI(model), + ]; + + tests.forEach((conf) => { + it(`should get the mocked response for ${conf.name} with config: ${conf.config}`, async () => { + await stringTest(conf); + await objectTest(conf); + }); + }); +}); diff --git a/src/collections/generate/unit.test.ts b/src/collections/generate/unit.test.ts index 63ff17b9..54ecb335 100644 --- a/src/collections/generate/unit.test.ts +++ b/src/collections/generate/unit.test.ts @@ -1,11 +1,11 @@ import { GenerativeConfigRuntimeType, ModuleConfig } from '../types'; -import { generativeConfigRuntime } from './config'; +import { generativeParameters } from './config'; // only tests fields that must be mapped from some public name to a gRPC name, e.g. baseURL -> baseUrl and stop: string[] -> stop: TextArray -describe('Unit testing of the generativeConfigRuntime factory methods', () => { +describe('Unit testing of the generativeParameters factory methods', () => { describe('anthropic', () => { it('with defaults', () => { - const config = generativeConfigRuntime.anthropic(); + const config = generativeParameters.anthropic(); expect(config).toEqual< ModuleConfig<'generative-anthropic', GenerativeConfigRuntimeType<'generative-anthropic'> | undefined> >({ @@ -14,7 +14,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => { }); }); it('with values', () => { - const config = generativeConfigRuntime.anthropic({ + const config = generativeParameters.anthropic({ baseURL: 'http://localhost:8080', stopSequences: ['a', 'b', 'c'], }); @@ -32,7 +32,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => { describe('anyscale', () => { it('with defaults', () => { - const config = generativeConfigRuntime.anyscale(); + const config = generativeParameters.anyscale(); expect(config).toEqual< ModuleConfig<'generative-anyscale', GenerativeConfigRuntimeType<'generative-anyscale'> | undefined> >({ @@ -41,7 +41,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => { }); }); it('with values', () => { - const config = generativeConfigRuntime.anyscale({ + const config = generativeParameters.anyscale({ baseURL: 'http://localhost:8080', }); expect(config).toEqual< @@ -57,7 +57,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => { describe('aws', () => { it('with defaults', () => { - const config = generativeConfigRuntime.aws(); + const config = generativeParameters.aws(); expect(config).toEqual< ModuleConfig<'generative-aws', GenerativeConfigRuntimeType<'generative-aws'> | undefined> >({ @@ -69,7 +69,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => { describe('azure-openai', () => { it('with defaults', () => { - const config = generativeConfigRuntime.azureOpenAI(); + const config = generativeParameters.azureOpenAI(); expect(config).toEqual< ModuleConfig<'generative-azure-openai', GenerativeConfigRuntimeType<'generative-azure-openai'>> >({ @@ -78,7 +78,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => { }); }); it('with values', () => { - const config = generativeConfigRuntime.azureOpenAI({ + const config = generativeParameters.azureOpenAI({ baseURL: 'http://localhost:8080', model: 'model', stop: ['a', 'b', 'c'], @@ -99,7 +99,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => { describe('cohere', () => { it('with defaults', () => { - const config = generativeConfigRuntime.cohere(); + const config = generativeParameters.cohere(); expect(config).toEqual< ModuleConfig<'generative-cohere', GenerativeConfigRuntimeType<'generative-cohere'> | undefined> >({ @@ -108,7 +108,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => { }); }); it('with values', () => { - const config = generativeConfigRuntime.cohere({ + const config = generativeParameters.cohere({ baseURL: 'http://localhost:8080', stopSequences: ['a', 'b', 'c'], }); @@ -126,7 +126,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => { describe('databricks', () => { it('with defaults', () => { - const config = generativeConfigRuntime.databricks(); + const config = generativeParameters.databricks(); expect(config).toEqual< ModuleConfig< 'generative-databricks', @@ -138,7 +138,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => { }); }); it('with values', () => { - const config = generativeConfigRuntime.databricks({ + const config = generativeParameters.databricks({ stop: ['a', 'b', 'c'], }); expect(config).toEqual< @@ -157,7 +157,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => { describe('friendliai', () => { it('with defaults', () => { - const config = generativeConfigRuntime.friendliai(); + const config = generativeParameters.friendliai(); expect(config).toEqual< ModuleConfig< 'generative-friendliai', @@ -169,7 +169,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => { }); }); it('with values', () => { - const config = generativeConfigRuntime.friendliai({ + const config = generativeParameters.friendliai({ baseURL: 'http://localhost:8080', }); expect(config).toEqual< @@ -188,7 +188,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => { describe('mistral', () => { it('with defaults', () => { - const config = generativeConfigRuntime.mistral(); + const config = generativeParameters.mistral(); expect(config).toEqual< ModuleConfig<'generative-mistral', GenerativeConfigRuntimeType<'generative-mistral'> | undefined> >({ @@ -197,7 +197,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => { }); }); it('with values', () => { - const config = generativeConfigRuntime.mistral({ + const config = generativeParameters.mistral({ baseURL: 'http://localhost:8080', }); expect(config).toEqual< @@ -213,7 +213,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => { describe('nvidia', () => { it('with defaults', () => { - const config = generativeConfigRuntime.nvidia(); + const config = generativeParameters.nvidia(); expect(config).toEqual< ModuleConfig<'generative-nvidia', GenerativeConfigRuntimeType<'generative-nvidia'> | undefined> >({ @@ -222,7 +222,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => { }); }); it('with values', () => { - const config = generativeConfigRuntime.nvidia({ + const config = generativeParameters.nvidia({ baseURL: 'http://localhost:8080', }); expect(config).toEqual< @@ -238,7 +238,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => { describe('ollama', () => { it('with defaults', () => { - const config = generativeConfigRuntime.ollama(); + const config = generativeParameters.ollama(); expect(config).toEqual< ModuleConfig<'generative-ollama', GenerativeConfigRuntimeType<'generative-ollama'> | undefined> >({ @@ -250,7 +250,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => { describe('openai', () => { it('with defaults', () => { - const config = generativeConfigRuntime.openAI(); + const config = generativeParameters.openAI(); expect(config).toEqual< ModuleConfig<'generative-openai', GenerativeConfigRuntimeType<'generative-openai'>> >({ @@ -259,7 +259,7 @@ describe('Unit testing of the generativeConfigRuntime factory methods', () => { }); }); it('with values', () => { - const config = generativeConfigRuntime.openAI({ + const config = generativeParameters.openAI({ baseURL: 'http://localhost:8080', model: 'model', stop: ['a', 'b', 'c'], diff --git a/src/collections/serialize/index.ts b/src/collections/serialize/index.ts index 3bb06477..e9a74973 100644 --- a/src/collections/serialize/index.ts +++ b/src/collections/serialize/index.ts @@ -852,40 +852,40 @@ export class Serialize { const provider = GenerativeProvider.fromPartial({ returnMetadata: opts?.metadata }); switch (generative.name) { case 'generative-anthropic': - provider.anthropic = await withImages(generative.config, opts?.images, opts?.imageProperties); + provider.anthropic = await withImages(generative.config || {}, opts?.images, opts?.imageProperties); break; case 'generative-anyscale': - provider.anyscale = generative.config; + provider.anyscale = generative.config || {}; break; case 'generative-aws': - provider.aws = await withImages(generative.config, opts?.images, opts?.imageProperties); + provider.aws = await withImages(generative.config || {}, opts?.images, opts?.imageProperties); break; case 'generative-cohere': - provider.cohere = generative.config; + provider.cohere = generative.config || {}; break; case 'generative-databricks': - provider.databricks = generative.config; + provider.databricks = generative.config || {}; break; case 'generative-dummy': - provider.dummy = generative.config; + provider.dummy = generative.config || {}; break; case 'generative-friendliai': - provider.friendliai = generative.config; + provider.friendliai = generative.config || {}; break; case 'generative-google': - provider.google = await withImages(generative.config, opts?.images, opts?.imageProperties); + provider.google = await withImages(generative.config || {}, opts?.images, opts?.imageProperties); break; case 'generative-mistral': - provider.mistral = generative.config; + provider.mistral = generative.config || {}; break; case 'generative-nvidia': - provider.nvidia = generative.config; + provider.nvidia = generative.config || {}; break; case 'generative-ollama': - provider.ollama = await withImages(generative.config, opts?.images, opts?.imageProperties); + provider.ollama = await withImages(generative.config || {}, opts?.images, opts?.imageProperties); break; case 'generative-openai': - provider.openai = await withImages(generative.config, opts?.images, opts?.imageProperties); + provider.openai = await withImages(generative.config || {}, opts?.images, opts?.imageProperties); break; } return provider; diff --git a/src/collections/types/generate.ts b/src/collections/types/generate.ts index 53bf208c..edd16e71 100644 --- a/src/collections/types/generate.ts +++ b/src/collections/types/generate.ts @@ -122,18 +122,18 @@ export type GroupedTask = { type omitFields = 'images' | 'imageProperties'; export type GenerativeConfigRuntime = - | ModuleConfig<'generative-anthropic', GenerativeConfigRuntimeType<'generative-anthropic'>> - | ModuleConfig<'generative-anyscale', GenerativeConfigRuntimeType<'generative-anyscale'>> - | ModuleConfig<'generative-aws', GenerativeConfigRuntimeType<'generative-aws'>> + | ModuleConfig<'generative-anthropic', GenerativeConfigRuntimeType<'generative-anthropic'> | undefined> + | ModuleConfig<'generative-anyscale', GenerativeConfigRuntimeType<'generative-anyscale'> | undefined> + | ModuleConfig<'generative-aws', GenerativeConfigRuntimeType<'generative-aws'> | undefined> | ModuleConfig<'generative-azure-openai', GenerativeConfigRuntimeType<'generative-azure-openai'>> - | ModuleConfig<'generative-cohere', GenerativeConfigRuntimeType<'generative-cohere'>> - | ModuleConfig<'generative-databricks', GenerativeConfigRuntimeType<'generative-databricks'>> - | ModuleConfig<'generative-dummy', GenerativeConfigRuntimeType<'generative-dummy'>> - | ModuleConfig<'generative-friendliai', GenerativeConfigRuntimeType<'generative-friendliai'>> - | ModuleConfig<'generative-google', GenerativeConfigRuntimeType<'generative-google'>> - | ModuleConfig<'generative-mistral', GenerativeConfigRuntimeType<'generative-mistral'>> - | ModuleConfig<'generative-nvidia', GenerativeConfigRuntimeType<'generative-nvidia'>> - | ModuleConfig<'generative-ollama', GenerativeConfigRuntimeType<'generative-ollama'>> + | ModuleConfig<'generative-cohere', GenerativeConfigRuntimeType<'generative-cohere'> | undefined> + | ModuleConfig<'generative-databricks', GenerativeConfigRuntimeType<'generative-databricks'> | undefined> + | ModuleConfig<'generative-dummy', GenerativeConfigRuntimeType<'generative-dummy'> | undefined> + | ModuleConfig<'generative-friendliai', GenerativeConfigRuntimeType<'generative-friendliai'> | undefined> + | ModuleConfig<'generative-google', GenerativeConfigRuntimeType<'generative-google'> | undefined> + | ModuleConfig<'generative-mistral', GenerativeConfigRuntimeType<'generative-mistral'> | undefined> + | ModuleConfig<'generative-nvidia', GenerativeConfigRuntimeType<'generative-nvidia'> | undefined> + | ModuleConfig<'generative-ollama', GenerativeConfigRuntimeType<'generative-ollama'> | undefined> | ModuleConfig<'generative-openai', GenerativeConfigRuntimeType<'generative-openai'>>; export type GenerativeConfigRuntimeType = G extends 'generative-anthropic'