Skip to content

Commit 08203de

Browse files
authored
feat: Add vector index create and update (#252)
1 parent fa62241 commit 08203de

File tree

4 files changed

+148
-1
lines changed

4 files changed

+148
-1
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
2+
import { buildVectorFields, DbOperationArgs, MongoDBToolBase, VectorIndexArgs } from "../mongodbTool.js";
3+
import { OperationType, ToolArgs } from "../../tool.js";
4+
5+
const VECTOR_INDEX_TYPE = "vectorSearch";
6+
export class CreateVectorIndexTool extends MongoDBToolBase {
7+
protected name = "create-vector-index";
8+
protected description = "Create an Atlas Vector Search Index for a collection.";
9+
protected argsShape = {
10+
...DbOperationArgs,
11+
name: VectorIndexArgs.name,
12+
vectorDefinition: VectorIndexArgs.vectorDefinition,
13+
filterFields: VectorIndexArgs.filterFields,
14+
};
15+
16+
protected operationType: OperationType = "create";
17+
18+
protected async execute({
19+
database,
20+
collection,
21+
name,
22+
vectorDefinition,
23+
filterFields,
24+
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
25+
const provider = await this.ensureConnected();
26+
27+
const indexes = await provider.createSearchIndexes(database, collection, [
28+
{
29+
name,
30+
type: VECTOR_INDEX_TYPE,
31+
definition: { fields: buildVectorFields(vectorDefinition, filterFields) },
32+
},
33+
]);
34+
35+
return {
36+
content: [
37+
{
38+
text: `Created the vector index ${indexes[0]} on collection "${collection}" in database "${database}"`,
39+
type: "text",
40+
},
41+
],
42+
};
43+
}
44+
}

src/tools/mongodb/mongodbTool.ts

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { z } from "zod";
2-
import { ToolArgs, ToolBase, ToolCategory, TelemetryToolMetadata } from "../tool.js";
2+
import { TelemetryToolMetadata, ToolArgs, ToolBase, ToolCategory } from "../tool.js";
33
import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
44
import { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
55
import { ErrorCodes, MongoDBError } from "../../errors.js";
@@ -10,6 +10,64 @@ export const DbOperationArgs = {
1010
collection: z.string().describe("Collection name"),
1111
};
1212

13+
export enum VectorFieldType {
14+
VECTOR = "vector",
15+
FILTER = "filter",
16+
}
17+
export const VectorIndexArgs = {
18+
name: z.string().describe("The name of the index"),
19+
vectorDefinition: z
20+
.object({
21+
path: z
22+
.string()
23+
.min(1)
24+
.describe(
25+
"Name of the field to index. For nested fields, use dot notation to specify path to embedded fields."
26+
),
27+
numDimensions: z
28+
.number()
29+
.int()
30+
.min(1)
31+
.max(8192)
32+
.describe("Number of vector dimensions to enforce at index-time and query-time."),
33+
similarity: z
34+
.enum(["euclidean", "cosine", "dotProduct"])
35+
.describe("Vector similarity function to use to search for top K-nearest neighbors."),
36+
quantization: z
37+
.enum(["none", "scalar", "binary"])
38+
.default("none")
39+
.optional()
40+
.describe(
41+
"Automatic vector quantization. Use this setting only if your embeddings are float or double vectors."
42+
),
43+
})
44+
.describe("The vector index definition."),
45+
filterFields: z
46+
.array(
47+
z.object({
48+
path: z
49+
.string()
50+
.min(1)
51+
.describe(
52+
"Name of the field to filter by. For nested fields, use dot notation to specify path to embedded fields."
53+
),
54+
})
55+
)
56+
.optional()
57+
.describe("Additional indexed fields that pre-filter data."),
58+
};
59+
60+
type VectorDefinitionType = z.infer<typeof VectorIndexArgs.vectorDefinition>;
61+
type FilterFieldsType = z.infer<typeof VectorIndexArgs.filterFields>;
62+
export function buildVectorFields(vectorDefinition: VectorDefinitionType, filterFields: FilterFieldsType): object[] {
63+
const typedVectorField = { ...vectorDefinition, type: VectorFieldType.VECTOR };
64+
const typedFilterFields = (filterFields ?? []).map((f) => ({
65+
...f,
66+
type: VectorFieldType.FILTER,
67+
}));
68+
return [typedVectorField, ...typedFilterFields];
69+
}
70+
1371
export const SearchIndexOperationArgs = {
1472
database: z.string().describe("Database name"),
1573
collection: z.string().describe("Collection name"),

src/tools/mongodb/tools.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ import { DropCollectionTool } from "./delete/dropCollection.js";
1818
import { ExplainTool } from "./metadata/explain.js";
1919
import { CreateCollectionTool } from "./create/createCollection.js";
2020
import { LogsTool } from "./metadata/logs.js";
21+
import { CreateVectorIndexTool } from "./create/createVectorIndex.js";
22+
import { UpdateVectorIndexTool } from "./update/updateVectorIndex.js";
2123
import { CollectionSearchIndexesTool } from "./read/collectionSearchIndexes.js";
2224
import { DropSearchIndexTool } from "./delete/dropSearchIndex.js";
2325

@@ -43,5 +45,7 @@ export const MongoDbTools = [
4345
ExplainTool,
4446
CreateCollectionTool,
4547
LogsTool,
48+
CreateVectorIndexTool,
49+
UpdateVectorIndexTool,
4650
DropSearchIndexTool,
4751
];
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
2+
import { buildVectorFields, DbOperationArgs, MongoDBToolBase, VectorIndexArgs } from "../mongodbTool.js";
3+
import { OperationType, ToolArgs } from "../../tool.js";
4+
5+
export class UpdateVectorIndexTool extends MongoDBToolBase {
6+
protected name = "update-vector-index";
7+
protected description = "Updates an Atlas Search vector for a collection";
8+
protected argsShape = {
9+
...DbOperationArgs,
10+
name: VectorIndexArgs.name,
11+
vectorDefinition: VectorIndexArgs.vectorDefinition,
12+
filterFields: VectorIndexArgs.filterFields,
13+
};
14+
15+
protected operationType: OperationType = "create";
16+
17+
protected async execute({
18+
database,
19+
collection,
20+
name,
21+
vectorDefinition,
22+
filterFields,
23+
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
24+
const provider = await this.ensureConnected();
25+
26+
// @ts-expect-error: Interface expects a SearchIndexDefinition {definition: {fields}}. However,
27+
// passing fields at the root level is necessary for the call to succeed.
28+
await provider.updateSearchIndex(database, collection, name, {
29+
fields: buildVectorFields(vectorDefinition, filterFields),
30+
});
31+
32+
return {
33+
content: [
34+
{
35+
text: `Successfully updated vector index "${name}" on collection "${collection}" in database "${database}"`,
36+
type: "text",
37+
},
38+
],
39+
};
40+
}
41+
}

0 commit comments

Comments
 (0)