diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index b30fa219..4db465ef 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -8,11 +8,16 @@ on: env: WEAVIATE_124: 1.24.26 - WEAVIATE_125: 1.25.30 - WEAVIATE_126: 1.26.14 - WEAVIATE_127: 1.27.11 - WEAVIATE_128: 1.28.4 - WEAVIATE_129: 1.29.0 + WEAVIATE_125: 1.25.34 + WEAVIATE_126: 1.26.17 + WEAVIATE_127: 1.27.15 + WEAVIATE_128: 1.28.11 + WEAVIATE_129: 1.29.1 + WEAVIATE_130: 1.30.0-rc.0-6b9a01c + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true jobs: checks: @@ -41,9 +46,10 @@ jobs: { node: "22.x", weaviate: $WEAVIATE_126}, { node: "22.x", weaviate: $WEAVIATE_127}, { node: "22.x", weaviate: $WEAVIATE_128}, - { node: "18.x", weaviate: $WEAVIATE_129}, - { node: "20.x", weaviate: $WEAVIATE_129}, - { node: "22.x", weaviate: $WEAVIATE_129} + { node: "22.x", weaviate: $WEAVIATE_129}, + { node: "18.x", weaviate: $WEAVIATE_130}, + { node: "20.x", weaviate: $WEAVIATE_130}, + { node: "22.x", weaviate: $WEAVIATE_130} ] steps: - uses: actions/checkout@v3 @@ -74,7 +80,7 @@ jobs: fail-fast: false matrix: versions: [ - { node: "22.x", weaviate: $WEAVIATE_129} + { node: "22.x", weaviate: $WEAVIATE_130} ] steps: - uses: actions/checkout@v3 @@ -133,4 +139,4 @@ jobs: uses: softprops/action-gh-release@v1 with: generate_release_notes: true - draft: true \ No newline at end of file + draft: true diff --git a/ci/compose.sh b/ci/compose.sh old mode 100644 new mode 100755 diff --git a/ci/docker-compose-rbac.yml b/ci/docker-compose-rbac.yml index 57f2b13a..6091f498 100644 --- a/ci/docker-compose-rbac.yml +++ b/ci/docker-compose-rbac.yml @@ -28,4 +28,10 @@ services: AUTHORIZATION_RBAC_ENABLED: "true" AUTHORIZATION_ADMIN_USERS: "admin-user" AUTHORIZATION_VIEWER_USERS: "viewer-user" + AUTHENTICATION_DB_USERS_ENABLED: "true" + AUTHENTICATION_OIDC_ENABLED: "true" + AUTHENTICATION_OIDC_CLIENT_ID: "wcs" + AUTHENTICATION_OIDC_ISSUER: "https://auth.wcs.api.weaviate.io/auth/realms/SeMI" + AUTHENTICATION_OIDC_USERNAME_CLAIM: "email" + AUTHENTICATION_OIDC_GROUPS_CLAIM: "groups" ... diff --git a/package-lock.json b/package-lock.json index 892179b8..fe6c0ab7 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "weaviate-client", - "version": "3.4.2", + "version": "3.5.0-beta.2", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "weaviate-client", - "version": "3.4.2", + "version": "3.5.0-beta.2", "license": "SEE LICENSE IN LICENSE", "dependencies": { "abort-controller-x": "^0.4.3", diff --git a/package.json b/package.json index 9d0656ff..fce830e7 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "weaviate-client", - "version": "3.4.2", + "version": "3.5.0-beta.2", "description": "JS/TS client for Weaviate", "main": "dist/node/cjs/index.js", "type": "module", diff --git a/src/collections/backup/unit.test.ts b/src/collections/backup/unit.test.ts index 8cf971a6..f5131d03 100644 --- a/src/collections/backup/unit.test.ts +++ b/src/collections/backup/unit.test.ts @@ -109,8 +109,8 @@ describe('Mock testing of backup cancellation', () => { let mock: CancelMock; beforeAll(async () => { - mock = await CancelMock.use('1.27.0', 8958, 8959); - client = await weaviate.connectToLocal({ port: 8958, grpcPort: 8959 }); + mock = await CancelMock.use('1.27.0', 8912, 8913); + client = await weaviate.connectToLocal({ port: 8912, grpcPort: 8913 }); }); it('should throw while waiting for creation if backup is cancelled in the meantime', async () => { @@ -133,7 +133,7 @@ describe('Mock testing of backup cancellation', () => { }); it('should return false if creation backup does not exist', async () => { - const success = await client.backup.cancel({ backupId: `${BACKUP_ID}4`, backend: BACKEND }); + const success = await client.backup.cancel({ backupId: `${BACKUP_ID}-unknown`, backend: BACKEND }); expect(success).toBe(false); }); diff --git a/src/collections/config/types/generative.ts b/src/collections/config/types/generative.ts index 667bc347..d275f201 100644 --- a/src/collections/config/types/generative.ts +++ b/src/collections/config/types/generative.ts @@ -15,6 +15,7 @@ export type GenerativeAWSConfig = { }; export type GenerativeAnthropicConfig = { + baseURL?: string; maxTokens?: number; model?: string; stopSequences?: string[]; @@ -58,6 +59,13 @@ export type GenerativeMistralConfig = { temperature?: number; }; +export type GenerativeNvidiaConfig = { + baseURL?: string; + maxTokens?: number; + model?: string; + temperature?: number; +}; + export type GenerativeOllamaConfig = { apiEndpoint?: string; model?: string; diff --git a/src/collections/config/types/vectorizer.ts b/src/collections/config/types/vectorizer.ts index 475c941d..99dd41f8 100644 --- a/src/collections/config/types/vectorizer.ts +++ b/src/collections/config/types/vectorizer.ts @@ -35,6 +35,7 @@ export type Vectorizer = | 'text2vec-gpt4all' | 'text2vec-huggingface' | 'text2vec-jinaai' + | 'text2vec-nvidia' | 'text2vec-mistral' | 'text2vec-ollama' | 'text2vec-openai' @@ -169,6 +170,8 @@ export type Multi2VecGoogleConfig = { textFields?: string[]; /** The video fields used when vectorizing. */ videoFields?: string[]; + /** Length of a video interval in seconds. */ + videoIntervalSeconds?: number; /** The model ID in use. */ modelId?: string; /** The dimensionality of the vector once embedded. */ @@ -223,6 +226,8 @@ export type Multi2VecVoyageAIConfig = { imageFields?: string[]; /** The model to use. */ model?: string; + /** How the output from the model should be encoded on return. */ + outputEncoding?: string; /** The text fields used when vectorizing. */ textFields?: string[]; /** Whether the input should be truncated to fit in the context window. */ @@ -363,6 +368,22 @@ export type Text2VecJinaAIConfig = { /** @deprecated Use `Text2VecJinaAIConfig` instead. */ export type Text2VecJinaConfig = Text2VecJinaAIConfig; +/** + * The configuration for text vectorization using the Nvidia module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/nvidia/embeddings) for detailed usage. + */ +export type Text2VecNvidiaConfig = { + /** The base URL to use where API requests should go. */ + baseURL?: string; + /** The model to use. */ + model?: string; + /** Whether to truncate when vectorising. */ + truncate?: boolean; + /** Whether to vectorize the collection name. */ + vectorizeCollectionName?: boolean; +}; + /** * The configuration for text vectorization using the Mistral module. * @@ -541,6 +562,8 @@ export type VectorizerConfigType = V extends 'img2vec-neural' ? Text2VecHuggingFaceConfig | undefined : V extends 'text2vec-jinaai' ? Text2VecJinaAIConfig | undefined + : V extends 'text2vec-nvidia' + ? Text2VecNvidiaConfig | undefined : V extends 'text2vec-mistral' ? Text2VecMistralConfig | undefined : V extends 'text2vec-ollama' diff --git a/src/collections/configure/generative.ts b/src/collections/configure/generative.ts index 730f2bcb..d4ee3154 100644 --- a/src/collections/configure/generative.ts +++ b/src/collections/configure/generative.ts @@ -8,6 +8,7 @@ import { GenerativeFriendliAIConfig, GenerativeGoogleConfig, GenerativeMistralConfig, + GenerativeNvidiaConfig, GenerativeOllamaConfig, GenerativeOpenAIConfig, GenerativePaLMConfig, @@ -22,6 +23,7 @@ import { GenerativeDatabricksConfigCreate, GenerativeFriendliAIConfigCreate, GenerativeMistralConfigCreate, + GenerativeNvidiaConfigCreate, GenerativeOllamaConfigCreate, GenerativeOpenAIConfigCreate, GenerativePaLMConfigCreate, @@ -169,6 +171,22 @@ export default { config, }; }, + /** + * Create a `ModuleConfig<'generative-nvidia', GenerativeNvidiaConfig | undefined>` object for use when performing AI generation using the `generative-mistral` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/nvidia/generative) for detailed usage. + * + * @param {GenerativeNvidiaConfigCreate} [config] The configuration for the `generative-nvidia` module. + * @returns {ModuleConfig<'generative-nvidia', GenerativeNvidiaConfig | undefined>} The configuration object. + */ + nvidia( + config?: GenerativeNvidiaConfigCreate + ): ModuleConfig<'generative-nvidia', GenerativeNvidiaConfig | undefined> { + return { + name: 'generative-nvidia', + config, + }; + }, /** * Create a `ModuleConfig<'generative-ollama', GenerativeOllamaConfig | undefined>` object for use when performing AI generation using the `generative-ollama` module. * diff --git a/src/collections/configure/types/generative.ts b/src/collections/configure/types/generative.ts index 2b1a18cf..ccf22ec6 100644 --- a/src/collections/configure/types/generative.ts +++ b/src/collections/configure/types/generative.ts @@ -5,6 +5,7 @@ import { GenerativeDatabricksConfig, GenerativeFriendliAIConfig, GenerativeMistralConfig, + GenerativeNvidiaConfig, GenerativeOllamaConfig, GenerativePaLMConfig, } from '../../index.js'; @@ -44,6 +45,8 @@ export type GenerativeFriendliAIConfigCreate = GenerativeFriendliAIConfig; export type GenerativeMistralConfigCreate = GenerativeMistralConfig; +export type GenerativeNvidiaConfigCreate = GenerativeNvidiaConfig; + export type GenerativeOllamaConfigCreate = GenerativeOllamaConfig; export type GenerativeOpenAIConfigCreate = GenerativeOpenAIConfigBaseCreate & { @@ -61,6 +64,7 @@ export type GenerativeConfigCreate = | GenerativeDatabricksConfigCreate | GenerativeFriendliAIConfigCreate | GenerativeMistralConfigCreate + | GenerativeNvidiaConfigCreate | GenerativeOllamaConfigCreate | GenerativeOpenAIConfigCreate | GenerativePaLMConfigCreate @@ -81,6 +85,8 @@ export type GenerativeConfigCreateType = G extends 'generative-anthropic' ? GenerativeFriendliAIConfigCreate : G extends 'generative-mistral' ? GenerativeMistralConfigCreate + : G extends 'generative-nvidia' + ? GenerativeNvidiaConfigCreate : G extends 'generative-ollama' ? GenerativeOllamaConfigCreate : G extends 'generative-openai' diff --git a/src/collections/configure/types/vectorizer.ts b/src/collections/configure/types/vectorizer.ts index 5318e8f4..5391c356 100644 --- a/src/collections/configure/types/vectorizer.ts +++ b/src/collections/configure/types/vectorizer.ts @@ -13,6 +13,7 @@ import { Text2VecHuggingFaceConfig, Text2VecJinaAIConfig, Text2VecMistralConfig, + Text2VecNvidiaConfig, Text2VecOllamaConfig, Text2VecOpenAIConfig, Text2VecTransformersConfig, @@ -198,6 +199,8 @@ export type Text2VecHuggingFaceConfigCreate = Text2VecHuggingFaceConfig; export type Text2VecJinaAIConfigCreate = Text2VecJinaAIConfig; +export type Text2VecNvidiaConfigCreate = Text2VecNvidiaConfig; + export type Text2VecMistralConfigCreate = Text2VecMistralConfig; export type Text2VecOllamaConfigCreate = Text2VecOllamaConfig; @@ -247,6 +250,8 @@ export type VectorizerConfigCreateType = V extends 'img2vec-neural' ? Text2VecHuggingFaceConfigCreate | undefined : V extends 'text2vec-jinaai' ? Text2VecJinaAIConfigCreate | undefined + : V extends 'text2vec-nvidia' + ? Text2VecNvidiaConfigCreate | undefined : V extends 'text2vec-mistral' ? Text2VecMistralConfigCreate | undefined : V extends 'text2vec-ollama' diff --git a/src/collections/configure/unit.test.ts b/src/collections/configure/unit.test.ts index 93c50f65..36a9ce6b 100644 --- a/src/collections/configure/unit.test.ts +++ b/src/collections/configure/unit.test.ts @@ -1161,6 +1161,47 @@ describe('Unit testing of the vectorizer factory class', () => { }); }); + it('should create the correct Text2VecNvidiaConfig type with defaults', () => { + const config = configure.vectorizer.text2VecNvidia(); + expect(config).toEqual>({ + name: undefined, + vectorIndex: { + name: 'hnsw', + config: undefined, + }, + vectorizer: { + name: 'text2vec-nvidia', + config: undefined, + }, + }); + }); + + it('should create the correct Text2VecNvidiaConfig type with all values', () => { + const config = configure.vectorizer.text2VecNvidia({ + name: 'test', + baseURL: 'base-url', + model: 'model', + truncate: true, + vectorizeCollectionName: true, + }); + expect(config).toEqual>({ + name: 'test', + vectorIndex: { + name: 'hnsw', + config: undefined, + }, + vectorizer: { + name: 'text2vec-nvidia', + config: { + baseURL: 'base-url', + model: 'model', + truncate: true, + vectorizeCollectionName: true, + }, + }, + }); + }); + it('should create the correct Text2VecMistralConfig type with defaults', () => { const config = configure.vectorizer.text2VecMistral(); expect(config).toEqual>({ diff --git a/src/collections/configure/vectorizer.ts b/src/collections/configure/vectorizer.ts index c0442658..298b1f01 100644 --- a/src/collections/configure/vectorizer.ts +++ b/src/collections/configure/vectorizer.ts @@ -518,6 +518,19 @@ export const vectorizer = { }, }); }, + text2VecNvidia: ( + opts?: ConfigureTextVectorizerOptions + ): VectorConfigCreate, N, I, 'text2vec-nvidia'> => { + const { name, sourceProperties, vectorIndexConfig, ...config } = opts || {}; + return makeVectorizer(name, { + sourceProperties, + vectorIndexConfig, + vectorizerConfig: { + name: 'text2vec-nvidia', + config: Object.keys(config).length === 0 ? undefined : config, + }, + }); + }, /** * Create a `VectorConfigCreate` object with the vectorizer set to `'text2vec-mistral'`. * diff --git a/src/collections/deserialize/index.ts b/src/collections/deserialize/index.ts index 588c2642..fdf37f3d 100644 --- a/src/collections/deserialize/index.ts +++ b/src/collections/deserialize/index.ts @@ -25,6 +25,8 @@ import { AggregateResult, AggregateText, AggregateType, + GenerativeConfigRuntime, + GenerativeMetadata, PropertiesMetrics, } from '../index.js'; import { referenceFromObjects } from '../references/utils.js'; @@ -207,11 +209,28 @@ export class Deserialize { }; } - public generate(reply: SearchReply): GenerativeReturn { + public generate( + reply: SearchReply + ): GenerativeReturn { return { objects: reply.results.map((result) => { return { - generated: result.metadata?.generativePresent ? result.metadata?.generative : undefined, + generated: result.metadata?.generativePresent + ? result.metadata?.generative + : result.generative + ? result.generative.values[0].result + : undefined, + generative: result.generative + ? { + text: result.generative.values[0].result, + debug: result.generative.values[0].debug, + metadata: result.generative.values[0].metadata as GenerativeMetadata, + } + : result.metadata?.generativePresent + ? { + text: result.metadata?.generative, + } + : undefined, metadata: Deserialize.metadata(result.metadata), properties: this.properties(result.properties), references: this.references(result.properties), @@ -219,7 +238,22 @@ export class Deserialize { vectors: Deserialize.vectors(result.metadata), } as any; }), - generated: reply.generativeGroupedResult, + generated: + reply.generativeGroupedResult !== '' + ? reply.generativeGroupedResult + : reply.generativeGroupedResults + ? reply.generativeGroupedResults.values[0].result + : undefined, + generative: reply.generativeGroupedResults + ? { + text: reply.generativeGroupedResults?.values[0].result, + metadata: reply.generativeGroupedResults?.values[0].metadata as GenerativeMetadata, + } + : reply.generativeGroupedResult !== '' + ? { + text: reply.generativeGroupedResult, + } + : undefined, }; } @@ -252,9 +286,9 @@ export class Deserialize { }; } - public generateGroupBy(reply: SearchReply): GenerativeGroupByReturn { + public generateGroupBy(reply: SearchReply): GenerativeGroupByReturn { const objects: GroupByObject[] = []; - const groups: Record> = {}; + const groups: Record> = {}; reply.groupByResults.forEach((result) => { const objs = result.objects.map((object) => { return { diff --git a/src/collections/generate/config.ts b/src/collections/generate/config.ts new file mode 100644 index 00000000..2f5eb714 --- /dev/null +++ b/src/collections/generate/config.ts @@ -0,0 +1,280 @@ +import { TextArray } from '../../proto/v1/base.js'; +import { ModuleConfig } from '../config/types/index.js'; +import { + GenerativeAWSConfigRuntime, + GenerativeAnthropicConfigRuntime, + GenerativeAnyscaleConfigRuntime, + GenerativeCohereConfigRuntime, + GenerativeConfigRuntimeType, + GenerativeDatabricksConfigRuntime, + GenerativeFriendliAIConfigRuntime, + GenerativeGoogleConfigRuntime, + GenerativeMistralConfigRuntime, + GenerativeNvidiaConfigRuntime, + GenerativeOllamaConfigRuntime, + GenerativeOpenAIConfigRuntime, +} from '../index.js'; + +export const generativeParameters = { + /** + * Create a `ModuleConfig<'generative-anthropic', GenerativeConfigRuntimeType<'generative-anthropic'> | undefined>` object for use when performing runtime-specific AI generation using the `generative-anthropic` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/anthropic/generative) for detailed usage. + * + * @param {GenerativeAnthropicConfigCreateRuntime} [config] The configuration for the `generative-anthropic` module. + * @returns {ModuleConfig<'generative-anthropic', GenerativeAnthropicConfigCreateRuntime | undefined>} The configuration object. + */ + anthropic( + config?: GenerativeAnthropicConfigRuntime + ): ModuleConfig<'generative-anthropic', GenerativeConfigRuntimeType<'generative-anthropic'> | undefined> { + const { baseURL, stopSequences, ...rest } = config || {}; + return { + name: 'generative-anthropic', + config: config + ? { + ...rest, + baseUrl: baseURL, + stopSequences: TextArray.fromPartial({ values: stopSequences }), + } + : undefined, + }; + }, + /** + * Create a `ModuleConfig<'generative-anyscale', GenerativeConfigRuntimeType<'generative-anyscale'> | undefined>` object for use when performing runtime-specific AI generation using the `generative-anyscale` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/anyscale/generative) for detailed usage. + * + * @param {GenerativeAnyscaleConfigRuntime} [config] The configuration for the `generative-aws` module. + * @returns {ModuleConfig<'generative-anyscale', GenerativeConfigRuntimeType<'generative-anyscale'> | undefined>} The configuration object. + */ + anyscale( + config?: GenerativeAnyscaleConfigRuntime + ): ModuleConfig<'generative-anyscale', GenerativeConfigRuntimeType<'generative-anyscale'> | undefined> { + const { baseURL, ...rest } = config || {}; + return { + name: 'generative-anyscale', + config: config + ? { + ...rest, + baseUrl: baseURL, + } + : undefined, + }; + }, + /** + * Create a `ModuleConfig<'generative-aws', GenerativeConfigRuntimeType<'generative-aws'> | undefined>` object for use when performing runtime-specific AI generation using the `generative-aws` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/aws/generative) for detailed usage. + * + * @param {GenerativeAWSConfigRuntime} [config] The configuration for the `generative-aws` module. + * @returns {ModuleConfig<'generative-aws', GenerativeConfigRuntimeType<'generative-aws'> | undefined>} The configuration object. + */ + aws( + config?: GenerativeAWSConfigRuntime + ): ModuleConfig<'generative-aws', GenerativeConfigRuntimeType<'generative-aws'> | undefined> { + return { + name: 'generative-aws', + config, + }; + }, + /** + * Create a `ModuleConfig<'generative-openai', GenerativeConfigRuntimeType<'generative-azure-openai'>>` object for use when performing runtime-specific AI generation using the `generative-openai` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/openai/generative) for detailed usage. + * + * @param {GenerativeAzureOpenAIConfigRuntime} [config] The configuration for the `generative-openai` module. + * @returns {ModuleConfig<'generative-azure-openai', GenerativeConfigRuntimeType<'generative-azure-openai'>>} The configuration object. + */ + azureOpenAI: ( + config?: GenerativeOpenAIConfigRuntime + ): ModuleConfig<'generative-azure-openai', GenerativeConfigRuntimeType<'generative-azure-openai'>> => { + const { baseURL, stop, ...rest } = config || {}; + return { + name: 'generative-azure-openai', + config: config + ? { + ...rest, + baseUrl: baseURL, + isAzure: true, + stop: TextArray.fromPartial({ values: stop }), + } + : { isAzure: true }, + }; + }, + /** + * Create a `ModuleConfig<'generative-cohere', GenerativeConfigRuntimeType<'generative-cohere'> | undefined>` object for use when performing runtime-specific AI generation using the `generative-cohere` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/cohere/generative) for detailed usage. + * + * @param {GenerativeCohereConfigRuntime} [config] The configuration for the `generative-cohere` module. + * @returns {ModuleConfig<'generative-cohere', GenerativeConfigRuntimeType<'generative-cohere'> | undefined>} The configuration object. + */ + cohere: ( + config?: GenerativeCohereConfigRuntime + ): ModuleConfig<'generative-cohere', GenerativeConfigRuntimeType<'generative-cohere'> | undefined> => { + const { baseURL, stopSequences, ...rest } = config || {}; + return { + name: 'generative-cohere', + config: config + ? { + ...rest, + baseUrl: baseURL, + stopSequences: TextArray.fromPartial({ values: stopSequences }), + } + : undefined, + }; + }, + /** + * Create a `ModuleConfig<'generative-databricks', GenerativeConfigRuntimeType<'generative-databricks'> | undefined>` object for use when performing runtime-specific AI generation using the `generative-databricks` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/databricks/generative) for detailed usage. + * + * @param {GenerativeDatabricksConfigRuntime} [config] The configuration for the `generative-databricks` module. + * @returns {ModuleConfig<'generative-databricks', GenerativeConfigRuntimeType<'generative-databricks'> | undefined>} The configuration object. + */ + databricks: ( + config?: GenerativeDatabricksConfigRuntime + ): ModuleConfig< + 'generative-databricks', + GenerativeConfigRuntimeType<'generative-databricks'> | undefined + > => { + const { stop, ...rest } = config || {}; + return { + name: 'generative-databricks', + config: config + ? { + ...rest, + stop: TextArray.fromPartial({ values: stop }), + } + : undefined, + }; + }, + /** + * Create a `ModuleConfig<'generative-friendliai', GenerativeConfigRuntimeType<'generative-friendliai'> | undefined>` object for use when performing runtime-specific AI generation using the `generative-friendliai` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/friendliai/generative) for detailed usage. + * + * @param {GenerativeFriendliAIConfigRuntime} [config] The configuration for the `generative-friendliai` module. + * @returns {ModuleConfig<'generative-databricks', GenerativeConfigRuntimeType<'generative-friendliai'> | undefined>} The configuration object. + */ + friendliai( + config?: GenerativeFriendliAIConfigRuntime + ): ModuleConfig<'generative-friendliai', GenerativeConfigRuntimeType<'generative-friendliai'> | undefined> { + const { baseURL, ...rest } = config || {}; + return { + name: 'generative-friendliai', + config: config + ? { + ...rest, + baseUrl: baseURL, + } + : undefined, + }; + }, + /** + * Create a `ModuleConfig<'generative-mistral', GenerativeConfigRuntimeType<'generative-mistral'> | undefined>` object for use when performing runtime-specific AI generation using the `generative-mistral` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/mistral/generative) for detailed usage. + * + * @param {GenerativeMistralConfigRuntime} [config] The configuration for the `generative-mistral` module. + * @returns {ModuleConfig<'generative-mistral', GenerativeConfigRuntimeType<'generative-mistral'> | undefined>} The configuration object. + */ + mistral( + config?: GenerativeMistralConfigRuntime + ): ModuleConfig<'generative-mistral', GenerativeConfigRuntimeType<'generative-mistral'> | undefined> { + const { baseURL, ...rest } = config || {}; + return { + name: 'generative-mistral', + config: config + ? { + baseUrl: baseURL, + ...rest, + } + : undefined, + }; + }, + /** + * Create a `ModuleConfig<'generative-nvidia', GenerativeConfigRuntimeType<'generative-nvidia'> | undefined>` object for use when performing runtime-specific AI generation using the `generative-mistral` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/nvidia/generative) for detailed usage. + * + * @param {GenerativeNvidiaConfigCreate} [config] The configuration for the `generative-nvidia` module. + * @returns {ModuleConfig<'generative-nvidia', GenerativeConfigRuntimeType<'generative-nvidia'> | undefined>} The configuration object. + */ + nvidia( + config?: GenerativeNvidiaConfigRuntime + ): ModuleConfig<'generative-nvidia', GenerativeConfigRuntimeType<'generative-nvidia'> | undefined> { + const { baseURL, ...rest } = config || {}; + return { + name: 'generative-nvidia', + config: config + ? { + ...rest, + baseUrl: baseURL, + } + : undefined, + }; + }, + /** + * Create a `ModuleConfig<'generative-ollama', GenerativeConfigRuntimeType<'generative-ollama'> | undefined>` object for use when performing runtime-specific AI generation using the `generative-ollama` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/ollama/generative) for detailed usage. + * + * @param {GenerativeOllamaConfigRuntime} [config] The configuration for the `generative-openai` module. + * @returns {ModuleConfig<'generative-ollama', GenerativeConfigRuntimeType<'generative-ollama'> | undefined>} The configuration object. + */ + ollama( + config?: GenerativeOllamaConfigRuntime + ): ModuleConfig<'generative-ollama', GenerativeConfigRuntimeType<'generative-ollama'> | undefined> { + return { + name: 'generative-ollama', + config, + }; + }, + /** + * Create a `ModuleConfig<'generative-openai', GenerativeConfigRuntimeType<'generative-openai'>>` object for use when performing runtime-specific AI generation using the `generative-openai` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/openai/generative) for detailed usage. + * + * @param {GenerativeOpenAIConfigRuntime} [config] The configuration for the `generative-openai` module. + * @returns {ModuleConfig<'generative-openai', GenerativeConfigRuntimeType<'generative-openai'>>} The configuration object. + */ + openAI: ( + config?: GenerativeOpenAIConfigRuntime + ): ModuleConfig<'generative-openai', GenerativeConfigRuntimeType<'generative-openai'>> => { + const { baseURL, stop, ...rest } = config || {}; + return { + name: 'generative-openai', + config: config + ? { + ...rest, + baseUrl: baseURL, + isAzure: false, + stop: TextArray.fromPartial({ values: stop }), + } + : { isAzure: false }, + }; + }, + /** + * Create a `ModuleConfig<'generative-google', GenerativeConfigRuntimeType<'generative-openai'> | undefined>` object for use when performing runtime-specific AI generation using the `generative-google` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/google/generative) for detailed usage. + * + * @param {GenerativeGoogleConfigRuntime} [config] The configuration for the `generative-palm` module. + * @returns {ModuleConfig<'generative-google', GenerativeConfigRuntimeType<'generative-google'> | undefined>} The configuration object. + */ + google: ( + config?: GenerativeGoogleConfigRuntime + ): ModuleConfig<'generative-google', GenerativeConfigRuntimeType<'generative-google'> | undefined> => { + const { stopSequences, ...rest } = config || {}; + return { + name: 'generative-google', + config: config + ? { + ...rest, + stopSequences: TextArray.fromPartial({ values: stopSequences }), + } + : undefined, + }; + }, +}; diff --git a/src/collections/generate/index.ts b/src/collections/generate/index.ts index 3af6fef1..3a5fb1b3 100644 --- a/src/collections/generate/index.ts +++ b/src/collections/generate/index.ts @@ -5,6 +5,7 @@ import { DbVersionSupport } from '../../utils/dbVersion.js'; import { WeaviateInvalidInputError } from '../../errors.js'; import { toBase64FromMedia } from '../../index.js'; +import { GenerativeSearch } from '../../proto/v1/generative.js'; import { SearchReply } from '../../proto/v1/search_get.js'; import { Deserialize } from '../deserialize/index.js'; import { Check } from '../query/check.js'; @@ -28,6 +29,7 @@ import { Serialize } from '../serialize/index.js'; import { GenerateOptions, GenerateReturn, + GenerativeConfigRuntime, GenerativeGroupByReturn, GenerativeReturn, GroupByOptions, @@ -51,107 +53,134 @@ class GenerateManager implements Generate { return new GenerateManager(new Check(connection, name, dbVersionSupport, consistencyLevel, tenant)); } - private async parseReply(reply: SearchReply) { + private async parseReply(reply: SearchReply) { const deserialize = await Deserialize.use(this.check.dbVersionSupport); - return deserialize.generate(reply); + return deserialize.generate(reply); } - private async parseGroupByReply( + private async parseGroupByReply( opts: SearchOptions | GroupByOptions | undefined, reply: SearchReply ) { const deserialize = await Deserialize.use(this.check.dbVersionSupport); return Serialize.search.isGroupBy(opts) ? deserialize.generateGroupBy(reply) - : deserialize.generate(reply); + : deserialize.generate(reply); } - public fetchObjects( - generate: GenerateOptions, + public fetchObjects( + generate: GenerateOptions, opts?: FetchObjectsOptions - ): Promise> { - return this.check - .fetchObjects(opts) - .then(({ search }) => + ): Promise> { + return Promise.all([ + this.check.fetchObjects(opts), + this.check.supportForSingleGroupedGenerative(), + this.check.supportForGenerativeConfigRuntime(generate.config), + ]) + .then(async ([{ search }, supportsSingleGrouped]) => search.withFetch({ ...Serialize.search.fetchObjects(opts), - generative: Serialize.generative(generate), + generative: await Serialize.generative({ supportsSingleGrouped }, generate), }) ) .then((reply) => this.parseReply(reply)); } - public bm25( + public bm25( query: string, - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseBm25Options - ): Promise>; - public bm25( + ): Promise>; + public bm25( query: string, - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByBm25Options - ): Promise>; - public bm25(query: string, generate: GenerateOptions, opts?: Bm25Options): GenerateReturn { - return this.check - .bm25(opts) - .then(({ search }) => + ): Promise>; + public bm25( + query: string, + generate: GenerateOptions, + opts?: Bm25Options + ): GenerateReturn { + return Promise.all([ + this.check.bm25(opts), + this.check.supportForSingleGroupedGenerative(), + this.check.supportForGenerativeConfigRuntime(generate.config), + ]) + .then(async ([{ search }, supportsSingleGrouped]) => search.withBm25({ ...Serialize.search.bm25(query, opts), - generative: Serialize.generative(generate), + generative: await Serialize.generative({ supportsSingleGrouped }, generate), }) ) .then((reply) => this.parseGroupByReply(opts, reply)); } - public hybrid( + public hybrid( query: string, - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseHybridOptions - ): Promise>; - public hybrid( + ): Promise>; + public hybrid( query: string, - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByHybridOptions - ): Promise>; - public hybrid(query: string, generate: GenerateOptions, opts?: HybridOptions): GenerateReturn { - return this.check - .hybridSearch(opts) - .then(({ search, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets }) => - search.withHybrid({ - ...Serialize.search.hybrid( - { - query, - supportsTargets, - supportsVectorsForTargets, - supportsWeightsForTargets, - }, - opts - ), - generative: Serialize.generative(generate), - }) + ): Promise>; + public hybrid( + query: string, + generate: GenerateOptions, + opts?: HybridOptions + ): GenerateReturn { + return Promise.all([ + this.check.hybridSearch(opts), + this.check.supportForSingleGroupedGenerative(), + this.check.supportForGenerativeConfigRuntime(generate.config), + ]) + .then( + async ([ + { search, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets }, + supportsSingleGrouped, + ]) => + search.withHybrid({ + ...Serialize.search.hybrid( + { + query, + supportsTargets, + supportsVectorsForTargets, + supportsWeightsForTargets, + }, + opts + ), + generative: await Serialize.generative({ supportsSingleGrouped }, generate), + }) ) .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearImage( + public nearImage( image: string | Buffer, - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseNearOptions - ): Promise>; - public nearImage( + ): Promise>; + public nearImage( image: string | Buffer, - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByNearOptions - ): Promise>; - public nearImage( + ): Promise>; + public nearImage( image: string | Buffer, - generate: GenerateOptions, + generate: GenerateOptions, opts?: NearOptions - ): GenerateReturn { - return this.check - .nearSearch(opts) - .then(({ search, supportsTargets, supportsWeightsForTargets }) => - toBase64FromMedia(image).then((image) => + ): GenerateReturn { + return Promise.all([ + this.check.nearSearch(opts), + this.check.supportForSingleGroupedGenerative(), + this.check.supportForGenerativeConfigRuntime(generate.config), + ]) + .then(([{ search, supportsTargets, supportsWeightsForTargets }, supportsSingleGrouped]) => + Promise.all([ + toBase64FromMedia(image), + Serialize.generative({ supportsSingleGrouped }, generate), + ]).then(([image, generative]) => search.withNearImage({ ...Serialize.search.nearImage( { @@ -161,27 +190,34 @@ class GenerateManager implements Generate { }, opts ), - generative: Serialize.generative(generate), + generative, }) ) ) .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearObject( + public nearObject( id: string, - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseNearOptions - ): Promise>; - public nearObject( + ): Promise>; + public nearObject( id: string, - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByNearOptions - ): Promise>; - public nearObject(id: string, generate: GenerateOptions, opts?: NearOptions): GenerateReturn { - return this.check - .nearSearch(opts) - .then(({ search, supportsTargets, supportsWeightsForTargets }) => + ): Promise>; + public nearObject( + id: string, + generate: GenerateOptions, + opts?: NearOptions + ): GenerateReturn { + return Promise.all([ + this.check.nearSearch(opts), + this.check.supportForSingleGroupedGenerative(), + this.check.supportForGenerativeConfigRuntime(generate.config), + ]) + .then(async ([{ search, supportsTargets, supportsWeightsForTargets }, supportsSingleGrouped]) => search.withNearObject({ ...Serialize.search.nearObject( { @@ -191,30 +227,33 @@ class GenerateManager implements Generate { }, opts ), - generative: Serialize.generative(generate), + generative: await Serialize.generative({ supportsSingleGrouped }, generate), }) ) .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearText( + public nearText( query: string | string[], - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseNearTextOptions - ): Promise>; - public nearText( + ): Promise>; + public nearText( query: string | string[], - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByNearTextOptions - ): Promise>; - public nearText( + ): Promise>; + public nearText( query: string | string[], - generate: GenerateOptions, + generate: GenerateOptions, opts?: NearOptions - ): GenerateReturn { - return this.check - .nearSearch(opts) - .then(({ search, supportsTargets, supportsWeightsForTargets }) => + ): GenerateReturn { + return Promise.all([ + this.check.nearSearch(opts), + this.check.supportForSingleGroupedGenerative(), + this.check.supportForGenerativeConfigRuntime(generate.config), + ]) + .then(async ([{ search, supportsTargets, supportsWeightsForTargets }, supportsSingleGrouped]) => search.withNearText({ ...Serialize.search.nearText( { @@ -224,114 +263,132 @@ class GenerateManager implements Generate { }, opts ), - generative: Serialize.generative(generate), + generative: await Serialize.generative({ supportsSingleGrouped }, generate), }) ) .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearVector( + public nearVector( vector: number[], - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseNearOptions - ): Promise>; - public nearVector( + ): Promise>; + public nearVector( vector: number[], - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByNearOptions - ): Promise>; - public nearVector( + ): Promise>; + public nearVector( vector: number[], - generate: GenerateOptions, + generate: GenerateOptions, opts?: NearOptions - ): GenerateReturn { - return this.check - .nearVector(vector, opts) - .then(({ search, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets }) => - search.withNearVector({ - ...Serialize.search.nearVector( - { - vector, - supportsTargets, - supportsVectorsForTargets, - supportsWeightsForTargets, - }, - opts - ), - generative: Serialize.generative(generate), - }) + ): GenerateReturn { + return Promise.all([ + this.check.nearVector(vector, opts), + this.check.supportForSingleGroupedGenerative(), + this.check.supportForGenerativeConfigRuntime(generate.config), + ]) + .then( + async ([ + { search, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets }, + supportsSingleGrouped, + ]) => + search.withNearVector({ + ...Serialize.search.nearVector( + { + vector, + supportsTargets, + supportsVectorsForTargets, + supportsWeightsForTargets, + }, + opts + ), + generative: await Serialize.generative({ supportsSingleGrouped }, generate), + }) ) .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearMedia( + public nearMedia( media: string | Buffer, type: NearMediaType, - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseNearOptions - ): Promise>; - public nearMedia( + ): Promise>; + public nearMedia( media: string | Buffer, type: NearMediaType, - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByNearOptions - ): Promise>; - public nearMedia( + ): Promise>; + public nearMedia( media: string | Buffer, type: NearMediaType, - generate: GenerateOptions, + generate: GenerateOptions, opts?: NearOptions - ): GenerateReturn { - return this.check - .nearSearch(opts) - .then(({ search, supportsTargets, supportsWeightsForTargets }) => { + ): GenerateReturn { + return Promise.all([ + this.check.nearSearch(opts), + this.check.supportForSingleGroupedGenerative(), + this.check.supportForGenerativeConfigRuntime(generate.config), + ]) + .then(([{ search, supportsTargets, supportsWeightsForTargets }, supportsSingleGrouped]) => { const args = { supportsTargets, supportsWeightsForTargets, }; - const generative = Serialize.generative(generate); - let send: (media: string) => Promise; + let send: (media: string, generative: GenerativeSearch) => Promise; switch (type) { case 'audio': - send = (media) => + send = (media, generative) => search.withNearAudio({ ...Serialize.search.nearAudio({ audio: media, ...args }, opts), generative, }); break; case 'depth': - send = (media) => + send = (media, generative) => search.withNearDepth({ ...Serialize.search.nearDepth({ depth: media, ...args }, opts), generative, }); break; case 'image': - send = (media) => + send = (media, generative) => search.withNearImage({ ...Serialize.search.nearImage({ image: media, ...args }, opts), generative, }); break; case 'imu': - send = (media) => - search.withNearIMU({ ...Serialize.search.nearIMU({ imu: media, ...args }, opts), generative }); + send = (media, generative) => + search.withNearIMU({ + ...Serialize.search.nearIMU({ imu: media, ...args }, opts), + generative, + }); break; case 'thermal': - send = (media) => + send = (media, generative) => search.withNearThermal({ ...Serialize.search.nearThermal({ thermal: media, ...args }, opts), generative, }); break; case 'video': - send = (media) => - search.withNearVideo({ ...Serialize.search.nearVideo({ video: media, ...args }), generative }); + send = (media, generative) => + search.withNearVideo({ + ...Serialize.search.nearVideo({ video: media, ...args }), + generative, + }); break; default: throw new WeaviateInvalidInputError(`Invalid media type: ${type}`); } - return toBase64FromMedia(media).then(send); + return Promise.all([ + toBase64FromMedia(media), + Serialize.generative({ supportsSingleGrouped }, generate), + ]).then(([media, generative]) => send(media, generative)); }) .then((reply) => this.parseGroupByReply(opts, reply)); } @@ -339,4 +396,5 @@ class GenerateManager implements Generate { export default GenerateManager.use; +export { generativeParameters } from './config.js'; export { Generate } from './types.js'; diff --git a/src/collections/generate/integration.test.ts b/src/collections/generate/integration.test.ts index 41846f90..39e8e351 100644 --- a/src/collections/generate/integration.test.ts +++ b/src/collections/generate/integration.test.ts @@ -4,6 +4,7 @@ import { WeaviateUnsupportedFeatureError } from '../../errors.js'; import weaviate, { WeaviateClient } from '../../index.js'; import { Collection } from '../collection/index.js'; import { GenerateOptions, GroupByOptions } from '../types/index.js'; +import { generativeParameters } from './config.js'; const maybe = process.env.OPENAI_APIKEY ? describe : describe.skip; @@ -27,10 +28,10 @@ maybe('Testing of the collection.generate methods with a simple collection', () testProp: string; }; - const generateOpts: GenerateOptions = { + const generateOpts = { singlePrompt: 'Write a haiku about ducks for {testProp}', groupedTask: 'What is the value of testProp here?', - groupedProperties: ['testProp'], + groupedProperties: ['testProp'] as 'testProp'[], }; afterAll(() => { @@ -148,6 +149,26 @@ maybe('Testing of the collection.generate methods with a simple collection', () expect(ret.objects[0].uuid).toEqual(id); expect(ret.objects[0].generated).toBeDefined(); }); + + it('should generate in a BC-compatible way', async () => { + const query = () => collection.generate.fetchObjects(generateOpts); + + const res = await query(); + expect(res.objects.length).toEqual(1); + expect(res.generated).toBeDefined(); + expect(res.generated).not.toEqual(''); + expect(res.generative?.text).toBeDefined(); + expect(res.generative?.text).not.toEqual(''); + expect(res.generative?.metadata).toBeUndefined(); + res.objects.forEach((obj) => { + expect(obj.generated).toBeDefined(); + expect(obj.generated).not.toEqual(''); + expect(obj.generative?.text).toBeDefined(); + expect(obj.generative?.text).not.toEqual(''); + expect(obj.generative?.metadata).toBeUndefined(); + expect(obj.generative?.debug).toBeUndefined(); + }); + }); }); }); @@ -162,7 +183,7 @@ maybe('Testing of the groupBy collection.generate methods with a simple collecti testProp: string; }; - const generateOpts: GenerateOptions = { + const generateOpts: GenerateOptions = { singlePrompt: 'Write a haiku about ducks for {testProp}', groupedTask: 'What is the value of testProp here?', groupedProperties: ['testProp'], @@ -421,3 +442,116 @@ maybe('Testing of the collection.generate methods with a multi vector collection expect(ret.objects[1].generated).toBeDefined(); }); }); + +maybe('Testing of the collection.generate methods with runtime generative config', () => { + let client: WeaviateClient; + let collection: Collection; + const collectionName = 'TestCollectionGenerateConfigRuntime'; + + type TestCollectionGenerateConfigRuntime = { + testProp: string; + }; + + afterAll(() => { + return client.collections.delete(collectionName).catch((err) => { + console.error(err); + throw err; + }); + }); + + beforeAll(async () => { + client = await makeOpenAIClient(); + collection = client.collections.get(collectionName); + return client.collections + .create({ + name: collectionName, + properties: [ + { + name: 'testProp', + dataType: 'text', + }, + ], + }) + .then(() => { + return collection.data.insert({ + properties: { + testProp: 'test', + }, + }); + }); + }); + + it('should generate using a runtime config without search and with extras', async () => { + const query = () => + collection.generate.fetchObjects({ + singlePrompt: { + prompt: 'Write a haiku about ducks for {testProp}', + debug: true, + metadata: true, + }, + groupedTask: { + prompt: 'What is the value of testProp here?', + nonBlobProperties: ['testProp'], + metadata: true, + }, + config: generativeParameters.openAI({ + stop: ['\n'], + }), + }); + + if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 30, 0))) { + await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError); + return; + } + + const res = await query(); + expect(res.objects.length).toEqual(1); + expect(res.generated).toBeDefined(); + expect(res.generated).not.toEqual(''); + expect(res.generative?.text).toBeDefined(); + expect(res.generative?.text).not.toEqual(''); + expect(res.generative?.metadata).toBeDefined(); + res.objects.forEach((obj) => { + expect(obj.generated).toBeDefined(); + expect(obj.generative?.text).toBeDefined(); + expect(obj.generative?.metadata).toBeDefined(); + expect(obj.generative?.debug).toBeDefined(); + }); + }); + + it('should generate using a runtime config without search nor extras', async () => { + const query = () => + collection.generate.fetchObjects({ + singlePrompt: 'Write a haiku about ducks for {testProp}', + groupedTask: 'What is the value of testProp here?', + config: { + name: 'generative-openai', + config: { + model: 'gpt-4o-mini', + stop: { values: ['\n'] }, + }, + }, + }); + + if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 30, 0))) { + await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError); + return; + } + + const res = await query(); + expect(res.objects.length).toEqual(1); + expect(res.generated).toBeDefined(); + expect(res.generated).not.toEqual(''); + expect(res.generative?.text).toBeDefined(); + expect(res.generative?.text).not.toEqual(''); + expect(res.generative?.metadata).toBeUndefined(); + res.objects.forEach((obj) => { + expect(obj.generated).toBeDefined(); + expect(obj.generated).not.toEqual(''); + expect(obj.generative?.text).toBeDefined(); + expect(obj.generative?.text).not.toEqual(''); + expect(obj.generative?.metadata).toBeUndefined(); + expect(obj.generative?.debug).toBeUndefined(); + }); + }); +}); diff --git a/src/collections/generate/mock.test.ts b/src/collections/generate/mock.test.ts new file mode 100644 index 00000000..d4a0c025 --- /dev/null +++ b/src/collections/generate/mock.test.ts @@ -0,0 +1,169 @@ +import express from 'express'; +import { Server as HttpServer } from 'http'; +import { Server as GrpcServer, createServer } from 'nice-grpc'; +import weaviate, { Collection, GenerativeConfigRuntime, WeaviateClient } from '../..'; +import { + HealthCheckRequest, + HealthCheckResponse, + HealthCheckResponse_ServingStatus, + HealthDefinition, + HealthServiceImplementation, +} from '../../proto/google/health/v1/health'; +import { GenerativeResult } from '../../proto/v1/generative'; +import { SearchReply, SearchRequest, SearchResult } from '../../proto/v1/search_get'; +import { WeaviateDefinition, WeaviateServiceImplementation } from '../../proto/v1/weaviate'; +import { generativeParameters } from './config'; + +const mockedSingleGenerative = 'Mocked single response'; +const mockedGroupedGenerative = 'Mocked group response'; + +class GenerateMock { + private grpc: GrpcServer; + private http: HttpServer; + + constructor(grpc: GrpcServer, http: HttpServer) { + this.grpc = grpc; + this.http = http; + } + + public static use = async (version: string, httpPort: number, grpcPort: number) => { + const httpApp = express(); + // Meta endpoint required for client instantiation + httpApp.get('/v1/meta', (req, res) => res.send({ version })); + + // gRPC health check required for client instantiation + const healthMockImpl: HealthServiceImplementation = { + check: (request: HealthCheckRequest): Promise => + Promise.resolve(HealthCheckResponse.create({ status: HealthCheckResponse_ServingStatus.SERVING })), + watch: jest.fn(), + }; + + const grpc = createServer(); + grpc.add(HealthDefinition, healthMockImpl); + + // Search endpoint returning generative mock data + const weaviateMockImpl: WeaviateServiceImplementation = { + aggregate: jest.fn(), + tenantsGet: jest.fn(), + search: (req: SearchRequest): Promise => { + expect(req.generative?.grouped?.queries.length).toBeGreaterThan(0); + expect(req.generative?.single?.queries.length).toBeGreaterThan(0); + return Promise.resolve( + SearchReply.fromPartial({ + results: [ + SearchResult.fromPartial({ + properties: { + nonRefProps: { fields: { name: { textValue: 'thing' } } }, + }, + generative: GenerativeResult.fromPartial({ + values: [ + { + result: mockedSingleGenerative, + }, + ], + }), + metadata: { + id: 'b602a271-d5a9-4324-921d-5abe4748d6b5', + }, + }), + ], + generativeGroupedResults: GenerativeResult.fromPartial({ + values: [ + { + result: mockedGroupedGenerative, + }, + ], + }), + }) + ); + }, + batchDelete: jest.fn(), + batchObjects: jest.fn(), + }; + grpc.add(WeaviateDefinition, weaviateMockImpl); + + await grpc.listen(`localhost:${grpcPort}`); + const http = await httpApp.listen(httpPort); + return new GenerateMock(grpc, http); + }; + + public close = () => Promise.all([this.http.close(), this.grpc.shutdown()]); +} + +describe('Mock testing of generate with runtime config', () => { + let client: WeaviateClient; + let collection: Collection; + let mock: GenerateMock; + + beforeAll(async () => { + mock = await GenerateMock.use('1.30.0-rc.1', 8958, 8959); + client = await weaviate.connectToLocal({ port: 8958, grpcPort: 8959 }); + collection = client.collections.use('Whatever'); + }); + + afterAll(() => mock.close()); + + const stringTest = (config: GenerativeConfigRuntime) => + collection.generate + .fetchObjects({ + singlePrompt: 'What is the meaning of life?', + groupedTask: 'What is the meaning of life?', + config: config, + }) + .then((res) => { + expect(res.generative?.text).toEqual(mockedGroupedGenerative); + expect(res.objects[0].generative?.text).toEqual(mockedSingleGenerative); + }); + + const objectTest = (config: GenerativeConfigRuntime) => + collection.generate + .fetchObjects({ + singlePrompt: { + prompt: 'What is the meaning of life?', + }, + groupedTask: { + prompt: 'What is the meaning of life?', + }, + config: config, + }) + .then((res) => { + expect(res.generative?.text).toEqual(mockedGroupedGenerative); + expect(res.objects[0].generative?.text).toEqual(mockedSingleGenerative); + }); + + const model = { model: 'llama-2' }; + + const tests: GenerativeConfigRuntime[] = [ + generativeParameters.anthropic(), + generativeParameters.anthropic(model), + generativeParameters.anyscale(), + generativeParameters.anyscale(model), + generativeParameters.aws(), + generativeParameters.aws(model), + generativeParameters.azureOpenAI(), + generativeParameters.azureOpenAI(model), + generativeParameters.cohere(), + generativeParameters.cohere(model), + generativeParameters.databricks(), + generativeParameters.databricks(model), + generativeParameters.friendliai(), + generativeParameters.friendliai(model), + generativeParameters.google(), + generativeParameters.google(model), + generativeParameters.mistral(), + generativeParameters.mistral(model), + generativeParameters.nvidia(), + generativeParameters.nvidia(model), + generativeParameters.ollama(), + generativeParameters.ollama(model), + generativeParameters.openAI(), + generativeParameters.openAI(model), + ]; + + tests.forEach((conf) => { + it(`should get the mocked response for ${conf.name} with config: ${conf.config}`, async () => { + await stringTest(conf); + await objectTest(conf); + }); + }); +}); diff --git a/src/collections/generate/types.ts b/src/collections/generate/types.ts index b211a46a..27548bfb 100644 --- a/src/collections/generate/types.ts +++ b/src/collections/generate/types.ts @@ -18,6 +18,7 @@ import { import { GenerateOptions, GenerateReturn, + GenerativeConfigRuntime, GenerativeGroupByReturn, GenerativeReturn, } from '../types/index.js'; @@ -31,11 +32,15 @@ interface Bm25 { * This overload is for performing a search without the `groupBy` param. * * @param {string} query - The query to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {BaseBm25Options} [opts] - The available options for performing the BM25 search. - * @return {Promise>} - The results of the search including the generated data. + * @return {Promise>} - The results of the search including the generated data. */ - bm25(query: string, generate: GenerateOptions, opts?: BaseBm25Options): Promise>; + bm25( + query: string, + generate: GenerateOptions, + opts?: BaseBm25Options + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a keyword-based BM25 search of objects in this collection. * @@ -44,15 +49,15 @@ interface Bm25 { * This overload is for performing a search with the `groupBy` param. * * @param {string} query - The query to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {GroupByBm25Options} opts - The available options for performing the BM25 search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ - bm25( + bm25( query: string, - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByBm25Options - ): Promise>; + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a keyword-based BM25 search of objects in this collection. * @@ -61,11 +66,15 @@ interface Bm25 { * This overload is for performing a search with a programmatically defined `opts` param. * * @param {string} query - The query to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {Bm25Options} [opts] - The available options for performing the BM25 search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @return {GenerateReturn} - The results of the search including the generated data. */ - bm25(query: string, generate: GenerateOptions, opts?: Bm25Options): GenerateReturn; + bm25( + query: string, + generate: GenerateOptions, + opts?: Bm25Options + ): GenerateReturn; } interface Hybrid { @@ -77,15 +86,15 @@ interface Hybrid { * This overload is for performing a search without the `groupBy` param. * * @param {string} query - The query to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {BaseHybridOptions} [opts] - The available options for performing the hybrid search. - * @return {Promise>} - The results of the search including the generated data. + * @return {Promise>} - The results of the search including the generated data. */ - hybrid( + hybrid( query: string, - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseHybridOptions - ): Promise>; + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of an object search in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. * @@ -94,15 +103,15 @@ interface Hybrid { * This overload is for performing a search with the `groupBy` param. * * @param {string} query - The query to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {GroupByHybridOptions} opts - The available options for performing the hybrid search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ - hybrid( + hybrid( query: string, - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByHybridOptions - ): Promise>; + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of an object search in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. * @@ -111,11 +120,15 @@ interface Hybrid { * This overload is for performing a search with a programmatically defined `opts` param. * * @param {string} query - The query to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {HybridOptions} [opts] - The available options for performing the hybrid search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @return {GenerateReturn} - The results of the search including the generated data. */ - hybrid(query: string, generate: GenerateOptions, opts?: HybridOptions): GenerateReturn; + hybrid( + query: string, + generate: GenerateOptions, + opts?: HybridOptions + ): GenerateReturn; } interface NearMedia { @@ -130,16 +143,16 @@ interface NearMedia { * * @param {string | Buffer} media - The media file to search on. This can be a base64 string, a file path string, or a buffer. * @param {NearMediaType} type - The type of media to search on. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {BaseNearOptions} [opts] - The available options for performing the near-media search. - * @return {Promise>} - The results of the search including the generated data. + * @return {Promise>} - The results of the search including the generated data. */ - nearMedia( + nearMedia( media: string | Buffer, type: NearMediaType, - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseNearOptions - ): Promise>; + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-audio object search in this collection using an audio-capable vectorization module and vector-based similarity search. * @@ -151,16 +164,16 @@ interface NearMedia { * * @param {string | Buffer} media - The media file to search on. This can be a base64 string, a file path string, or a buffer. * @param {NearMediaType} type - The type of media to search on. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {GroupByNearOptions} opts - The available options for performing the near-media search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ - nearMedia( + nearMedia( media: string | Buffer, type: NearMediaType, - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByNearOptions - ): Promise>; + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-audio object search in this collection using an audio-capable vectorization module and vector-based similarity search. * @@ -172,16 +185,16 @@ interface NearMedia { * * @param {string | Buffer} media - The media to search on. This can be a base64 string, a file path string, or a buffer. * @param {NearMediaType} type - The type of media to search on. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {NearOptions} [opts] - The available options for performing the near-media search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @return {GenerateReturn} - The results of the search including the generated data. */ - nearMedia( + nearMedia( media: string | Buffer, type: NearMediaType, - generate: GenerateOptions, + generate: GenerateOptions, opts?: NearOptions - ): GenerateReturn; + ): GenerateReturn; } interface NearObject { @@ -193,15 +206,15 @@ interface NearObject { * This overload is for performing a search without the `groupBy` param. * * @param {string} id - The ID of the object to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {BaseNearOptions} [opts] - The available options for performing the near-object search. - * @return {Promise>} - The results of the search including the generated data. + * @return {Promise>} - The results of the search including the generated data. */ - nearObject( + nearObject( id: string, - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseNearOptions - ): Promise>; + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-object object search in this collection using a vector-based similarity search. * @@ -210,15 +223,15 @@ interface NearObject { * This overload is for performing a search with the `groupBy` param. * * @param {string} id - The ID of the object to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {GroupByNearOptions} opts - The available options for performing the near-object search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ - nearObject( + nearObject( id: string, - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByNearOptions - ): Promise>; + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-object object search in this collection using a vector-based similarity search. * @@ -227,11 +240,15 @@ interface NearObject { * This overload is for performing a search with a programmatically defined `opts` param. * * @param {string} id - The ID of the object to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {NearOptions} [opts] - The available options for performing the near-object search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @return {GenerateReturn} - The results of the search including the generated data. */ - nearObject(id: string, generate: GenerateOptions, opts?: NearOptions): GenerateReturn; + nearObject( + id: string, + generate: GenerateOptions, + opts?: NearOptions + ): GenerateReturn; } interface NearText { @@ -245,15 +262,15 @@ interface NearText { * This overload is for performing a search without the `groupBy` param. * * @param {string | string[]} query - The query to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {BaseNearTextOptions} [opts] - The available options for performing the near-text search. - * @return {Promise>} - The results of the search including the generated data. + * @return {Promise>} - The results of the search including the generated data. */ - nearText( + nearText( query: string | string[], - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseNearTextOptions - ): Promise>; + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-image object search in this collection using the image-capable vectorization module and vector-based similarity search. * @@ -264,15 +281,15 @@ interface NearText { * This overload is for performing a search with the `groupBy` param. * * @param {string | string[]} query - The query to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {GroupByNearTextOptions} opts - The available options for performing the near-text search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ - nearText( + nearText( query: string | string[], - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByNearTextOptions - ): Promise>; + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-image object search in this collection using the image-capable vectorization module and vector-based similarity search. * @@ -283,15 +300,15 @@ interface NearText { * This overload is for performing a search with a programmatically defined `opts` param. * * @param {string | string[]} query - The query to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {NearTextOptions} [opts] - The available options for performing the near-text search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @return {GenerateReturn} - The results of the search including the generated data. */ - nearText( + nearText( query: string | string[], - generate: GenerateOptions, + generate: GenerateOptions, opts?: NearTextOptions - ): GenerateReturn; + ): GenerateReturn; } interface NearVector { @@ -303,15 +320,15 @@ interface NearVector { * This overload is for performing a search without the `groupBy` param. * * @param {NearVectorInputType} vector - The vector(s) to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {BaseNearOptions} [opts] - The available options for performing the near-vector search. - * @return {Promise>} - The results of the search including the generated data. + * @return {Promise>} - The results of the search including the generated data. */ - nearVector( + nearVector( vector: NearVectorInputType, - generate: GenerateOptions, + generate: GenerateOptions, opts?: BaseNearOptions - ): Promise>; + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-vector object search in this collection using vector-based similarity search. * @@ -320,15 +337,15 @@ interface NearVector { * This overload is for performing a search with the `groupBy` param. * * @param {NearVectorInputType} vector - The vector(s) to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {GroupByNearOptions} opts - The available options for performing the near-vector search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ - nearVector( + nearVector( vector: NearVectorInputType, - generate: GenerateOptions, + generate: GenerateOptions, opts: GroupByNearOptions - ): Promise>; + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-vector object search in this collection using vector-based similarity search. * @@ -337,15 +354,15 @@ interface NearVector { * This overload is for performing a search with a programmatically defined `opts` param. * * @param {NearVectorInputType} vector - The vector(s) to search for. - * @param {GenerateOptions} generate - The available options for performing the generation. + * @param {GenerateOptions} generate - The available options for performing the generation. * @param {NearOptions} [opts] - The available options for performing the near-vector search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @return {GenerateReturn} - The results of the search including the generated data. */ - nearVector( + nearVector( vector: NearVectorInputType, - generate: GenerateOptions, + generate: GenerateOptions, opts?: NearOptions - ): GenerateReturn; + ): GenerateReturn; } export interface Generate @@ -355,5 +372,8 @@ export interface Generate NearObject, NearText, NearVector { - fetchObjects: (generate: GenerateOptions, opts?: FetchObjectsOptions) => Promise>; + fetchObjects: ( + generate: GenerateOptions, + opts?: FetchObjectsOptions + ) => Promise>; } diff --git a/src/collections/generate/unit.test.ts b/src/collections/generate/unit.test.ts new file mode 100644 index 00000000..54ecb335 --- /dev/null +++ b/src/collections/generate/unit.test.ts @@ -0,0 +1,280 @@ +import { GenerativeConfigRuntimeType, ModuleConfig } from '../types'; +import { generativeParameters } from './config'; + +// only tests fields that must be mapped from some public name to a gRPC name, e.g. baseURL -> baseUrl and stop: string[] -> stop: TextArray +describe('Unit testing of the generativeParameters factory methods', () => { + describe('anthropic', () => { + it('with defaults', () => { + const config = generativeParameters.anthropic(); + expect(config).toEqual< + ModuleConfig<'generative-anthropic', GenerativeConfigRuntimeType<'generative-anthropic'> | undefined> + >({ + name: 'generative-anthropic', + config: undefined, + }); + }); + it('with values', () => { + const config = generativeParameters.anthropic({ + baseURL: 'http://localhost:8080', + stopSequences: ['a', 'b', 'c'], + }); + expect(config).toEqual< + ModuleConfig<'generative-anthropic', GenerativeConfigRuntimeType<'generative-anthropic'> | undefined> + >({ + name: 'generative-anthropic', + config: { + baseUrl: 'http://localhost:8080', + stopSequences: { values: ['a', 'b', 'c'] }, + }, + }); + }); + }); + + describe('anyscale', () => { + it('with defaults', () => { + const config = generativeParameters.anyscale(); + expect(config).toEqual< + ModuleConfig<'generative-anyscale', GenerativeConfigRuntimeType<'generative-anyscale'> | undefined> + >({ + name: 'generative-anyscale', + config: undefined, + }); + }); + it('with values', () => { + const config = generativeParameters.anyscale({ + baseURL: 'http://localhost:8080', + }); + expect(config).toEqual< + ModuleConfig<'generative-anyscale', GenerativeConfigRuntimeType<'generative-anyscale'> | undefined> + >({ + name: 'generative-anyscale', + config: { + baseUrl: 'http://localhost:8080', + }, + }); + }); + }); + + describe('aws', () => { + it('with defaults', () => { + const config = generativeParameters.aws(); + expect(config).toEqual< + ModuleConfig<'generative-aws', GenerativeConfigRuntimeType<'generative-aws'> | undefined> + >({ + name: 'generative-aws', + config: undefined, + }); + }); + }); + + describe('azure-openai', () => { + it('with defaults', () => { + const config = generativeParameters.azureOpenAI(); + expect(config).toEqual< + ModuleConfig<'generative-azure-openai', GenerativeConfigRuntimeType<'generative-azure-openai'>> + >({ + name: 'generative-azure-openai', + config: { isAzure: true }, + }); + }); + it('with values', () => { + const config = generativeParameters.azureOpenAI({ + baseURL: 'http://localhost:8080', + model: 'model', + stop: ['a', 'b', 'c'], + }); + expect(config).toEqual< + ModuleConfig<'generative-azure-openai', GenerativeConfigRuntimeType<'generative-azure-openai'>> + >({ + name: 'generative-azure-openai', + config: { + baseUrl: 'http://localhost:8080', + stop: { values: ['a', 'b', 'c'] }, + model: 'model', + isAzure: true, + }, + }); + }); + }); + + describe('cohere', () => { + it('with defaults', () => { + const config = generativeParameters.cohere(); + expect(config).toEqual< + ModuleConfig<'generative-cohere', GenerativeConfigRuntimeType<'generative-cohere'> | undefined> + >({ + name: 'generative-cohere', + config: undefined, + }); + }); + it('with values', () => { + const config = generativeParameters.cohere({ + baseURL: 'http://localhost:8080', + stopSequences: ['a', 'b', 'c'], + }); + expect(config).toEqual< + ModuleConfig<'generative-cohere', GenerativeConfigRuntimeType<'generative-cohere'> | undefined> + >({ + name: 'generative-cohere', + config: { + baseUrl: 'http://localhost:8080', + stopSequences: { values: ['a', 'b', 'c'] }, + }, + }); + }); + }); + + describe('databricks', () => { + it('with defaults', () => { + const config = generativeParameters.databricks(); + expect(config).toEqual< + ModuleConfig< + 'generative-databricks', + GenerativeConfigRuntimeType<'generative-databricks'> | undefined + > + >({ + name: 'generative-databricks', + config: undefined, + }); + }); + it('with values', () => { + const config = generativeParameters.databricks({ + stop: ['a', 'b', 'c'], + }); + expect(config).toEqual< + ModuleConfig< + 'generative-databricks', + GenerativeConfigRuntimeType<'generative-databricks'> | undefined + > + >({ + name: 'generative-databricks', + config: { + stop: { values: ['a', 'b', 'c'] }, + }, + }); + }); + }); + + describe('friendliai', () => { + it('with defaults', () => { + const config = generativeParameters.friendliai(); + expect(config).toEqual< + ModuleConfig< + 'generative-friendliai', + GenerativeConfigRuntimeType<'generative-friendliai'> | undefined + > + >({ + name: 'generative-friendliai', + config: undefined, + }); + }); + it('with values', () => { + const config = generativeParameters.friendliai({ + baseURL: 'http://localhost:8080', + }); + expect(config).toEqual< + ModuleConfig< + 'generative-friendliai', + GenerativeConfigRuntimeType<'generative-friendliai'> | undefined + > + >({ + name: 'generative-friendliai', + config: { + baseUrl: 'http://localhost:8080', + }, + }); + }); + }); + + describe('mistral', () => { + it('with defaults', () => { + const config = generativeParameters.mistral(); + expect(config).toEqual< + ModuleConfig<'generative-mistral', GenerativeConfigRuntimeType<'generative-mistral'> | undefined> + >({ + name: 'generative-mistral', + config: undefined, + }); + }); + it('with values', () => { + const config = generativeParameters.mistral({ + baseURL: 'http://localhost:8080', + }); + expect(config).toEqual< + ModuleConfig<'generative-mistral', GenerativeConfigRuntimeType<'generative-mistral'> | undefined> + >({ + name: 'generative-mistral', + config: { + baseUrl: 'http://localhost:8080', + }, + }); + }); + }); + + describe('nvidia', () => { + it('with defaults', () => { + const config = generativeParameters.nvidia(); + expect(config).toEqual< + ModuleConfig<'generative-nvidia', GenerativeConfigRuntimeType<'generative-nvidia'> | undefined> + >({ + name: 'generative-nvidia', + config: undefined, + }); + }); + it('with values', () => { + const config = generativeParameters.nvidia({ + baseURL: 'http://localhost:8080', + }); + expect(config).toEqual< + ModuleConfig<'generative-nvidia', GenerativeConfigRuntimeType<'generative-nvidia'> | undefined> + >({ + name: 'generative-nvidia', + config: { + baseUrl: 'http://localhost:8080', + }, + }); + }); + }); + + describe('ollama', () => { + it('with defaults', () => { + const config = generativeParameters.ollama(); + expect(config).toEqual< + ModuleConfig<'generative-ollama', GenerativeConfigRuntimeType<'generative-ollama'> | undefined> + >({ + name: 'generative-ollama', + config: undefined, + }); + }); + }); + + describe('openai', () => { + it('with defaults', () => { + const config = generativeParameters.openAI(); + expect(config).toEqual< + ModuleConfig<'generative-openai', GenerativeConfigRuntimeType<'generative-openai'>> + >({ + name: 'generative-openai', + config: { isAzure: false }, + }); + }); + it('with values', () => { + const config = generativeParameters.openAI({ + baseURL: 'http://localhost:8080', + model: 'model', + stop: ['a', 'b', 'c'], + }); + expect(config).toEqual< + ModuleConfig<'generative-openai', GenerativeConfigRuntimeType<'generative-openai'>> + >({ + name: 'generative-openai', + config: { + baseUrl: 'http://localhost:8080', + isAzure: false, + model: 'model', + stop: { values: ['a', 'b', 'c'] }, + }, + }); + }); + }); +}); diff --git a/src/collections/query/check.ts b/src/collections/query/check.ts index 291738de..2e562a2b 100644 --- a/src/collections/query/check.ts +++ b/src/collections/query/check.ts @@ -2,7 +2,7 @@ import Connection from '../../connection/grpc.js'; import { WeaviateUnsupportedFeatureError } from '../../errors.js'; import { ConsistencyLevel } from '../../index.js'; import { DbVersionSupport } from '../../utils/dbVersion.js'; -import { GroupByOptions } from '../index.js'; +import { GenerativeConfigRuntime, GroupByOptions } from '../index.js'; import { Serialize } from '../serialize/index.js'; import { BaseBm25Options, @@ -98,6 +98,19 @@ export class Check { return check.supports; }; + public supportForSingleGroupedGenerative = async () => { + const check = await this.dbVersionSupport.supportsSingleGrouped(); + if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message); + return check.supports; + }; + + public supportForGenerativeConfigRuntime = async (generativeConfig?: GenerativeConfigRuntime) => { + if (generativeConfig === undefined) return true; + const check = await this.dbVersionSupport.supportsGenerativeConfigRuntime(); + if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message); + return check.supports; + }; + public nearSearch = (opts?: BaseNearOptions) => { return Promise.all([ this.getSearcher(), diff --git a/src/collections/serialize/index.ts b/src/collections/serialize/index.ts index f7409b4e..e9a74973 100644 --- a/src/collections/serialize/index.ts +++ b/src/collections/serialize/index.ts @@ -25,7 +25,12 @@ import { BatchObject_Properties, BatchObject_SingleTargetRefProps, } from '../../proto/v1/batch.js'; -import { GenerativeSearch } from '../../proto/v1/generative.js'; +import { + GenerativeProvider, + GenerativeSearch, + GenerativeSearch_Grouped, + GenerativeSearch_Single, +} from '../../proto/v1/generative.js'; import { GroupBy, MetadataRequest, @@ -63,6 +68,7 @@ import { SearchNearVectorArgs, SearchNearVideoArgs, } from '../../grpc/searcher.js'; +import { toBase64FromMedia } from '../../index.js'; import { AggregateRequest_Aggregation, AggregateRequest_Aggregation_Boolean, @@ -82,6 +88,7 @@ import { ObjectArrayProperties, ObjectProperties, ObjectPropertiesValue, + TextArray, TextArrayProperties, Vectors as VectorsGrpc, } from '../../proto/v1/base.js'; @@ -97,10 +104,13 @@ import { AggregateBaseOptions, AggregateHybridOptions, AggregateNearOptions, + GenerativeConfigRuntime, GroupByAggregate, + GroupedTask, MultiTargetVectorJoin, PrimitiveKeys, PropertiesMetrics, + SinglePrompt, } from '../index.js'; import { BaseHybridOptions, @@ -818,14 +828,126 @@ export class Serialize { return vec !== undefined && !Array.isArray(vec) && Object.values(vec).some(ArrayInputGuards.is2DArray); }; - public static generative = (generative?: GenerateOptions): GenerativeSearch => { - return GenerativeSearch.fromPartial({ - singleResponsePrompt: generative?.singlePrompt, - groupedResponseTask: generative?.groupedTask, - groupedProperties: generative?.groupedProperties as string[], - }); + private static generativeQuery = async ( + generative: GenerativeConfigRuntime, + opts?: { metadata?: boolean; images?: (string | Buffer)[]; imageProperties?: string[] } + ): Promise => { + const withImages = async >( + config: T, + imgs?: (string | Buffer)[], + imgProps?: string[] + ): Promise => { + if (imgs == undefined && imgProps == undefined) { + return config; + } + return { + ...config, + images: TextArray.fromPartial({ + values: imgs ? await Promise.all(imgs.map(toBase64FromMedia)) : undefined, + }), + imageProperties: TextArray.fromPartial({ values: imgProps }), + }; + }; + + const provider = GenerativeProvider.fromPartial({ returnMetadata: opts?.metadata }); + switch (generative.name) { + case 'generative-anthropic': + provider.anthropic = await withImages(generative.config || {}, opts?.images, opts?.imageProperties); + break; + case 'generative-anyscale': + provider.anyscale = generative.config || {}; + break; + case 'generative-aws': + provider.aws = await withImages(generative.config || {}, opts?.images, opts?.imageProperties); + break; + case 'generative-cohere': + provider.cohere = generative.config || {}; + break; + case 'generative-databricks': + provider.databricks = generative.config || {}; + break; + case 'generative-dummy': + provider.dummy = generative.config || {}; + break; + case 'generative-friendliai': + provider.friendliai = generative.config || {}; + break; + case 'generative-google': + provider.google = await withImages(generative.config || {}, opts?.images, opts?.imageProperties); + break; + case 'generative-mistral': + provider.mistral = generative.config || {}; + break; + case 'generative-nvidia': + provider.nvidia = generative.config || {}; + break; + case 'generative-ollama': + provider.ollama = await withImages(generative.config || {}, opts?.images, opts?.imageProperties); + break; + case 'generative-openai': + provider.openai = await withImages(generative.config || {}, opts?.images, opts?.imageProperties); + break; + } + return provider; + }; + + public static generative = async ( + args: { supportsSingleGrouped: boolean }, + opts?: GenerateOptions + ): Promise => { + const singlePrompt = Serialize.isSinglePrompt(opts?.singlePrompt) + ? opts.singlePrompt.prompt + : opts?.singlePrompt; + const singlePromptDebug = Serialize.isSinglePrompt(opts?.singlePrompt) + ? opts.singlePrompt.debug + : undefined; + + const groupedTask = Serialize.isGroupedTask(opts?.groupedTask) + ? opts.groupedTask.prompt + : opts?.groupedTask; + const groupedProperties = Serialize.isGroupedTask(opts?.groupedTask) + ? opts.groupedTask.nonBlobProperties + : opts?.groupedProperties; + + const singleOpts = Serialize.isSinglePrompt(opts?.singlePrompt) ? opts.singlePrompt : undefined; + const groupedOpts = Serialize.isGroupedTask(opts?.groupedTask) ? opts.groupedTask : undefined; + + return args.supportsSingleGrouped + ? GenerativeSearch.fromPartial({ + single: opts?.singlePrompt + ? GenerativeSearch_Single.fromPartial({ + prompt: singlePrompt, + debug: singlePromptDebug, + queries: opts.config ? [await Serialize.generativeQuery(opts.config, singleOpts)] : undefined, + }) + : undefined, + grouped: opts?.groupedTask + ? GenerativeSearch_Grouped.fromPartial({ + task: groupedTask, + queries: opts.config + ? [await Serialize.generativeQuery(opts.config, groupedOpts)] + : undefined, + properties: groupedProperties + ? TextArray.fromPartial({ values: groupedProperties as string[] }) + : undefined, + }) + : undefined, + }) + : GenerativeSearch.fromPartial({ + singleResponsePrompt: singlePrompt, + groupedResponseTask: groupedTask, + groupedProperties: groupedProperties as string[], + }); }; + public static isSinglePrompt(arg?: string | SinglePrompt): arg is SinglePrompt { + return typeof arg !== 'string' && arg !== undefined && arg.prompt !== undefined; + } + + public static isGroupedTask(arg?: string | GroupedTask): arg is GroupedTask { + return typeof arg !== 'string' && arg !== undefined && arg.prompt !== undefined; + } + private static bm25QueryProperties = ( properties?: (PrimitiveKeys | Bm25QueryProperty)[] ): string[] | undefined => { diff --git a/src/collections/serialize/unit.test.ts b/src/collections/serialize/unit.test.ts index 721d1e46..6d9f9612 100644 --- a/src/collections/serialize/unit.test.ts +++ b/src/collections/serialize/unit.test.ts @@ -441,12 +441,15 @@ describe('Unit testing of Serialize', () => { }); }); - it('should parse args for generative', () => { - const args = Serialize.generative({ - singlePrompt: 'test', - groupedProperties: ['name'], - groupedTask: 'testing', - }); + it('should parse args for generative', async () => { + const args = await Serialize.generative( + { supportsSingleGrouped: false }, + { + singlePrompt: 'test', + groupedProperties: ['name'], + groupedTask: 'testing', + } + ); expect(args).toEqual({ singleResponsePrompt: 'test', groupedProperties: ['name'], diff --git a/src/collections/tenants/index.ts b/src/collections/tenants/index.ts index f72c6959..320f166d 100644 --- a/src/collections/tenants/index.ts +++ b/src/collections/tenants/index.ts @@ -1,5 +1,5 @@ import { ConnectionGRPC } from '../../connection/index.js'; -import { WeaviateUnsupportedFeatureError } from '../../errors.js'; +import { WeaviateUnexpectedStatusCodeError, WeaviateUnsupportedFeatureError } from '../../errors.js'; import { Tenant as TenantREST } from '../../openapi/types.js'; import { TenantsCreator, TenantsDeleter, TenantsGetter, TenantsUpdater } from '../../schema/index.js'; import { DbVersionSupport } from '../../utils/dbVersion.js'; @@ -17,12 +17,10 @@ const parseValueOrValueArray = (value: V | V[]) => (Array.isArray(value) ? va const parseStringOrTenant = (tenant: string | T) => typeof tenant === 'string' ? tenant : tenant.name; -const parseTenantREST = (tenant: TenantREST): Tenant => { - return { - name: tenant.name!, - activityStatus: Deserialize.activityStatusREST(tenant.activityStatus), - }; -}; +const parseTenantREST = (tenant: TenantREST): Tenant => ({ + name: tenant.name!, + activityStatus: Deserialize.activityStatusREST(tenant.activityStatus), +}); const tenants = ( connection: ConnectionGRPC, @@ -53,9 +51,20 @@ const tenants = ( return check.supports ? getGRPC() : getREST(); }, getByNames: (tenants: (string | T)[]) => getGRPC(tenants.map(parseStringOrTenant)), - getByName: (tenant: string | T) => { + getByName: async (tenant: string | T) => { const tenantName = parseStringOrTenant(tenant); - return getGRPC([tenantName]).then((tenants) => tenants[tenantName] || null); + if (await dbVersionSupport.supportsTenantGetRESTMethod().then((check) => !check.supports)) { + return getGRPC([tenantName]).then((tenants) => tenants[tenantName] ?? null); + } + return connection + .get(`/schema/${collection}/tenants/${tenantName}`) + .then(parseTenantREST) + .catch((err) => { + if (err instanceof WeaviateUnexpectedStatusCodeError && err.code === 404) { + return null; + } + throw err; + }); }, remove: (tenants: string | T | (string | T)[]) => new TenantsDeleter( diff --git a/src/collections/types/generate.ts b/src/collections/types/generate.ts index b3f6bac2..edd16e71 100644 --- a/src/collections/types/generate.ts +++ b/src/collections/types/generate.ts @@ -1,54 +1,314 @@ +import { + GenerativeAWS as GenerativeAWSGRPC, + GenerativeAWSMetadata, + GenerativeAnthropic as GenerativeAnthropicGRPC, + GenerativeAnthropicMetadata, + GenerativeAnyscale as GenerativeAnyscaleGRPC, + GenerativeAnyscaleMetadata, + GenerativeCohere as GenerativeCohereGRPC, + GenerativeCohereMetadata, + GenerativeDatabricks as GenerativeDatabricksGRPC, + GenerativeDatabricksMetadata, + GenerativeDebug, + GenerativeDummy as GenerativeDummyGRPC, + GenerativeDummyMetadata, + GenerativeFriendliAI as GenerativeFriendliAIGRPC, + GenerativeFriendliAIMetadata, + GenerativeGoogle as GenerativeGoogleGRPC, + GenerativeGoogleMetadata, + GenerativeMistral as GenerativeMistralGRPC, + GenerativeMistralMetadata, + GenerativeNvidia as GenerativeNvidiaGRPC, + GenerativeNvidiaMetadata, + GenerativeOllama as GenerativeOllamaGRPC, + GenerativeOllamaMetadata, + GenerativeOpenAI as GenerativeOpenAIGRPC, + GenerativeOpenAIMetadata, +} from '../../proto/v1/generative.js'; +import { ModuleConfig } from '../index.js'; import { GroupByObject, GroupByResult, WeaviateGenericObject, WeaviateNonGenericObject } from './query.js'; -export type GenerativeGenericObject = WeaviateGenericObject & { - /** The LLM-generated output applicable to this single object. */ +export type GenerativeGenericObject< + T, + C extends GenerativeConfigRuntime | undefined +> = WeaviateGenericObject & { + /** @deprecated (use `generative.text` instead) The LLM-generated output applicable to this single object. */ generated?: string; + /** Generative data returned from the LLM inference on this object. */ + generative?: GenerativeSingle; }; -export type GenerativeNonGenericObject = WeaviateNonGenericObject & { - /** The LLM-generated output applicable to this single object. */ - generated?: string; -}; +export type GenerativeNonGenericObject = + WeaviateNonGenericObject & { + /** @deprecated (use `generative.text` instead) The LLM-generated output applicable to this single object. */ + generated?: string; + /** Generative data returned from the LLM inference on this object. */ + generative?: GenerativeSingle; + }; /** An object belonging to a collection as returned by the methods in the `collection.generate` namespace. * * Depending on the generic type `T`, the object will have subfields that map from `T`'s specific type definition. * If not, then the object will be non-generic and have a `properties` field that maps from a generic string to a `WeaviateField`. */ -export type GenerativeObject = T extends undefined - ? GenerativeNonGenericObject - : GenerativeGenericObject; +export type GenerativeObject = T extends undefined + ? GenerativeNonGenericObject + : GenerativeGenericObject; + +export type GenerativeSingle = { + debug?: GenerativeDebug; + metadata?: GenerativeMetadata; + text?: string; +}; + +export type GenerativeGrouped = { + metadata?: GenerativeMetadata; + text?: string; +}; /** The return of a query method in the `collection.generate` namespace. */ -export type GenerativeReturn = { +export type GenerativeReturn = { /** The objects that were found by the query. */ - objects: GenerativeObject[]; - /** The LLM-generated output applicable to this query as a whole. */ + objects: GenerativeObject[]; + /** @deprecated (use `generative.text` instead) The LLM-generated output applicable to this query as a whole. */ generated?: string; + generative?: GenerativeGrouped; }; -export type GenerativeGroupByResult = GroupByResult & { +export type GenerativeGroupByResult = GroupByResult & { + /** @deprecated (use `generative.text` instead) The LLM-generated output applicable to this query as a whole. */ generated?: string; + generative?: GenerativeSingle; }; /** The return of a query method in the `collection.generate` namespace where the `groupBy` argument was specified. */ -export type GenerativeGroupByReturn = { +export type GenerativeGroupByReturn = { /** The objects that were found by the query. */ objects: GroupByObject[]; /** The groups that were created by the query. */ - groups: Record>; - /** The LLM-generated output applicable to this query as a whole. */ + groups: Record>; + /** @deprecated (use `generative.text` instead) The LLM-generated output applicable to this query as a whole. */ generated?: string; + generative?: GenerativeGrouped; }; /** Options available when defining queries using methods in the `collection.generate` namespace. */ -export type GenerateOptions = { +export type GenerateOptions = { /** The prompt to use when generating content relevant to each object of the collection individually. */ - singlePrompt?: string; + singlePrompt?: string | SinglePrompt; /** The prompt to use when generating content relevant to objects returned by the query as a whole. */ - groupedTask?: string; + groupedTask?: string | GroupedTask; /** The properties to use as context to be injected into the `groupedTask` prompt when performing the grouped generation. */ groupedProperties?: T extends undefined ? string[] : (keyof T)[]; + config?: C; +}; + +export type SinglePrompt = { + prompt: string; + debug?: boolean; + metadata?: boolean; + images?: (string | Buffer)[]; + imageProperties?: string[]; +}; + +export type GroupedTask = { + prompt: string; + metadata?: boolean; + nonBlobProperties?: T extends undefined ? string[] : (keyof T)[]; + images?: (string | Buffer)[]; + imageProperties?: string[]; +}; + +type omitFields = 'images' | 'imageProperties'; + +export type GenerativeConfigRuntime = + | ModuleConfig<'generative-anthropic', GenerativeConfigRuntimeType<'generative-anthropic'> | undefined> + | ModuleConfig<'generative-anyscale', GenerativeConfigRuntimeType<'generative-anyscale'> | undefined> + | ModuleConfig<'generative-aws', GenerativeConfigRuntimeType<'generative-aws'> | undefined> + | ModuleConfig<'generative-azure-openai', GenerativeConfigRuntimeType<'generative-azure-openai'>> + | ModuleConfig<'generative-cohere', GenerativeConfigRuntimeType<'generative-cohere'> | undefined> + | ModuleConfig<'generative-databricks', GenerativeConfigRuntimeType<'generative-databricks'> | undefined> + | ModuleConfig<'generative-dummy', GenerativeConfigRuntimeType<'generative-dummy'> | undefined> + | ModuleConfig<'generative-friendliai', GenerativeConfigRuntimeType<'generative-friendliai'> | undefined> + | ModuleConfig<'generative-google', GenerativeConfigRuntimeType<'generative-google'> | undefined> + | ModuleConfig<'generative-mistral', GenerativeConfigRuntimeType<'generative-mistral'> | undefined> + | ModuleConfig<'generative-nvidia', GenerativeConfigRuntimeType<'generative-nvidia'> | undefined> + | ModuleConfig<'generative-ollama', GenerativeConfigRuntimeType<'generative-ollama'> | undefined> + | ModuleConfig<'generative-openai', GenerativeConfigRuntimeType<'generative-openai'>>; + +export type GenerativeConfigRuntimeType = G extends 'generative-anthropic' + ? Omit + : G extends 'generative-anyscale' + ? Omit + : G extends 'generative-aws' + ? Omit + : G extends 'generative-azure-openai' + ? Omit & { isAzure: true } + : G extends 'generative-cohere' + ? Omit + : G extends 'generative-databricks' + ? Omit + : G extends 'generative-google' + ? Omit + : G extends 'generative-friendliai' + ? Omit + : G extends 'generative-mistral' + ? Omit + : G extends 'generative-nvidia' + ? Omit + : G extends 'generative-ollama' + ? Omit + : G extends 'generative-openai' + ? Omit & { isAzure?: false } + : G extends 'none' + ? undefined + : Record | undefined; + +export type GenerativeMetadata = C extends undefined + ? never + : C extends infer R extends GenerativeConfigRuntime + ? R['name'] extends 'generative-anthropic' + ? GenerativeAnthropicMetadata + : R['name'] extends 'generative-anyscale' + ? GenerativeAnyscaleMetadata + : R['name'] extends 'generative-aws' + ? GenerativeAWSMetadata + : R['name'] extends 'generative-cohere' + ? GenerativeCohereMetadata + : R['name'] extends 'generative-databricks' + ? GenerativeDatabricksMetadata + : R['name'] extends 'generative-dummy' + ? GenerativeDummyMetadata + : R['name'] extends 'generative-friendliai' + ? GenerativeFriendliAIMetadata + : R['name'] extends 'generative-google' + ? GenerativeGoogleMetadata + : R['name'] extends 'generative-mistral' + ? GenerativeMistralMetadata + : R['name'] extends 'generative-nvidia' + ? GenerativeNvidiaMetadata + : R['name'] extends 'generative-ollama' + ? GenerativeOllamaMetadata + : R['name'] extends 'generative-openai' + ? GenerativeOpenAIMetadata + : never + : never; + +export type GenerateReturn = + | Promise> + | Promise>; + +export type GenerativeAnthropicConfigRuntime = { + baseURL?: string | undefined; + maxTokens?: number | undefined; + model?: string | undefined; + temperature?: number | undefined; + topK?: number | undefined; + topP?: number | undefined; + stopSequences?: string[] | undefined; +}; + +export type GenerativeAnyscaleConfigRuntime = { + baseURL?: string | undefined; + model?: string | undefined; + temperature?: number | undefined; }; -export type GenerateReturn = Promise> | Promise>; +export type GenerativeAWSConfigRuntime = { + model?: string | undefined; + temperature?: number | undefined; + service?: string | undefined; + region?: string | undefined; + endpoint?: string | undefined; + targetModel?: string | undefined; + targetVariant?: string | undefined; +}; + +export type GenerativeCohereConfigRuntime = { + baseURL?: string | undefined; + frequencyPenalty?: number | undefined; + maxTokens?: number | undefined; + model?: string | undefined; + k?: number | undefined; + p?: number | undefined; + presencePenalty?: number | undefined; + stopSequences?: string[] | undefined; + temperature?: number | undefined; +}; + +export type GenerativeDatabricksConfigRuntime = { + endpoint?: string | undefined; + model?: string | undefined; + frequencyPenalty?: number | undefined; + logProbs?: boolean | undefined; + topLogProbs?: number | undefined; + maxTokens?: number | undefined; + n?: number | undefined; + presencePenalty?: number | undefined; + stop?: string[] | undefined; + temperature?: number | undefined; + topP?: number | undefined; +}; + +export type GenerativeDummyConfigRuntime = GenerativeDummyGRPC; + +export type GenerativeFriendliAIConfigRuntime = { + baseURL?: string | undefined; + model?: string | undefined; + maxTokens?: number | undefined; + temperature?: number | undefined; + n?: number | undefined; + topP?: number | undefined; +}; + +export type GenerativeGoogleConfigRuntime = { + frequencyPenalty?: number | undefined; + maxTokens?: number | undefined; + model?: string | undefined; + presencePenalty?: number | undefined; + temperature?: number | undefined; + topK?: number | undefined; + topP?: number | undefined; + stopSequences?: string[] | undefined; + apiEndpoint?: string | undefined; + projectId?: string | undefined; + endpointId?: string | undefined; + region?: string | undefined; +}; + +export type GenerativeMistralConfigRuntime = { + baseURL?: string | undefined; + maxTokens?: number | undefined; + model?: string | undefined; + temperature?: number | undefined; + topP?: number | undefined; +}; + +export type GenerativeNvidiaConfigRuntime = { + baseURL?: string | undefined; + model?: string | undefined; + temperature?: number | undefined; + topP?: number | undefined; + maxTokens?: number | undefined; +}; + +export type GenerativeOllamaConfigRuntime = { + apiEndpoint?: string | undefined; + model?: string | undefined; + temperature?: number | undefined; +}; + +export type GenerativeOpenAIConfigRuntime = { + frequencyPenalty?: number | undefined; + maxTokens?: number | undefined; + model?: string; + n?: number | undefined; + presencePenalty?: number | undefined; + stop?: string[] | undefined; + temperature?: number | undefined; + topP?: number | undefined; + baseURL?: string | undefined; + apiVersion?: string | undefined; + resourceName?: string | undefined; + deploymentId?: string | undefined; +}; diff --git a/src/connection/http.ts b/src/connection/http.ts index 5250d930..1af076fd 100644 --- a/src/connection/http.ts +++ b/src/connection/http.ts @@ -115,11 +115,9 @@ export default class ConnectionREST { postReturn = (path: string, payload: B): Promise => { if (this.authEnabled) { - return this.login().then((token) => - this.http.post(path, payload, true, token).then((res) => res as T) - ); + return this.login().then((token) => this.http.post(path, payload, true, token) as T); } - return this.http.post(path, payload, true, '').then((res) => res as T); + return this.http.post(path, payload, true, '') as Promise; }; postEmpty = (path: string, payload: B): Promise => { diff --git a/src/grpc/searcher.ts b/src/grpc/searcher.ts index 6dc6103e..20496c6f 100644 --- a/src/grpc/searcher.ts +++ b/src/grpc/searcher.ts @@ -171,6 +171,7 @@ export default class Searcher extends Base implements Search { tenant: this.tenant, uses123Api: true, uses125Api: true, + uses127Api: true, }, { metadata: this.metadata, diff --git a/src/openapi/schema.ts b/src/openapi/schema.ts index 598ea51b..b5c11aa5 100644 --- a/src/openapi/schema.ts +++ b/src/openapi/schema.ts @@ -40,9 +40,29 @@ export interface paths { }; }; }; + '/replication/replicate': { + post: operations['replicate']; + }; '/users/own-info': { get: operations['getOwnInfo']; }; + '/users/db': { + get: operations['listAllUsers']; + }; + '/users/db/{user_id}': { + get: operations['getUserInfo']; + post: operations['createUser']; + delete: operations['deleteUser']; + }; + '/users/db/{user_id}/rotate-key': { + post: operations['rotateUserApiKey']; + }; + '/users/db/{user_id}/activate': { + post: operations['activateUser']; + }; + '/users/db/{user_id}/deactivate': { + post: operations['deactivateUser']; + }; '/authz/roles': { get: operations['getRoles']; post: operations['createRole']; @@ -61,9 +81,15 @@ export interface paths { post: operations['hasPermission']; }; '/authz/roles/{id}/users': { + get: operations['getUsersForRoleDeprecated']; + }; + '/authz/roles/{id}/user-assignments': { get: operations['getUsersForRole']; }; '/authz/users/{id}/roles': { + get: operations['getRolesForUserDeprecated']; + }; + '/authz/users/{id}/roles/{userType}': { get: operations['getRolesForUser']; }; '/authz/users/{id}/assign': { @@ -231,6 +257,16 @@ export interface paths { } export interface definitions { + /** + * @description the type of user + * @enum {string} + */ + UserTypeInput: 'db' | 'oidc'; + /** + * @description the type of user + * @enum {string} + */ + UserTypeOutput: 'db_user' | 'db_env_user' | 'oidc'; UserOwnInfo: { /** @description The groups associated to the user */ groups?: string[]; @@ -238,6 +274,23 @@ export interface definitions { /** @description The username associated with the provided key */ username: string; }; + DBUserInfo: { + /** @description The role names associated to the user */ + roles: string[]; + /** @description The user id of the given user */ + userId: string; + /** + * @description type of the returned user + * @enum {string} + */ + dbUserType: 'db_user' | 'db_env_user'; + /** @description activity status of the returned user */ + active: boolean; + }; + UserApiKey: { + /** @description The apikey */ + apikey: string; + }; Role: { /** @description role name */ name: string; @@ -349,7 +402,10 @@ export interface definitions { | 'update_collections' | 'delete_collections' | 'assign_and_revoke_users' + | 'create_users' | 'read_users' + | 'update_users' + | 'delete_users' | 'create_tenants' | 'read_tenants' | 'update_tenants' @@ -371,6 +427,7 @@ export interface definitions { /** @description The username that was extracted either from the authentication information */ username?: string; groups?: string[]; + userType?: definitions['UserTypeInput']; }; /** @description An array of available words and contexts. */ C11yWordsResponse: { @@ -603,6 +660,35 @@ export interface definitions { value?: { [key: string]: unknown }; merge?: definitions['Object']; }; + /** @description Request body to add a replica of given shard of a given collection */ + ReplicationReplicateReplicaRequest: { + /** @description The node containing the replica */ + sourceNodeName: string; + /** @description The node to add a copy of the replica on */ + destinationNodeName: string; + /** @description The collection name holding the shard */ + collectionId: string; + /** @description The shard id holding the replica to be copied */ + shardId: string; + }; + /** @description Request body to disable (soft-delete) a replica of given shard of a given collection */ + ReplicationDisableReplicaRequest: { + /** @description The node containing the replica to be disabled */ + nodeName: string; + /** @description The collection name holding the replica to be disabled */ + collectionId: string; + /** @description The shard id holding the replica to be disabled */ + shardId: string; + }; + /** @description Request body to delete a replica of given shard of a given collection */ + ReplicationDeleteReplicaRequest: { + /** @description The node containing the replica to be deleted */ + nodeName: string; + /** @description The collection name holding the replica to be delete */ + collectionId: string; + /** @description The shard id holding the replica to be deleted */ + shardId: string; + }; /** @description A single peer in the network. */ PeerUpdate: { /** @@ -708,7 +794,8 @@ export interface definitions { | 'trigram' | 'gse' | 'kagome_kr' - | 'kagome_ja'; + | 'kagome_ja' + | 'gse_ch'; /** @description The properties of the nested object(s). Applies to object and object[] data types. */ nestedProperties?: definitions['NestedProperty'][]; }; @@ -736,7 +823,8 @@ export interface definitions { | 'trigram' | 'gse' | 'kagome_kr' - | 'kagome_ja'; + | 'kagome_ja' + | 'gse_ch'; /** @description The properties of the nested object(s). Applies to object and object[] data types. */ nestedProperties?: definitions['NestedProperty'][]; }; @@ -1611,6 +1699,35 @@ export interface operations { 503: unknown; }; }; + replicate: { + parameters: { + body: { + body: definitions['ReplicationReplicateReplicaRequest']; + }; + }; + responses: { + /** Replication operation registered successfully */ + 200: unknown; + /** Malformed request. */ + 400: { + schema: definitions['ErrorResponse']; + }; + /** Unauthorized or invalid credentials. */ + 401: unknown; + /** Forbidden */ + 403: { + schema: definitions['ErrorResponse']; + }; + /** Request body is well-formed (i.e., syntactically correct), but semantically erroneous. */ + 422: { + schema: definitions['ErrorResponse']; + }; + /** An error has occurred while trying to fulfill the request. Most likely the ErrorResponse will contain more information about the error. */ + 500: { + schema: definitions['ErrorResponse']; + }; + }; + }; getOwnInfo: { responses: { /** Info about the user */ @@ -1625,6 +1742,229 @@ export interface operations { }; }; }; + listAllUsers: { + responses: { + /** Info about the user */ + 200: { + schema: definitions['DBUserInfo'][]; + }; + /** Unauthorized or invalid credentials. */ + 401: unknown; + /** Forbidden */ + 403: { + schema: definitions['ErrorResponse']; + }; + /** An error has occurred while trying to fulfill the request. Most likely the ErrorResponse will contain more information about the error. */ + 500: { + schema: definitions['ErrorResponse']; + }; + }; + }; + getUserInfo: { + parameters: { + path: { + /** user id */ + user_id: string; + }; + }; + responses: { + /** Info about the user */ + 200: { + schema: definitions['DBUserInfo']; + }; + /** Unauthorized or invalid credentials. */ + 401: unknown; + /** Forbidden */ + 403: { + schema: definitions['ErrorResponse']; + }; + /** user not found */ + 404: unknown; + /** An error has occurred while trying to fulfill the request. Most likely the ErrorResponse will contain more information about the error. */ + 500: { + schema: definitions['ErrorResponse']; + }; + }; + }; + createUser: { + parameters: { + path: { + /** user id */ + user_id: string; + }; + }; + responses: { + /** User created successfully */ + 201: { + schema: definitions['UserApiKey']; + }; + /** Malformed request. */ + 400: { + schema: definitions['ErrorResponse']; + }; + /** Unauthorized or invalid credentials. */ + 401: unknown; + /** Forbidden */ + 403: { + schema: definitions['ErrorResponse']; + }; + /** User already exists */ + 409: { + schema: definitions['ErrorResponse']; + }; + /** Request body is well-formed (i.e., syntactically correct), but semantically erroneous. */ + 422: { + schema: definitions['ErrorResponse']; + }; + /** An error has occurred while trying to fulfill the request. Most likely the ErrorResponse will contain more information about the error. */ + 500: { + schema: definitions['ErrorResponse']; + }; + }; + }; + deleteUser: { + parameters: { + path: { + /** user name */ + user_id: string; + }; + }; + responses: { + /** Successfully deleted. */ + 204: never; + /** Malformed request. */ + 400: { + schema: definitions['ErrorResponse']; + }; + /** Unauthorized or invalid credentials. */ + 401: unknown; + /** Forbidden */ + 403: { + schema: definitions['ErrorResponse']; + }; + /** user not found */ + 404: unknown; + /** Request body is well-formed (i.e., syntactically correct), but semantically erroneous. */ + 422: { + schema: definitions['ErrorResponse']; + }; + /** An error has occurred while trying to fulfill the request. Most likely the ErrorResponse will contain more information about the error. */ + 500: { + schema: definitions['ErrorResponse']; + }; + }; + }; + rotateUserApiKey: { + parameters: { + path: { + /** user id */ + user_id: string; + }; + }; + responses: { + /** ApiKey successfully changed */ + 200: { + schema: definitions['UserApiKey']; + }; + /** Malformed request. */ + 400: { + schema: definitions['ErrorResponse']; + }; + /** Unauthorized or invalid credentials. */ + 401: unknown; + /** Forbidden */ + 403: { + schema: definitions['ErrorResponse']; + }; + /** user not found */ + 404: unknown; + /** Request body is well-formed (i.e., syntactically correct), but semantically erroneous. */ + 422: { + schema: definitions['ErrorResponse']; + }; + /** An error has occurred while trying to fulfill the request. Most likely the ErrorResponse will contain more information about the error. */ + 500: { + schema: definitions['ErrorResponse']; + }; + }; + }; + activateUser: { + parameters: { + path: { + /** user id */ + user_id: string; + }; + }; + responses: { + /** User successfully activated */ + 200: unknown; + /** Malformed request. */ + 400: { + schema: definitions['ErrorResponse']; + }; + /** Unauthorized or invalid credentials. */ + 401: unknown; + /** Forbidden */ + 403: { + schema: definitions['ErrorResponse']; + }; + /** user not found */ + 404: unknown; + /** user already activated */ + 409: unknown; + /** Request body is well-formed (i.e., syntactically correct), but semantically erroneous. */ + 422: { + schema: definitions['ErrorResponse']; + }; + /** An error has occurred while trying to fulfill the request. Most likely the ErrorResponse will contain more information about the error. */ + 500: { + schema: definitions['ErrorResponse']; + }; + }; + }; + deactivateUser: { + parameters: { + path: { + /** user id */ + user_id: string; + }; + body: { + body?: { + /** + * @description if the key should be revoked when deactivating the user + * @default false + */ + revoke_key?: boolean; + }; + }; + }; + responses: { + /** users successfully deactivated */ + 200: unknown; + /** Malformed request. */ + 400: { + schema: definitions['ErrorResponse']; + }; + /** Unauthorized or invalid credentials. */ + 401: unknown; + /** Forbidden */ + 403: { + schema: definitions['ErrorResponse']; + }; + /** user not found */ + 404: unknown; + /** user already deactivated */ + 409: unknown; + /** Request body is well-formed (i.e., syntactically correct), but semantically erroneous. Are you sure the class is defined in the configuration file? */ + 422: { + schema: definitions['ErrorResponse']; + }; + /** An error has occurred while trying to fulfill the request. Most likely the ErrorResponse will contain more information about the error. */ + 500: { + schema: definitions['ErrorResponse']; + }; + }; + }; getRoles: { responses: { /** Successful response. */ @@ -1849,7 +2189,7 @@ export interface operations { }; }; }; - getUsersForRole: { + getUsersForRoleDeprecated: { parameters: { path: { /** role name */ @@ -1879,11 +2219,86 @@ export interface operations { }; }; }; + getUsersForRole: { + parameters: { + path: { + /** role name */ + id: string; + }; + }; + responses: { + /** Users assigned to this role */ + 200: { + schema: ({ + userId?: string; + userType: definitions['UserTypeOutput']; + } & { + name: unknown; + })[]; + }; + /** Bad request */ + 400: { + schema: definitions['ErrorResponse']; + }; + /** Unauthorized or invalid credentials. */ + 401: unknown; + /** Forbidden */ + 403: { + schema: definitions['ErrorResponse']; + }; + /** no role found */ + 404: unknown; + /** An error has occurred while trying to fulfill the request. Most likely the ErrorResponse will contain more information about the error. */ + 500: { + schema: definitions['ErrorResponse']; + }; + }; + }; + getRolesForUserDeprecated: { + parameters: { + path: { + /** user name */ + id: string; + }; + }; + responses: { + /** Role assigned users */ + 200: { + schema: definitions['RolesListResponse']; + }; + /** Bad request */ + 400: { + schema: definitions['ErrorResponse']; + }; + /** Unauthorized or invalid credentials. */ + 401: unknown; + /** Forbidden */ + 403: { + schema: definitions['ErrorResponse']; + }; + /** no role found for user */ + 404: unknown; + /** Request body is well-formed (i.e., syntactically correct), but semantically erroneous. Are you sure the class is defined in the configuration file? */ + 422: { + schema: definitions['ErrorResponse']; + }; + /** An error has occurred while trying to fulfill the request. Most likely the ErrorResponse will contain more information about the error. */ + 500: { + schema: definitions['ErrorResponse']; + }; + }; + }; getRolesForUser: { parameters: { path: { /** user name */ id: string; + /** The type of user */ + userType: 'oidc' | 'db'; + }; + query: { + /** Whether to include detailed role information needed the roles permission */ + includeFullRoles?: boolean; }; }; responses: { @@ -1903,6 +2318,10 @@ export interface operations { }; /** no role found for user */ 404: unknown; + /** Request body is well-formed (i.e., syntactically correct), but semantically erroneous. Are you sure the class is defined in the configuration file? */ + 422: { + schema: definitions['ErrorResponse']; + }; /** An error has occurred while trying to fulfill the request. Most likely the ErrorResponse will contain more information about the error. */ 500: { schema: definitions['ErrorResponse']; @@ -1919,6 +2338,7 @@ export interface operations { body: { /** @description the roles that assigned to user */ roles?: string[]; + userType?: definitions['UserTypeInput']; }; }; }; @@ -1955,6 +2375,7 @@ export interface operations { body: { /** @description the roles that revoked from the key or user */ roles?: string[]; + userType?: definitions['UserTypeInput']; }; }; }; diff --git a/src/openapi/types.ts b/src/openapi/types.ts index 054ab10d..59e6e7d1 100644 --- a/src/openapi/types.ts +++ b/src/openapi/types.ts @@ -1,4 +1,4 @@ -import { definitions } from './schema.js'; +import { definitions, operations } from './schema.js'; type Override = Omit & T2; type DefaultProperties = { [key: string]: unknown }; @@ -54,7 +54,6 @@ export type WeaviateMultiTenancyConfig = WeaviateClass['multiTenancyConfig']; export type WeaviateReplicationConfig = WeaviateClass['replicationConfig']; export type WeaviateShardingConfig = WeaviateClass['shardingConfig']; export type WeaviateShardStatus = definitions['ShardStatusGetResponse']; -export type WeaviateUser = definitions['UserOwnInfo']; export type WeaviateVectorIndexConfig = WeaviateClass['vectorIndexConfig']; export type WeaviateVectorsConfig = WeaviateClass['vectorConfig']; export type WeaviateVectorConfig = definitions['VectorConfig']; @@ -69,3 +68,9 @@ export type Meta = definitions['Meta']; export type Role = definitions['Role']; export type Permission = definitions['Permission']; export type Action = definitions['Permission']['action']; +export type WeaviateUser = definitions['UserOwnInfo']; +export type WeaviateDBUser = definitions['DBUserInfo']; +export type WeaviateUserType = definitions['UserTypeOutput']; +export type WeaviateUserTypeInternal = definitions['UserTypeInput']; +export type WeaviateUserTypeDB = definitions['DBUserInfo']['dbUserType']; +export type WeaviateAssignedUser = operations['getUsersForRole']['responses']['200']['schema'][0]; diff --git a/src/proto/v1/generative.ts b/src/proto/v1/generative.ts index 2abae4ba..fe0805fa 100644 --- a/src/proto/v1/generative.ts +++ b/src/proto/v1/generative.ts @@ -118,7 +118,7 @@ export interface GenerativeOllama { export interface GenerativeOpenAI { frequencyPenalty?: number | undefined; maxTokens?: number | undefined; - model: string; + model?: string | undefined; n?: number | undefined; presencePenalty?: number | undefined; stop?: TextArray | undefined; @@ -1884,7 +1884,7 @@ function createBaseGenerativeOpenAI(): GenerativeOpenAI { return { frequencyPenalty: undefined, maxTokens: undefined, - model: "", + model: undefined, n: undefined, presencePenalty: undefined, stop: undefined, @@ -1908,7 +1908,7 @@ export const GenerativeOpenAI = { if (message.maxTokens !== undefined) { writer.uint32(16).int64(message.maxTokens); } - if (message.model !== "") { + if (message.model !== undefined) { writer.uint32(26).string(message.model); } if (message.n !== undefined) { @@ -2075,7 +2075,7 @@ export const GenerativeOpenAI = { return { frequencyPenalty: isSet(object.frequencyPenalty) ? globalThis.Number(object.frequencyPenalty) : undefined, maxTokens: isSet(object.maxTokens) ? globalThis.Number(object.maxTokens) : undefined, - model: isSet(object.model) ? globalThis.String(object.model) : "", + model: isSet(object.model) ? globalThis.String(object.model) : undefined, n: isSet(object.n) ? globalThis.Number(object.n) : undefined, presencePenalty: isSet(object.presencePenalty) ? globalThis.Number(object.presencePenalty) : undefined, stop: isSet(object.stop) ? TextArray.fromJSON(object.stop) : undefined, @@ -2099,7 +2099,7 @@ export const GenerativeOpenAI = { if (message.maxTokens !== undefined) { obj.maxTokens = Math.round(message.maxTokens); } - if (message.model !== "") { + if (message.model !== undefined) { obj.model = message.model; } if (message.n !== undefined) { @@ -2148,7 +2148,7 @@ export const GenerativeOpenAI = { const message = createBaseGenerativeOpenAI(); message.frequencyPenalty = object.frequencyPenalty ?? undefined; message.maxTokens = object.maxTokens ?? undefined; - message.model = object.model ?? ""; + message.model = object.model ?? undefined; message.n = object.n ?? undefined; message.presencePenalty = object.presencePenalty ?? undefined; message.stop = (object.stop !== undefined && object.stop !== null) ? TextArray.fromPartial(object.stop) : undefined; diff --git a/src/roles/index.ts b/src/roles/index.ts index 2fbe723b..6971cf4f 100644 --- a/src/roles/index.ts +++ b/src/roles/index.ts @@ -1,5 +1,9 @@ import { ConnectionREST } from '../index.js'; -import { Permission as WeaviatePermission, Role as WeaviateRole } from '../openapi/types.js'; +import { + WeaviateAssignedUser, + Permission as WeaviatePermission, + Role as WeaviateRole, +} from '../openapi/types.js'; import { BackupsPermission, ClusterPermission, @@ -11,6 +15,7 @@ import { Role, RolesPermission, TenantsPermission, + UserAssignment, UsersPermission, } from './types.js'; import { Map } from './util.js'; @@ -30,12 +35,13 @@ export interface Roles { */ byName: (roleName: string) => Promise; /** - * Retrieve the user IDs assigned to a role. + * Retrieve the user IDs assigned to a role. Each user has a qualifying user type, + * e.g. `'db_user' | 'db_env_user' | 'oidc'`. * * @param {string} roleName The name of the role to retrieve the assigned user IDs for. * @returns {Promise} The user IDs assigned to the role. */ - assignedUserIds: (roleName: string) => Promise; + userAssignments: (roleName: string) => Promise; /** * Delete a role by its name. * @@ -89,15 +95,20 @@ const roles = (connection: ConnectionREST): Roles => { listAll: () => connection.get('/authz/roles').then(Map.roles), byName: (roleName: string) => connection.get(`/authz/roles/${roleName}`).then(Map.roleFromWeaviate), - assignedUserIds: (roleName: string) => connection.get(`/authz/roles/${roleName}/users`), - create: (roleName: string, permissions: PermissionsInput) => { - const perms = Map.flattenPermissions(permissions).flatMap(Map.permissionToWeaviate); + userAssignments: (roleName: string) => + connection + .get(`/authz/roles/${roleName}/user-assignments`, true) + .then(Map.assignedUsers), + create: (roleName: string, permissions?: PermissionsInput) => { + const perms = permissions + ? Map.flattenPermissions(permissions).flatMap(Map.permissionToWeaviate) + : undefined; return connection - .postEmpty('/authz/roles', { + .postEmpty('/authz/roles', { name: roleName, permissions: perms, }) - .then(() => Map.roleFromWeaviate({ name: roleName, permissions: perms })); + .then(() => Map.roleFromWeaviate({ name: roleName, permissions: perms || [] })); }, delete: (roleName: string) => connection.delete(`/authz/roles/${roleName}`, null), exists: (roleName: string) => @@ -106,9 +117,13 @@ const roles = (connection: ConnectionREST): Roles => { .then(() => true) .catch(() => false), addPermissions: (roleName: string, permissions: PermissionsInput) => - connection.postEmpty(`/authz/roles/${roleName}/add-permissions`, { permissions }), + connection.postEmpty(`/authz/roles/${roleName}/add-permissions`, { + permissions: Map.flattenPermissions(permissions).flatMap(Map.permissionToWeaviate), + }), removePermissions: (roleName: string, permissions: PermissionsInput) => - connection.postEmpty(`/authz/roles/${roleName}/remove-permissions`, { permissions }), + connection.postEmpty(`/authz/roles/${roleName}/remove-permissions`, { + permissions: Map.flattenPermissions(permissions).flatMap(Map.permissionToWeaviate), + }), hasPermissions: (roleName: string, permission: Permission | Permission[]) => Promise.all( (Array.isArray(permission) ? permission : [permission]) @@ -121,6 +136,15 @@ const roles = (connection: ConnectionREST): Roles => { }; export const permissions = { + /** + * Create a set of permissions specific to Weaviate's backup functionality. + * + * For all collections, provide the `collection` argument as `'*'`. + * + * @param {string | string[]} args.collection The collection or collections to create permissions for. + * @param {boolean} [args.manage] Whether to allow managing backups. Defaults to `false`. + * @returns {BackupsPermission[]} The permissions for the specified collections. + */ backup: (args: { collection: string | string[]; manage?: boolean }): BackupsPermission[] => { const collections = Array.isArray(args.collection) ? args.collection : [args.collection]; return collections.flatMap((collection) => { @@ -129,11 +153,27 @@ export const permissions = { return out; }); }, + /** + * Create a set of permissions specific to Weaviate's cluster endpoints. + * + * @param {boolean} [args.read] Whether to allow reading cluster information. Defaults to `false`. + */ cluster: (args: { read?: boolean }): ClusterPermission[] => { const out: ClusterPermission = { actions: [] }; if (args.read) out.actions.push('read_cluster'); return [out]; }, + /** + * Create a set of permissions specific to any operations involving collections. + * + * For all collections, provide the `collection` argument as `'*'`. + * + * @param {string | string[]} args.collection The collection or collections to create permissions for. + * @param {boolean} [args.create_collection] Whether to allow creating collections. Defaults to `false`. + * @param {boolean} [args.read_config] Whether to allow reading collection configurations. Defaults to `false`. + * @param {boolean} [args.update_config] Whether to allow updating collection configurations. Defaults to `false`. + * @param {boolean} [args.delete_collection] Whether to allow deleting collections. Defaults to `false`. + */ collections: (args: { collection: string | string[]; create_collection?: boolean; @@ -151,16 +191,37 @@ export const permissions = { return out; }); }, + /** + * Create a set of permissions specific to any operations involving objects within collections and tenants. + * + * For all collections, provide the `collection` argument as `'*'`. + * For all tenants, provide the `tenant` argument as `'*'`. + * + * Providing arrays of collections and tenants will create permissions for each combination of collection and tenant. + * E.g., `data({ collection: ['A', 'B'], tenant: ['X', 'Y'] })` will create permissions for tenants `X` and `Y` in both collections `A` and `B`. + * + * @param {string | string[]} args.collection The collection or collections to create permissions for. + * @param {string | string[]} [args.tenant] The tenant or tenants to create permissions for. Defaults to `'*'`. + * @param {boolean} [args.create] Whether to allow creating objects. Defaults to `false`. + * @param {boolean} [args.read] Whether to allow reading objects. Defaults to `false`. + * @param {boolean} [args.update] Whether to allow updating objects. Defaults to `false`. + * @param {boolean} [args.delete] Whether to allow deleting objects. Defaults to `false`. + */ data: (args: { collection: string | string[]; + tenant?: string | string[]; create?: boolean; read?: boolean; update?: boolean; delete?: boolean; }): DataPermission[] => { const collections = Array.isArray(args.collection) ? args.collection : [args.collection]; - return collections.flatMap((collection) => { - const out: DataPermission = { collection, actions: [] }; + const tenants = Array.isArray(args.tenant) ? args.tenant : [args.tenant ?? '*']; + const combinations = collections.flatMap((collection) => + tenants.map((tenant) => ({ collection, tenant })) + ); + return combinations.flatMap(({ collection, tenant }) => { + const out: DataPermission = { collection, tenant, actions: [] }; if (args.create) out.actions.push('create_data'); if (args.read) out.actions.push('read_data'); if (args.update) out.actions.push('update_data'); @@ -168,7 +229,16 @@ export const permissions = { return out; }); }, + /** + * This namespace contains methods to create permissions specific to nodes. + */ nodes: { + /** + * Create a set of permissions specific to reading nodes with verbosity set to `minimal`. + * + * @param {boolean} [args.read] Whether to allow reading nodes. Defaults to `false`. + * @returns {NodesPermission[]} The permissions for reading nodes. + */ minimal: (args: { read?: boolean }): NodesPermission[] => { const out: NodesPermission = { collection: '*', @@ -178,6 +248,13 @@ export const permissions = { if (args.read) out.actions.push('read_nodes'); return [out]; }, + /** + * Create a set of permissions specific to reading nodes with verbosity set to `verbose`. + * + * @param {string | string[]} args.collection The collection or collections to create permissions for. + * @param {boolean} [args.read] Whether to allow reading nodes. Defaults to `false`. + * @returns {NodesPermission[]} The permissions for reading nodes. + */ verbose: (args: { collection: string | string[]; read?: boolean }): NodesPermission[] => { const collections = Array.isArray(args.collection) ? args.collection : [args.collection]; return collections.flatMap((collection) => { @@ -191,6 +268,16 @@ export const permissions = { }); }, }, + /** + * Create a set of permissions specific to any operations involving roles. + * + * @param {string | string[]} args.role The role or roles to create permissions for. + * @param {boolean} [args.create] Whether to allow creating roles. Defaults to `false`. + * @param {boolean} [args.read] Whether to allow reading roles. Defaults to `false`. + * @param {boolean} [args.update] Whether to allow updating roles. Defaults to `false`. + * @param {boolean} [args.delete] Whether to allow deleting roles. Defaults to `false`. + * @returns {RolesPermission[]} The permissions for the specified roles. + */ roles: (args: { role: string | string[]; create?: boolean; @@ -208,16 +295,38 @@ export const permissions = { return out; }); }, + /** + * Create a set of permissions specific to any operations involving tenants. + * + * For all collections, provide the `collection` argument as `'*'`. + * For all tenants, provide the `tenant` argument as `'*'`. + * + * Providing arrays of collections and tenants will create permissions for each combination of collection and tenant. + * E.g., `tenants({ collection: ['A', 'B'], tenant: ['X', 'Y'] })` will create permissions for tenants `X` and `Y` in both collections `A` and `B`. + * + * @param {string | string[] | Record} args.collection The collection or collections to create permissions for. + * @param {string | string[]} [args.tenant] The tenant or tenants to create permissions for. Defaults to `'*'`. + * @param {boolean} [args.create] Whether to allow creating tenants. Defaults to `false`. + * @param {boolean} [args.read] Whether to allow reading tenants. Defaults to `false`. + * @param {boolean} [args.update] Whether to allow updating tenants. Defaults to `false`. + * @param {boolean} [args.delete] Whether to allow deleting tenants. Defaults to `false`. + * @returns {TenantsPermission[]} The permissions for the specified tenants. + */ tenants: (args: { collection: string | string[]; + tenant?: string | string[]; create?: boolean; read?: boolean; update?: boolean; delete?: boolean; }): TenantsPermission[] => { const collections = Array.isArray(args.collection) ? args.collection : [args.collection]; - return collections.flatMap((collection) => { - const out: TenantsPermission = { collection, actions: [] }; + const tenants = Array.isArray(args.tenant) ? args.tenant : [args.tenant ?? '*']; + const combinations = collections.flatMap((collection) => + tenants.map((tenant) => ({ collection, tenant })) + ); + return combinations.flatMap(({ collection, tenant }) => { + const out: TenantsPermission = { collection, tenant, actions: [] }; if (args.create) out.actions.push('create_tenants'); if (args.read) out.actions.push('read_tenants'); if (args.update) out.actions.push('update_tenants'); @@ -225,15 +334,23 @@ export const permissions = { return out; }); }, + /** + * Create a set of permissions specific to any operations involving users. + * + * @param {string | string[]} args.user The user or users to create permissions for. + * @param {boolean} [args.assignAndRevoke] Whether to allow assigning and revoking users. Defaults to `false`. + * @param {boolean} [args.read] Whether to allow reading users. Defaults to `false`. + * @returns {UsersPermission[]} The permissions for the specified users. + */ users: (args: { user: string | string[]; - assign_and_revoke?: boolean; + assignAndRevoke?: boolean; read?: boolean; }): UsersPermission[] => { const users = Array.isArray(args.user) ? args.user : [args.user]; return users.flatMap((user) => { const out: UsersPermission = { users: user, actions: [] }; - if (args.assign_and_revoke) out.actions.push('assign_and_revoke_users'); + if (args.assignAndRevoke) out.actions.push('assign_and_revoke_users'); if (args.read) out.actions.push('read_users'); return out; }); diff --git a/src/roles/integration.test.ts b/src/roles/integration.test.ts index 7d083209..1d335c9f 100644 --- a/src/roles/integration.test.ts +++ b/src/roles/integration.test.ts @@ -1,12 +1,289 @@ -import weaviate, { ApiKey, Permission, Role, WeaviateClient } from '..'; +import weaviate, { + ApiKey, + CollectionsAction, + DataAction, + Permission, + Role, + RolesAction, + TenantsAction, + UserAssignment, + WeaviateClient, +} from '..'; +import { requireAtLeast } from '../../test/version'; import { WeaviateStartUpError, WeaviateUnexpectedStatusCodeError } from '../errors'; -import { DbVersion } from '../utils/dbVersion'; -const only = DbVersion.fromString(`v${process.env.WEAVIATE_VERSION!}`).isAtLeast(1, 29, 0) - ? describe - : describe.skip; +type TestCase = { + roleName: string; + permissions: Permission[]; + expected: Role; +}; -only('Integration testing of the roles namespace', () => { +const emptyPermissions = { + backupsPermissions: [], + clusterPermissions: [], + collectionsPermissions: [], + dataPermissions: [], + nodesPermissions: [], + rolesPermissions: [], + tenantsPermissions: [], + usersPermissions: [], +}; +const crud = { + create: true, + read: true, + update: true, + delete: true, +}; +const collectionsActions: CollectionsAction[] = [ + 'create_collections', + 'read_collections', + 'update_collections', + 'delete_collections', +]; +const dataActions: DataAction[] = ['create_data', 'read_data', 'update_data', 'delete_data']; +const tenantsActions: TenantsAction[] = [ + 'create_tenants', + 'read_tenants', + 'update_tenants', + 'delete_tenants', +]; +const rolesActions: RolesAction[] = ['create_roles', 'read_roles', 'update_roles', 'delete_roles']; +const testCases: TestCase[] = [ + { + roleName: 'backups', + permissions: weaviate.permissions.backup({ collection: 'Some-collection', manage: true }), + expected: { + name: 'backups', + ...emptyPermissions, + backupsPermissions: [{ collection: 'Some-collection', actions: ['manage_backups'] }], + }, + }, + { + roleName: 'cluster', + permissions: weaviate.permissions.cluster({ read: true }), + expected: { + name: 'cluster', + ...emptyPermissions, + clusterPermissions: [{ actions: ['read_cluster'] }], + }, + }, + { + roleName: 'collections', + permissions: weaviate.permissions.collections({ + collection: 'Some-collection', + create_collection: true, + read_config: true, + update_config: true, + delete_collection: true, + }), + expected: { + name: 'collections', + ...emptyPermissions, + collectionsPermissions: [ + { + collection: 'Some-collection', + actions: collectionsActions, + }, + ], + }, + }, + { + roleName: 'data-st', + permissions: weaviate.permissions.data({ + collection: 'Some-collection', + ...crud, + }), + expected: { + name: 'data-st', + ...emptyPermissions, + dataPermissions: [ + { + collection: 'Some-collection', + tenant: '*', + actions: dataActions, + }, + ], + }, + }, + { + roleName: 'data-mt', + permissions: weaviate.permissions.data({ + collection: 'Some-collection', + tenant: 'some-tenant', + ...crud, + }), + expected: { + name: 'data-mt', + ...emptyPermissions, + dataPermissions: [ + { + collection: 'Some-collection', + tenant: 'some-tenant', + actions: dataActions, + }, + ], + }, + }, + { + roleName: 'data-mt-mixed', + permissions: weaviate.permissions.data({ + collection: ['Some-collection', 'Another-collection'], + tenant: ['some-tenant', 'another-tenant'], + ...crud, + }), + expected: { + name: 'data-mt-mixed', + ...emptyPermissions, + dataPermissions: [ + { + collection: 'Some-collection', + tenant: 'some-tenant', + actions: dataActions, + }, + { + collection: 'Some-collection', + tenant: 'another-tenant', + actions: dataActions, + }, + { + collection: 'Another-collection', + tenant: 'some-tenant', + actions: dataActions, + }, + { + collection: 'Another-collection', + tenant: 'another-tenant', + actions: dataActions, + }, + ], + }, + }, + { + roleName: 'nodes-verbose', + permissions: weaviate.permissions.nodes.verbose({ + collection: 'Some-collection', + read: true, + }), + expected: { + name: 'nodes-verbose', + ...emptyPermissions, + nodesPermissions: [{ collection: 'Some-collection', verbosity: 'verbose', actions: ['read_nodes'] }], + }, + }, + { + roleName: 'nodes-minimal', + permissions: weaviate.permissions.nodes.minimal({ + read: true, + }), + expected: { + name: 'nodes-minimal', + ...emptyPermissions, + nodesPermissions: [{ collection: '*', verbosity: 'minimal', actions: ['read_nodes'] }], + }, + }, + { + roleName: 'roles', + permissions: weaviate.permissions.roles({ + role: 'some-role', + ...crud, + }), + expected: { + name: 'roles', + ...emptyPermissions, + rolesPermissions: [{ role: 'some-role', actions: rolesActions }], + }, + }, + { + roleName: 'tenants-st', + permissions: weaviate.permissions.tenants({ + collection: 'some-collection', + ...crud, + }), + expected: { + name: 'tenants-st', + ...emptyPermissions, + tenantsPermissions: [ + { + collection: 'Some-collection', + tenant: '*', + actions: tenantsActions, + }, + ], + }, + }, + { + roleName: 'tenants-mt', + permissions: weaviate.permissions.tenants({ + collection: 'some-collection', + tenant: 'some-tenant', + ...crud, + }), + expected: { + name: 'tenants-mt', + ...emptyPermissions, + tenantsPermissions: [ + { + collection: 'Some-collection', + tenant: 'some-tenant', + actions: tenantsActions, + }, + ], + }, + }, + { + roleName: 'tenants-mt-mixed', + permissions: weaviate.permissions.tenants({ + collection: ['some-collection', 'another-collection'], + tenant: ['some-tenant', 'another-tenant'], + ...crud, + }), + expected: { + name: 'tenants-mt-mixed', + ...emptyPermissions, + tenantsPermissions: [ + { + collection: 'Some-collection', + tenant: 'some-tenant', + actions: tenantsActions, + }, + { + collection: 'Some-collection', + tenant: 'another-tenant', + actions: tenantsActions, + }, + { + collection: 'Another-collection', + tenant: 'some-tenant', + actions: tenantsActions, + }, + { + collection: 'Another-collection', + tenant: 'another-tenant', + actions: tenantsActions, + }, + ], + }, + }, + { + roleName: 'users', + permissions: weaviate.permissions.users({ + user: 'some-user', + assignAndRevoke: true, + read: true, + }), + expected: { + name: 'users', + ...emptyPermissions, + usersPermissions: [{ users: 'some-user', actions: ['assign_and_revoke_users', 'read_users'] }], + }, + }, +]; + +requireAtLeast( + 1, + 29, + 0 +)('Integration testing of the roles namespace', () => { let client: WeaviateClient; beforeAll(async () => { @@ -40,201 +317,37 @@ only('Integration testing of the roles namespace', () => { expect(exists).toBeFalsy(); }); + requireAtLeast( + 1, + 30, + 0 + )('namespaced users', () => { + it('retrieves assigned users with namespace', async () => { + await client.roles.create('landlord', { + collection: 'Buildings', + tenant: 'john-doe', + actions: ['create_tenants', 'delete_tenants'], + }); + + await client.users.db.create('Innkeeper').catch((res) => expect(res.code).toEqual(409)); + + await client.users.db.assignRoles('landlord', 'custom-user'); + await client.users.db.assignRoles('landlord', 'Innkeeper'); + + const assignments = await client.roles.userAssignments('landlord'); + expect(assignments).toEqual( + expect.arrayContaining([ + expect.objectContaining({ id: 'custom-user', userType: 'db_env_user' }), + expect.objectContaining({ id: 'Innkeeper', userType: 'db_user' }), + ]) + ); + + await client.users.db.delete('Innkeeper'); + await client.roles.delete('landlord'); + }); + }); + describe('should be able to create roles using the permissions factory', () => { - type TestCase = { - roleName: string; - permissions: Permission[]; - expected: Role; - }; - const testCases: TestCase[] = [ - { - roleName: 'backups', - permissions: weaviate.permissions.backup({ collection: 'Some-collection', manage: true }), - expected: { - name: 'backups', - backupsPermissions: [{ collection: 'Some-collection', actions: ['manage_backups'] }], - clusterPermissions: [], - collectionsPermissions: [], - dataPermissions: [], - nodesPermissions: [], - rolesPermissions: [], - tenantsPermissions: [], - usersPermissions: [], - }, - }, - { - roleName: 'cluster', - permissions: weaviate.permissions.cluster({ read: true }), - expected: { - name: 'cluster', - backupsPermissions: [], - clusterPermissions: [{ actions: ['read_cluster'] }], - collectionsPermissions: [], - dataPermissions: [], - nodesPermissions: [], - rolesPermissions: [], - tenantsPermissions: [], - usersPermissions: [], - }, - }, - { - roleName: 'collections', - permissions: weaviate.permissions.collections({ - collection: 'Some-collection', - create_collection: true, - read_config: true, - update_config: true, - delete_collection: true, - }), - expected: { - name: 'collections', - backupsPermissions: [], - clusterPermissions: [], - collectionsPermissions: [ - { - collection: 'Some-collection', - actions: ['create_collections', 'read_collections', 'update_collections', 'delete_collections'], - }, - ], - dataPermissions: [], - nodesPermissions: [], - rolesPermissions: [], - tenantsPermissions: [], - usersPermissions: [], - }, - }, - { - roleName: 'data', - permissions: weaviate.permissions.data({ - collection: 'Some-collection', - create: true, - read: true, - update: true, - delete: true, - }), - expected: { - name: 'data', - backupsPermissions: [], - clusterPermissions: [], - collectionsPermissions: [], - dataPermissions: [ - { - collection: 'Some-collection', - actions: ['create_data', 'read_data', 'update_data', 'delete_data'], - }, - ], - nodesPermissions: [], - rolesPermissions: [], - tenantsPermissions: [], - usersPermissions: [], - }, - }, - { - roleName: 'nodes-verbose', - permissions: weaviate.permissions.nodes.verbose({ - collection: 'Some-collection', - read: true, - }), - expected: { - name: 'nodes-verbose', - backupsPermissions: [], - clusterPermissions: [], - collectionsPermissions: [], - dataPermissions: [], - nodesPermissions: [ - { collection: 'Some-collection', verbosity: 'verbose', actions: ['read_nodes'] }, - ], - rolesPermissions: [], - tenantsPermissions: [], - usersPermissions: [], - }, - }, - { - roleName: 'nodes-minimal', - permissions: weaviate.permissions.nodes.minimal({ - read: true, - }), - expected: { - name: 'nodes-minimal', - backupsPermissions: [], - clusterPermissions: [], - collectionsPermissions: [], - dataPermissions: [], - nodesPermissions: [{ collection: '*', verbosity: 'minimal', actions: ['read_nodes'] }], - rolesPermissions: [], - tenantsPermissions: [], - usersPermissions: [], - }, - }, - { - roleName: 'roles', - permissions: weaviate.permissions.roles({ - role: 'some-role', - create: true, - read: true, - update: true, - delete: true, - }), - expected: { - name: 'roles', - backupsPermissions: [], - clusterPermissions: [], - collectionsPermissions: [], - dataPermissions: [], - nodesPermissions: [], - rolesPermissions: [ - { role: 'some-role', actions: ['create_roles', 'read_roles', 'update_roles', 'delete_roles'] }, - ], - tenantsPermissions: [], - usersPermissions: [], - }, - }, - { - roleName: 'tenants', - permissions: weaviate.permissions.tenants({ - collection: 'some-collection', - create: true, - read: true, - update: true, - delete: true, - }), - expected: { - name: 'tenants', - backupsPermissions: [], - clusterPermissions: [], - collectionsPermissions: [], - dataPermissions: [], - nodesPermissions: [], - rolesPermissions: [], - tenantsPermissions: [ - { - collection: 'Some-collection', - actions: ['create_tenants', 'read_tenants', 'update_tenants', 'delete_tenants'], - }, - ], - usersPermissions: [], - }, - }, - { - roleName: 'users', - permissions: weaviate.permissions.users({ - user: 'some-user', - assign_and_revoke: true, - read: true, - }), - expected: { - name: 'users', - backupsPermissions: [], - clusterPermissions: [], - collectionsPermissions: [], - dataPermissions: [], - nodesPermissions: [], - rolesPermissions: [], - tenantsPermissions: [], - usersPermissions: [{ users: 'some-user', actions: ['assign_and_revoke_users', 'read_users'] }], - }, - }, - ]; testCases.forEach((testCase) => { it(`with ${testCase.roleName} permissions`, async () => { await client.roles.create(testCase.roleName, testCase.permissions); @@ -244,25 +357,40 @@ only('Integration testing of the roles namespace', () => { }); }); + it('should be able to add permissions to one of the created roles', async () => { + await client.roles.addPermissions( + 'backups', + weaviate.permissions.backup({ collection: 'Another-collection', manage: true }) + ); + const role = await client.roles.byName('backups'); + expect(role).toEqual({ + name: 'backups', + ...emptyPermissions, + backupsPermissions: [ + { collection: 'Some-collection', actions: ['manage_backups'] }, + { collection: 'Another-collection', actions: ['manage_backups'] }, + ], + }); + }); + + it('should be able to remove permissions from one of the created roles', async () => { + await client.roles.removePermissions( + 'backups', + weaviate.permissions.backup({ collection: 'Another-collection', manage: true }) + ); + const role = await client.roles.byName('backups'); + expect(role).toEqual({ + name: 'backups', + ...emptyPermissions, + backupsPermissions: [{ collection: 'Some-collection', actions: ['manage_backups'] }], + }); + }); + it('should delete one of the created roles', async () => { await client.roles.delete('backups'); await expect(client.roles.byName('backups')).rejects.toThrowError(WeaviateUnexpectedStatusCodeError); await expect(client.roles.exists('backups')).resolves.toBeFalsy(); }); - afterAll(() => - Promise.all( - [ - 'backups', - 'cluster', - 'collections', - 'data', - 'nodes-verbose', - 'nodes-minimal', - 'roles', - 'tenants', - 'users', - ].map((n) => client.roles.delete(n)) - ) - ); + afterAll(() => Promise.all(testCases.map((t) => client.roles.delete(t.roleName)))); }); diff --git a/src/roles/types.ts b/src/roles/types.ts index 6f314245..93273521 100644 --- a/src/roles/types.ts +++ b/src/roles/types.ts @@ -1,4 +1,4 @@ -import { Action } from '../openapi/types.js'; +import { Action, WeaviateUserType } from '../openapi/types.js'; export type BackupsAction = Extract; export type ClusterAction = Extract; @@ -22,6 +22,11 @@ export type TenantsAction = Extract< >; export type UsersAction = Extract; +export type UserAssignment = { + id: string; + userType: WeaviateUserType; +}; + export type BackupsPermission = { collection: string; actions: BackupsAction[]; @@ -38,6 +43,7 @@ export type CollectionsPermission = { export type DataPermission = { collection: string; + tenant: string; actions: DataAction[]; }; @@ -54,6 +60,7 @@ export type RolesPermission = { export type TenantsPermission = { collection: string; + tenant: string; actions: TenantsAction[]; }; diff --git a/src/roles/util.ts b/src/roles/util.ts index b4a8a9bf..09f82bb0 100644 --- a/src/roles/util.ts +++ b/src/roles/util.ts @@ -1,8 +1,15 @@ -import { Permission as WeaviatePermission, Role as WeaviateRole, WeaviateUser } from '../openapi/types.js'; -import { User } from '../users/types.js'; +import { + WeaviateAssignedUser, + WeaviateDBUser, + Permission as WeaviatePermission, + Role as WeaviateRole, + WeaviateUser, +} from '../openapi/types.js'; +import { User, UserDB } from '../users/types.js'; import { BackupsAction, BackupsPermission, + ClusterAction, ClusterPermission, CollectionsAction, CollectionsPermission, @@ -17,41 +24,46 @@ import { RolesPermission, TenantsAction, TenantsPermission, + UserAssignment, UsersAction, UsersPermission, } from './types.js'; export class PermissionGuards { - private static includes = (permission: Permission, ...actions: string[]): boolean => + private static includes = (permission: Permission, ...actions: A[]): boolean => actions.filter((a) => Array.from(permission.actions).includes(a)).length > 0; static isBackups = (permission: Permission): permission is BackupsPermission => - PermissionGuards.includes(permission, 'manage_backups'); + PermissionGuards.includes(permission, 'manage_backups'); static isCluster = (permission: Permission): permission is ClusterPermission => - PermissionGuards.includes(permission, 'read_cluster'); + PermissionGuards.includes(permission, 'read_cluster'); static isCollections = (permission: Permission): permission is CollectionsPermission => - PermissionGuards.includes( + PermissionGuards.includes( permission, 'create_collections', 'delete_collections', 'read_collections', - 'update_collections', - 'manage_collections' + 'update_collections' ); static isData = (permission: Permission): permission is DataPermission => - PermissionGuards.includes( + PermissionGuards.includes( permission, 'create_data', 'delete_data', 'read_data', - 'update_data', - 'manage_data' + 'update_data' ); static isNodes = (permission: Permission): permission is NodesPermission => - PermissionGuards.includes(permission, 'read_nodes'); + PermissionGuards.includes(permission, 'read_nodes'); static isRoles = (permission: Permission): permission is RolesPermission => - PermissionGuards.includes(permission, 'create_role', 'read_roles', 'update_roles', 'delete_roles'); + PermissionGuards.includes( + permission, + 'create_roles', + 'read_roles', + 'update_roles', + 'delete_roles' + ); static isTenants = (permission: Permission): permission is TenantsPermission => - PermissionGuards.includes( + PermissionGuards.includes( permission, 'create_tenants', 'delete_tenants', @@ -59,7 +71,7 @@ export class PermissionGuards { 'update_tenants' ); static isUsers = (permission: Permission): permission is UsersPermission => - PermissionGuards.includes(permission, 'read_users', 'assign_and_revoke_users'); + PermissionGuards.includes(permission, 'read_users', 'assign_and_revoke_users'); static isPermission = (permissions: PermissionsInput): permissions is Permission => !Array.isArray(permissions); static isPermissionArray = (permissions: PermissionsInput): permissions is Permission[] => @@ -81,35 +93,35 @@ export class Map { static permissionToWeaviate = (permission: Permission): WeaviatePermission[] => { if (PermissionGuards.isBackups(permission)) { return Array.from(permission.actions).map((action) => ({ - backups: { collection: permission.collection }, + backups: permission, action, })); } else if (PermissionGuards.isCluster(permission)) { return Array.from(permission.actions).map((action) => ({ action })); } else if (PermissionGuards.isCollections(permission)) { return Array.from(permission.actions).map((action) => ({ - collections: { collection: permission.collection }, + collections: permission, action, })); } else if (PermissionGuards.isData(permission)) { return Array.from(permission.actions).map((action) => ({ - data: { collection: permission.collection }, + data: permission, action, })); } else if (PermissionGuards.isNodes(permission)) { return Array.from(permission.actions).map((action) => ({ - nodes: { collection: permission.collection, verbosity: permission.verbosity }, + nodes: permission, action, })); } else if (PermissionGuards.isRoles(permission)) { - return Array.from(permission.actions).map((action) => ({ roles: { role: permission.role }, action })); + return Array.from(permission.actions).map((action) => ({ roles: permission, action })); } else if (PermissionGuards.isTenants(permission)) { return Array.from(permission.actions).map((action) => ({ - tenants: { collection: permission.collection }, + tenants: permission, action, })); } else if (PermissionGuards.isUsers(permission)) { - return Array.from(permission.actions).map((action) => ({ users: { users: permission.users }, action })); + return Array.from(permission.actions).map((action) => ({ users: permission, action })); } else { throw new Error(`Unknown permission type: ${JSON.stringify(permission, null, 2)}`); } @@ -118,20 +130,38 @@ export class Map { static roleFromWeaviate = (role: WeaviateRole): Role => PermissionsMapping.use(role).map(); static roles = (roles: WeaviateRole[]): Record => - roles.reduce((acc, role) => { - acc[role.name] = Map.roleFromWeaviate(role); - return acc; - }, {} as Record); + roles.reduce( + (acc, role) => ({ + ...acc, + [role.name]: Map.roleFromWeaviate(role), + }), + {} as Record + ); static users = (users: string[]): Record => - users.reduce((acc, user) => { - acc[user] = { id: user }; - return acc; - }, {} as Record); + users.reduce( + (acc, user) => ({ + ...acc, + [user]: { id: user }, + }), + {} as Record + ); static user = (user: WeaviateUser): User => ({ id: user.username, roles: user.roles?.map(Map.roleFromWeaviate), }); + static dbUser = (user: WeaviateDBUser): UserDB => ({ + userType: user.dbUserType, + id: user.userId, + roleNames: user.roles, + active: user.active, + }); + static dbUsers = (users: WeaviateDBUser[]): UserDB[] => users.map(Map.dbUser); + static assignedUsers = (users: WeaviateAssignedUser[]): UserAssignment[] => + users.map((user) => ({ + id: user.userId || '', + userType: user.userType, + })); } class PermissionsMapping { @@ -155,7 +185,11 @@ class PermissionsMapping { public static use = (role: WeaviateRole) => new PermissionsMapping(role); public map = (): Role => { - this.role.permissions.forEach(this.permissionFromWeaviate); + // If truncated roles are requested (?includeFullRoles=false), + // role.permissions are not present. + if (this.role.permissions !== null) { + this.role.permissions.forEach(this.permissionFromWeaviate); + } return { name: this.role.name, backupsPermissions: Object.values(this.mappings.backups), @@ -198,9 +232,11 @@ class PermissionsMapping { private data = (permission: WeaviatePermission) => { if (permission.data !== undefined) { - const key = permission.data.collection; - if (key === undefined) throw new Error('Data permission missing collection'); - if (this.mappings.data[key] === undefined) this.mappings.data[key] = { collection: key, actions: [] }; + const { collection, tenant } = permission.data; + if (collection === undefined) throw new Error('Data permission missing collection'); + const key = tenant === undefined ? collection : `${collection}#${tenant}`; + if (this.mappings.data[key] === undefined) + this.mappings.data[key] = { collection, tenant: tenant || '*', actions: [] }; this.mappings.data[key].actions.push(permission.action as DataAction); } }; @@ -232,10 +268,11 @@ class PermissionsMapping { private tenants = (permission: WeaviatePermission) => { if (permission.tenants !== undefined) { - const key = permission.tenants.collection; - if (key === undefined) throw new Error('Tenants permission missing collection'); + const { collection, tenant } = permission.tenants; + if (collection === undefined) throw new Error('Tenants permission missing collection'); + const key = tenant === undefined ? collection : `${collection}#${tenant}`; if (this.mappings.tenants[key] === undefined) - this.mappings.tenants[key] = { collection: key, actions: [] }; + this.mappings.tenants[key] = { collection, tenant: tenant || '*', actions: [] }; this.mappings.tenants[key].actions.push(permission.action as TenantsAction); } }; diff --git a/src/users/index.ts b/src/users/index.ts index 353909fc..7b2c608a 100644 --- a/src/users/index.ts +++ b/src/users/index.ts @@ -1,10 +1,39 @@ +import { WeaviateUnexpectedStatusCodeError } from '../errors.js'; import { ConnectionREST } from '../index.js'; -import { Role as WeaviateRole, WeaviateUser } from '../openapi/types.js'; +import { + WeaviateUserTypeInternal as UserTypeInternal, + WeaviateDBUser, + Role as WeaviateRole, + WeaviateUser, +} from '../openapi/types.js'; import { Role } from '../roles/types.js'; import { Map } from '../roles/util.js'; -import { User } from './types.js'; +import { AssignRevokeOptions, DeactivateOptions, GetAssignedRolesOptions, User, UserDB } from './types.js'; -export interface Users { +/** + * Operations supported for 'db', 'oidc', and legacy (non-namespaced) users. + * Use respective implementations in `users.db` and `users.oidc`, and `users`. + */ +interface UsersBase { + /** + * Assign roles to a user. + * + * @param {string | string[]} roleNames The name or names of the roles to assign. + * @param {string} userId The ID of the user to assign the roles to. + * @returns {Promise} A promise that resolves when the roles are assigned. + */ + assignRoles: (roleNames: string | string[], userId: string) => Promise; + /** + * Revoke roles from a user. + * + * @param {string | string[]} roleNames The name or names of the roles to revoke. + * @param {string} userId The ID of the user to revoke the roles from. + * @returns {Promise} A promise that resolves when the roles are revoked. + */ + revokeRoles: (roleNames: string | string[], userId: string) => Promise; +} + +export interface Users extends UsersBase { /** * Retrieve the information relevant to the currently authenticated user. * @@ -18,35 +47,202 @@ export interface Users { * @returns {Promise>} A map of role names to their respective roles. */ getAssignedRoles: (userId: string) => Promise>; + + db: DBUsers; + oidc: OIDCUsers; +} + +/** Operations supported for namespaced 'db' users.*/ +export interface DBUsers extends UsersBase { /** - * Assign roles to a user. + * Retrieve the roles assigned to a 'db_user' user. * - * @param {string | string[]} roleNames The name or names of the roles to assign. - * @param {string} userId The ID of the user to assign the roles to. - * @returns {Promise} A promise that resolves when the roles are assigned. + * @param {string} userId The ID of the user to retrieve the assigned roles for. + * @returns {Promise>} A map of role names to their respective roles. */ - assignRoles: (roleNames: string | string[], userId: string) => Promise; + getAssignedRoles: (userId: string, opts?: GetAssignedRolesOptions) => Promise>; + + /** Create a new 'db_user' user. + * + * @param {string} userId The ID of the user to create. Must consist of valid URL characters only. + * @returns {Promise} API key for the newly created user. + */ + create: (userId: string) => Promise; + /** - * Revoke roles from a user. + * Delete a 'db_user' user. It is not possible to delete 'db_env_user' users programmatically. * - * @param {string | string[]} roleNames The name or names of the roles to revoke. - * @param {string} userId The ID of the user to revoke the roles from. - * @returns {Promise} A promise that resolves when the roles are revoked. + * @param {string} userId The ID of the user to delete. + * @returns {Promise} `true` if the user has been successfully deleted. */ - revokeRoles: (roleNames: string | string[], userId: string) => Promise; + delete: (userId: string) => Promise; + + /** + * Rotate the API key of a 'db_user' user. The old API key becomes invalid. + * API keys of 'db_env_user' users are defined in the server's environment + * and cannot be modified programmatically. + * + * @param {string} userId The ID of the user to create a new API key for. + * @returns {Promise} New API key for the user. + */ + rotateKey: (userId: string) => Promise; + + /** + * Activate 'db_user' user. + * + * @param {string} userId The ID of the user to activate. + * @returns {Promise} `true` if the user has been successfully activated. + */ + activate: (userId: string) => Promise; + + /** + * Deactivate 'db_user' user. + * + * @param {string} userId The ID of the user to deactivate. + * @returns {Promise} `true` if the user has been successfully deactivated. + */ + deactivate: (userId: string, opts?: DeactivateOptions) => Promise; + + /** + * Retrieve information about the 'db_user' / 'db_env_user' user. + * + * @param {string} userId The ID of the user to get. + * @returns {Promise} ID, status, and assigned roles of a 'db_*' user. + */ + byName: (userId: string) => Promise; + + /** + * List all 'db_user' / 'db_env_user' users. + * + * @returns {Promise} ID, status, and assigned roles for each 'db_*' user. + */ + listAll: () => Promise; +} + +/** Operations supported for namespaced 'oidc' users.*/ +export interface OIDCUsers extends UsersBase { + /** + * Retrieve the roles assigned to an 'oidc' user. + * + * @param {string} userId The ID of the user to retrieve the assigned roles for. + * @returns {Promise>} A map of role names to their respective roles. + */ + getAssignedRoles: (userId: string, opts?: GetAssignedRolesOptions) => Promise>; } const users = (connection: ConnectionREST): Users => { + const base = baseUsers(connection); + return { getMyUser: () => connection.get('/users/own-info').then(Map.user), getAssignedRoles: (userId: string) => connection.get(`/authz/users/${userId}/roles`).then(Map.roles), + assignRoles: (roleNames: string | string[], userId: string) => base.assignRoles(roleNames, userId), + revokeRoles: (roleNames: string | string[], userId: string) => base.revokeRoles(roleNames, userId), + db: db(connection), + oidc: oidc(connection), + }; +}; + +const db = (connection: ConnectionREST): DBUsers => { + const ns = namespacedUsers(connection); + + /** expectCode returns true if the error contained an expected status code. */ + const expectCode = (code: number): ((_: any) => boolean) => { + return (error) => { + if (error instanceof WeaviateUnexpectedStatusCodeError) { + return error.code === code; + } + throw error; + }; + }; + + type APIKeyResponse = { apikey: string }; + return { + getAssignedRoles: (userId: string, opts?: GetAssignedRolesOptions) => + ns.getAssignedRoles('db', userId, opts), assignRoles: (roleNames: string | string[], userId: string) => + ns.assignRoles(roleNames, userId, { userType: 'db' }), + revokeRoles: (roleNames: string | string[], userId: string) => + ns.revokeRoles(roleNames, userId, { userType: 'db' }), + + create: (userId: string) => + connection.postReturn(`/users/db/${userId}`, null).then((resp) => resp.apikey), + delete: (userId: string) => + connection + .delete(`/users/db/${userId}`, null) + .then(() => true) + .catch(() => false), + rotateKey: (userId: string) => + connection + .postReturn(`/users/db/${userId}/rotate-key`, null) + .then((resp) => resp.apikey), + activate: (userId: string) => + connection + .postEmpty(`/users/db/${userId}/activate`, null) + .then(() => true) + .catch(expectCode(409)), + deactivate: (userId: string, opts?: DeactivateOptions) => + connection + .postEmpty(`/users/db/${userId}/deactivate`, opts || null) + .then(() => true) + .catch(expectCode(409)), + byName: (userId: string) => connection.get(`/users/db/${userId}`, true).then(Map.dbUser), + listAll: () => connection.get('/users/db', true).then(Map.dbUsers), + }; +}; + +const oidc = (connection: ConnectionREST): OIDCUsers => { + const ns = namespacedUsers(connection); + return { + getAssignedRoles: (userId: string, opts?: GetAssignedRolesOptions) => + ns.getAssignedRoles('oidc', userId, opts), + assignRoles: (roleNames: string | string[], userId: string) => + ns.assignRoles(roleNames, userId, { userType: 'oidc' }), + revokeRoles: (roleNames: string | string[], userId: string) => + ns.revokeRoles(roleNames, userId, { userType: 'oidc' }), + }; +}; + +/** Internal interface for operations that MAY accept a 'db'/'oidc' namespace. */ +interface NamespacedUsers { + getAssignedRoles: ( + userType: UserTypeInternal, + userId: string, + opts?: GetAssignedRolesOptions + ) => Promise>; + assignRoles: (roleNames: string | string[], userId: string, opts?: AssignRevokeOptions) => Promise; + revokeRoles: (roleNames: string | string[], userId: string, opts?: AssignRevokeOptions) => Promise; +} + +/** Implementation of the operations common to 'db', 'oidc', and legacy users. */ +const baseUsers = (connection: ConnectionREST): UsersBase => { + const ns = namespacedUsers(connection); + return { + assignRoles: (roleNames: string | string[], userId: string) => ns.assignRoles(roleNames, userId), + revokeRoles: (roleNames: string | string[], userId: string) => ns.revokeRoles(roleNames, userId), + }; +}; + +/** Implementation of the operations common to 'db' and 'oidc' users. */ +const namespacedUsers = (connection: ConnectionREST): NamespacedUsers => { + return { + getAssignedRoles: (userType: UserTypeInternal, userId: string, opts?: GetAssignedRolesOptions) => + connection + .get( + `/authz/users/${userId}/roles/${userType}${ + opts?.includePermissions ? '?&includeFullRoles=true' : '' + }` + ) + .then(Map.roles), + assignRoles: (roleNames: string | string[], userId: string, opts?: AssignRevokeOptions) => connection.postEmpty(`/authz/users/${userId}/assign`, { + ...opts, roles: Array.isArray(roleNames) ? roleNames : [roleNames], }), - revokeRoles: (roleNames: string | string[], userId: string) => + revokeRoles: (roleNames: string | string[], userId: string, opts?: AssignRevokeOptions) => connection.postEmpty(`/authz/users/${userId}/revoke`, { + ...opts, roles: Array.isArray(roleNames) ? roleNames : [roleNames], }), }; diff --git a/src/users/integration.test.ts b/src/users/integration.test.ts index 83d2ec4a..0c442bfe 100644 --- a/src/users/integration.test.ts +++ b/src/users/integration.test.ts @@ -1,11 +1,13 @@ import weaviate, { ApiKey } from '..'; -import { DbVersion } from '../utils/dbVersion'; +import { requireAtLeast } from '../../test/version.js'; +import { WeaviateUserTypeDB } from '../v2'; +import { UserDB } from './types.js'; -const only = DbVersion.fromString(`v${process.env.WEAVIATE_VERSION!}`).isAtLeast(1, 29, 0) - ? describe - : describe.skip; - -only('Integration testing of the users namespace', () => { +requireAtLeast( + 1, + 29, + 0 +)('Integration testing of the users namespace', () => { const makeClient = (key: string) => weaviate.connectToLocal({ port: 8091, @@ -59,5 +61,109 @@ only('Integration testing of the users namespace', () => { expect(roles.test).toBeUndefined(); }); + requireAtLeast( + 1, + 30, + 0 + )('dynamic user management', () => { + it('should be able to manage "db" user lifecycle', async () => { + const client = await makeClient('admin-key'); + + /** Pass false to expect a rejected promise, chain assertions about dynamic-dave otherwise. */ + const expectDave = (ok: boolean = true) => { + const promise = expect(client.users.db.byName('dynamic-dave')); + return ok ? promise.resolves : promise.rejects; + }; + + await client.users.db.create('dynamic-dave'); + await expectDave().toHaveProperty('active', true); + + // Second activation is a no-op + await expect(client.users.db.activate('dynamic-dave')).resolves.toEqual(true); + + await client.users.db.deactivate('dynamic-dave'); + await expectDave().toHaveProperty('active', false); + + // Second deactivation is a no-op + await expect(client.users.db.deactivate('dynamic-dave', { revokeKey: true })).resolves.toEqual(true); + + await client.users.db.delete('dynamic-dave'); + await expectDave(false).toHaveProperty('code', 404); + }); + + it('should be able to obtain and rotate api keys', async () => { + const admin = await makeClient('admin-key'); + const apiKey = await admin.users.db.create('api-ashley'); + + let userAshley = await makeClient(apiKey).then((client) => client.users.getMyUser()); + expect(userAshley.id).toEqual('api-ashley'); + + const newKey = await admin.users.db.rotateKey('api-ashley'); + userAshley = await makeClient(newKey).then((client) => client.users.getMyUser()); + expect(userAshley.id).toEqual('api-ashley'); + }); + + it('should be able to list all dynamic users', async () => { + const admin = await makeClient('admin-key'); + + await Promise.all(['jim', 'pam', 'dwight'].map((user) => admin.users.db.create(user))); + + const all = await admin.users.db.listAll(); + expect(all.length).toBeGreaterThanOrEqual(3); + + const pam = await admin.users.db.byName('pam'); + expect(all).toEqual(expect.arrayContaining([pam])); + }); + + it('should be able to fetch static users', async () => { + const custom = await makeClient('admin-key').then((client) => client.users.db.byName('custom-user')); + expect(custom.userType).toEqual('db_env_user'); + }); + + it.each<'db' | 'oidc'>(['db', 'oidc'])('should be able to assign roles to "%s" users', async (kind) => { + const admin = await makeClient('admin-key'); + + if (kind === 'db') { + await admin.users.db.create('role-rick'); + } + + await admin.users[kind].assignRoles('test', 'role-rick'); + await expect(admin.users[kind].getAssignedRoles('role-rick')).resolves.toEqual( + expect.objectContaining({ test: expect.any(Object) }) + ); + + await admin.users[kind].revokeRoles('test', 'role-rick'); + await expect(admin.users[kind].getAssignedRoles('role-rick')).resolves.toEqual({}); + }); + + it('should be able to fetch assigned roles with all permissions', async () => { + const admin = await makeClient('admin-key'); + + await admin.roles.delete('test'); + await admin.roles.create('test', [ + { collection: 'Things', actions: ['manage_backups'] }, + { collection: 'Things', tenant: 'data-tenant', actions: ['create_data'] }, + { collection: 'Things', verbosity: 'minimal', actions: ['read_nodes'] }, + ]); + await admin.users.db.create('permission-peter'); + await admin.users.db.assignRoles('test', 'permission-peter'); + + const roles = await admin.users.db.getAssignedRoles('permission-peter', { includePermissions: true }); + expect(roles.test.backupsPermissions).toHaveLength(1); + expect(roles.test.dataPermissions).toHaveLength(1); + expect(roles.test.nodesPermissions).toHaveLength(1); + }); + + afterAll(() => + makeClient('admin-key').then(async (c) => { + await Promise.all( + ['jim', 'pam', 'dwight', 'dynamic-dave', 'api-ashley', 'role-rick', 'permission-peter'].map((n) => + c.users.db.delete(n) + ) + ); + }) + ); + }); + afterAll(() => makeClient('admin-key').then((c) => c.roles.delete('test'))); }); diff --git a/src/users/types.ts b/src/users/types.ts index 097b1c57..b4a9d59d 100644 --- a/src/users/types.ts +++ b/src/users/types.ts @@ -1,6 +1,25 @@ +import { WeaviateUserTypeDB as UserTypeDB, WeaviateUserTypeInternal } from '../openapi/types.js'; import { Role } from '../roles/types.js'; export type User = { id: string; roles?: Role[]; }; + +export type UserDB = { + userType: UserTypeDB; + id: string; + roleNames: string[]; + active: boolean; +}; + +/** Optional arguments to /user/{type}/{username} enpoint. */ +export type GetAssignedRolesOptions = { + includePermissions?: boolean; +}; + +/** Optional arguments to /assign and /revoke endpoints. */ +export type AssignRevokeOptions = { userType?: WeaviateUserTypeInternal }; + +/** Optional arguments to /deactivate endpoint. */ +export type DeactivateOptions = { revokeKey?: boolean }; diff --git a/src/utils/dbVersion.ts b/src/utils/dbVersion.ts index 279537e2..a945e414 100644 --- a/src/utils/dbVersion.ts +++ b/src/utils/dbVersion.ts @@ -151,11 +151,18 @@ export class DbVersionSupport { return { version: version, supports: version.isAtLeast(1, 25, 0), - message: this.errorMessage('Tenants get method', version.show(), '1.25.0'), + message: this.errorMessage('Tenants get method over gRPC', version.show(), '1.25.0'), }; }); }; + supportsTenantGetRESTMethod = () => + this.dbVersionProvider.getVersion().then((version) => ({ + version: version, + supports: version.isAtLeast(1, 28, 0), + message: this.errorMessage('Tenant get method over REST', version.show(), '1.28.0'), + })); + supportsDynamicVectorIndex = () => { return this.dbVersionProvider.getVersion().then((version) => { return { @@ -219,6 +226,24 @@ export class DbVersionSupport { }; }); }; + + supportsSingleGrouped = () => + this.dbVersionProvider.getVersion().then((version) => ({ + version, + supports: + (version.isAtLeast(1, 27, 14) && version.isLowerThan(1, 28, 0)) || + (version.isAtLeast(1, 28, 8) && version.isLowerThan(1, 29, 0)) || + (version.isAtLeast(1, 29, 0) && version.isLowerThan(1, 30, 0)) || + version.isAtLeast(1, 30, 0), + message: this.errorMessage('Single/Grouped fields in gRPC', version.show(), '1.30.0'), + })); + + supportsGenerativeConfigRuntime = () => + this.dbVersionProvider.getVersion().then((version) => ({ + version, + supports: version.isAtLeast(1, 30, 0), + message: this.errorMessage('Generative config runtime', version.show(), '1.30.0'), + })); } const EMPTY_VERSION = ''; diff --git a/test/version.ts b/test/version.ts new file mode 100644 index 00000000..b34118ef --- /dev/null +++ b/test/version.ts @@ -0,0 +1,7 @@ +import { DbVersion } from '../src/utils/dbVersion'; + +const version = DbVersion.fromString(`v${process.env.WEAVIATE_VERSION!}`); + +/** Run the suite / test only for Weaviate version above this. */ +export const requireAtLeast = (...semver: [...Parameters]) => + version.isAtLeast(...semver) ? describe : describe.skip;