From b88806b7125d6f0e29a5450a9a5a9382275b539b Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Mon, 14 Apr 2025 14:31:44 +0300 Subject: [PATCH 1/2] Add support for: - `reranker-nvidia` when configuring collections - missing `baseURL` options in certain modules - unit tests for reranker configuration --- src/collections/config/types/generative.ts | 2 + src/collections/config/types/reranker.ts | 10 ++ src/collections/config/types/vectorizer.ts | 2 + src/collections/configure/reranker.ts | 17 ++++ src/collections/configure/unit.test.ts | 105 +++++++++++++++++++++ 5 files changed, 136 insertions(+) diff --git a/src/collections/config/types/generative.ts b/src/collections/config/types/generative.ts index 8e32d634..cd8811e3 100644 --- a/src/collections/config/types/generative.ts +++ b/src/collections/config/types/generative.ts @@ -25,6 +25,7 @@ export type GenerativeAnthropicConfig = { }; export type GenerativeAnyscaleConfig = { + baseURL?: string; model?: string; temperature?: number; }; @@ -54,6 +55,7 @@ export type GenerativeFriendliAIConfig = { }; export type GenerativeMistralConfig = { + baseURL?: string; maxTokens?: number; model?: string; temperature?: number; diff --git a/src/collections/config/types/reranker.ts b/src/collections/config/types/reranker.ts index f1cd64de..2357e490 100644 --- a/src/collections/config/types/reranker.ts +++ b/src/collections/config/types/reranker.ts @@ -5,6 +5,7 @@ export type RerankerCohereConfig = { }; export type RerankerVoyageAIConfig = { + baseURL?: string; model?: 'rerank-lite-1' | string; }; @@ -18,9 +19,15 @@ export type RerankerJinaAIConfig = { | string; }; +export type RerankerNvidiaConfig = { + baseURL?: string; + model?: 'nvidia/rerank-qa-mistral-4b' | string; +}; + export type RerankerConfig = | RerankerCohereConfig | RerankerJinaAIConfig + | RerankerNvidiaConfig | RerankerTransformersConfig | RerankerVoyageAIConfig | Record @@ -29,6 +36,7 @@ export type RerankerConfig = export type Reranker = | 'reranker-cohere' | 'reranker-jinaai' + | 'reranker-nvidia' | 'reranker-transformers' | 'reranker-voyageai' | 'none' @@ -38,6 +46,8 @@ export type RerankerConfigType = R extends 'reranker-cohere' ? RerankerCohereConfig : R extends 'reranker-jinaai' ? RerankerJinaAIConfig + : R extends 'reranker-nvidia' + ? RerankerNvidiaConfig : R extends 'reranker-transformers' ? RerankerTransformersConfig : R extends 'reranker-voyageai' diff --git a/src/collections/config/types/vectorizer.ts b/src/collections/config/types/vectorizer.ts index 99dd41f8..6beed48e 100644 --- a/src/collections/config/types/vectorizer.ts +++ b/src/collections/config/types/vectorizer.ts @@ -390,6 +390,8 @@ export type Text2VecNvidiaConfig = { * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/mistral/embeddings) for detailed usage. */ export type Text2VecMistralConfig = { + /** The base URL to use where API requests should go. */ + baseURL?: string; /** The model to use. */ model?: 'mistral-embed' | string; /** Whether to vectorize the collection name. */ diff --git a/src/collections/configure/reranker.ts b/src/collections/configure/reranker.ts index 0070e3ed..3c750866 100644 --- a/src/collections/configure/reranker.ts +++ b/src/collections/configure/reranker.ts @@ -2,6 +2,7 @@ import { ModuleConfig, RerankerCohereConfig, RerankerJinaAIConfig, + RerankerNvidiaConfig, RerankerVoyageAIConfig, } from '../config/types/index.js'; @@ -38,6 +39,22 @@ export default { config: config, }; }, + /** + * Create a `ModuleConfig<'reranker-nvidia', RerankerNvidiaConfig>` object for use when reranking using the `reranker-nvidia` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/nvidia/reranker) for detailed usage. + * + * @param {RerankerNvidiaConfig} [config] The configuration for the `reranker-nvidia` module. + * @returns {ModuleConfig<'reranker-nvidia', RerankerNvidiaConfig | undefined>} The configuration object. + */ + nvidia: ( + config?: RerankerNvidiaConfig + ): ModuleConfig<'reranker-nvidia', RerankerNvidiaConfig | undefined> => { + return { + name: 'reranker-nvidia', + config: config, + }; + }, /** * Create a `ModuleConfig<'reranker-transformers', Record>` object for use when reranking using the `reranker-transformers` module. * diff --git a/src/collections/configure/unit.test.ts b/src/collections/configure/unit.test.ts index e75cd518..8f34fa6b 100644 --- a/src/collections/configure/unit.test.ts +++ b/src/collections/configure/unit.test.ts @@ -12,6 +12,11 @@ import { GenerativeOpenAIConfig, GenerativeXAIConfig, ModuleConfig, + RerankerCohereConfig, + RerankerJinaAIConfig, + RerankerNvidiaConfig, + RerankerTransformersConfig, + RerankerVoyageAIConfig, VectorConfigCreate, } from '../types/index.js'; import { configure } from './index.js'; @@ -1220,6 +1225,7 @@ describe('Unit testing of the vectorizer factory class', () => { it('should create the correct Text2VecMistralConfig type with all values', () => { const config = configure.vectorizer.text2VecMistral({ + baseURL: 'base-url', name: 'test', model: 'model', vectorizeCollectionName: true, @@ -1233,6 +1239,7 @@ describe('Unit testing of the vectorizer factory class', () => { vectorizer: { name: 'text2vec-mistral', config: { + baseURL: 'base-url', model: 'model', vectorizeCollectionName: true, }, @@ -1567,12 +1574,14 @@ describe('Unit testing of the generative factory class', () => { it('should create the correct GenerativeAnyscaleConfig type with all values', () => { const config = configure.generative.anyscale({ + baseURL: 'base-url', model: 'model', temperature: 0.5, }); expect(config).toEqual>({ name: 'generative-anyscale', config: { + baseURL: 'base-url', model: 'model', temperature: 0.5, }, @@ -1749,6 +1758,7 @@ describe('Unit testing of the generative factory class', () => { it('should create the correct GenerativeMistralConfig type with all values', () => { const config = configure.generative.mistral({ + baseURL: 'base-url', maxTokens: 100, model: 'model', temperature: 0.5, @@ -1756,6 +1766,7 @@ describe('Unit testing of the generative factory class', () => { expect(config).toEqual>({ name: 'generative-mistral', config: { + baseURL: 'base-url', maxTokens: 100, model: 'model', temperature: 0.5, @@ -1909,3 +1920,97 @@ describe('Unit testing of the generative factory class', () => { }); }); }); + +describe('Unit testing of the reranker factory class', () => { + it('should create the correct RerankerCohereConfig type using required & default values', () => { + const config = configure.reranker.cohere(); + expect(config).toEqual>({ + name: 'reranker-cohere', + config: undefined, + }); + }); + + it('should create the correct RerankerCohereConfig type with all values', () => { + const config = configure.reranker.cohere({ + model: 'model', + }); + expect(config).toEqual>({ + name: 'reranker-cohere', + config: { + model: 'model', + }, + }); + }); + + it('should create the correct RerankerJinaAIConfig type using required & default values', () => { + const config = configure.reranker.jinaai(); + expect(config).toEqual>({ + name: 'reranker-jinaai', + config: undefined, + }); + }); + + it('should create the correct RerankerJinaAIConfig type with all values', () => { + const config = configure.reranker.jinaai({ + model: 'model', + }); + expect(config).toEqual>({ + name: 'reranker-jinaai', + config: { + model: 'model', + }, + }); + }); + + it('should create the correct RerankerNvidiaConfig type with required & default values', () => { + const config = configure.reranker.nvidia(); + expect(config).toEqual>({ + name: 'reranker-nvidia', + config: undefined, + }); + }); + + it('should create the correct RerankerNvidiaConfig type with all values', () => { + const config = configure.reranker.nvidia({ + baseURL: 'base-url', + model: 'model', + }); + expect(config).toEqual>({ + name: 'reranker-nvidia', + config: { + baseURL: 'base-url', + model: 'model', + }, + }); + }); + + it('should create the correct RerankerTransformersConfig type using required & default values', () => { + const config = configure.reranker.transformers(); + expect(config).toEqual>({ + name: 'reranker-transformers', + config: undefined, + }); + }); + + it('should create the correct RerankerVoyageAIConfig with required & default values', () => { + const config = configure.reranker.voyageAI(); + expect(config).toEqual>({ + name: 'reranker-voyageai', + config: undefined, + }); + }); + + it('should create the correct RerankerVoyageAIConfig type with all values', () => { + const config = configure.reranker.voyageAI({ + baseURL: 'base-url', + model: 'model', + }); + expect(config).toEqual>({ + name: 'reranker-voyageai', + config: { + baseURL: 'base-url', + model: 'model', + }, + }); + }); +}); From 626b5a993e841aa9580a94a786666c7eccafd72c Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Mon, 14 Apr 2025 14:40:51 +0300 Subject: [PATCH 2/2] Fix transformers unit test --- src/collections/configure/unit.test.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/collections/configure/unit.test.ts b/src/collections/configure/unit.test.ts index 8f34fa6b..0eddb278 100644 --- a/src/collections/configure/unit.test.ts +++ b/src/collections/configure/unit.test.ts @@ -1986,9 +1986,9 @@ describe('Unit testing of the reranker factory class', () => { it('should create the correct RerankerTransformersConfig type using required & default values', () => { const config = configure.reranker.transformers(); - expect(config).toEqual>({ + expect(config).toEqual>({ name: 'reranker-transformers', - config: undefined, + config: {}, }); });