diff --git a/eslint.config.js b/eslint.config.js index 18f564c7..da617263 100644 --- a/eslint.config.js +++ b/eslint.config.js @@ -22,6 +22,7 @@ export default defineConfig([ files, rules: { "@typescript-eslint/switch-exhaustiveness-check": "error", + "@typescript-eslint/no-non-null-assertion": "error", }, }, // Ignore features specific to TypeScript resolved rules diff --git a/src/common/atlas/apiClient.ts b/src/common/atlas/apiClient.ts index 39957b7d..6bb8e4fd 100644 --- a/src/common/atlas/apiClient.ts +++ b/src/common/atlas/apiClient.ts @@ -55,30 +55,31 @@ export class ApiClient { return this.accessToken?.token.access_token as string | undefined; }; - private authMiddleware = (apiClient: ApiClient): Middleware => ({ - async onRequest({ request, schemaPath }) { + private authMiddleware: Middleware = { + onRequest: async ({ request, schemaPath }) => { if (schemaPath.startsWith("/api/private/unauth") || schemaPath.startsWith("/api/oauth")) { return undefined; } try { - const accessToken = await apiClient.getAccessToken(); + const accessToken = await this.getAccessToken(); request.headers.set("Authorization", `Bearer ${accessToken}`); return request; } catch { // ignore not availble tokens, API will return 401 } }, - }); - private errorMiddleware = (): Middleware => ({ + }; + + private readonly errorMiddleware: Middleware = { async onResponse({ response }) { if (!response.ok) { throw await ApiClientError.fromResponse(response); } }, - }); + }; - constructor(options?: ApiClientOptions) { + constructor(options: ApiClientOptions) { const defaultOptions = { baseUrl: "https://cloud.mongodb.com/", userAgent: `AtlasMCP/${config.version} (${process.platform}; ${process.arch}; ${process.env.HOSTNAME || "unknown"})`, @@ -107,9 +108,9 @@ export class ApiClient { tokenPath: "/api/oauth/token", }, }); - this.client.use(this.authMiddleware(this)); + this.client.use(this.authMiddleware); } - this.client.use(this.errorMiddleware()); + this.client.use(this.errorMiddleware); } public async getIpInfo(): Promise<{ diff --git a/src/config.ts b/src/config.ts index ecdf32ad..cce25722 100644 --- a/src/config.ts +++ b/src/config.ts @@ -60,7 +60,10 @@ function getLogPath(): string { // to SNAKE_UPPER_CASE. function getEnvConfig(): Partial { function setValue(obj: Record, path: string[], value: string): void { - const currentField = path.shift()!; + const currentField = path.shift(); + if (!currentField) { + return; + } if (path.length === 0) { const numberValue = Number(value); if (!isNaN(numberValue)) { diff --git a/src/index.ts b/src/index.ts index 0f282e4d..944ee92a 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,19 +1,30 @@ #!/usr/bin/env node import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; -import { Server } from "./server.js"; import logger from "./logger.js"; import { mongoLogId } from "mongodb-log-writer"; +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import config from "./config.js"; +import { Session } from "./session.js"; +import { Server } from "./server.js"; -export async function runServer() { - const server = new Server(); +try { + const session = new Session(); + const mcpServer = new McpServer({ + name: "MongoDB Atlas", + version: config.version, + }); + + const server = new Server({ + mcpServer, + session, + }); const transport = new StdioServerTransport(); - await server.connect(transport); -} -runServer().catch((error) => { - logger.emergency(mongoLogId(1_000_004), "server", `Fatal error running server: ${error}`); + await server.connect(transport); +} catch (error: unknown) { + logger.emergency(mongoLogId(1_000_004), "server", `Fatal error running server: ${error as string}`); process.exit(1); -}); +} diff --git a/src/server.ts b/src/server.ts index 5530658c..72d9c4f9 100644 --- a/src/server.ts +++ b/src/server.ts @@ -1,39 +1,40 @@ import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; -import defaultState, { State } from "./state.js"; +import { Session } from "./session.js"; import { Transport } from "@modelcontextprotocol/sdk/shared/transport.js"; -import { registerAtlasTools } from "./tools/atlas/tools.js"; -import { registerMongoDBTools } from "./tools/mongodb/index.js"; -import config from "./config.js"; +import { AtlasTools } from "./tools/atlas/tools.js"; +import { MongoDbTools } from "./tools/mongodb/tools.js"; import logger, { initializeLogger } from "./logger.js"; import { mongoLogId } from "mongodb-log-writer"; export class Server { - state: State = defaultState; - private server?: McpServer; + public readonly session: Session; + private readonly mcpServer: McpServer; + + constructor({ mcpServer, session }: { mcpServer: McpServer; session: Session }) { + this.mcpServer = mcpServer; + this.session = session; + } async connect(transport: Transport) { - this.server = new McpServer({ - name: "MongoDB Atlas", - version: config.version, - }); + this.mcpServer.server.registerCapabilities({ logging: {} }); - this.server.server.registerCapabilities({ logging: {} }); + this.registerTools(); - registerAtlasTools(this.server, this.state); - registerMongoDBTools(this.server, this.state); + await initializeLogger(this.mcpServer); - await initializeLogger(this.server); - await this.server.connect(transport); + await this.mcpServer.connect(transport); logger.info(mongoLogId(1_000_004), "server", `Server started with transport ${transport.constructor.name}`); } async close(): Promise { - try { - await this.state.serviceProvider?.close(true); - } catch { - // Ignore errors during service provider close + await this.session.close(); + await this.mcpServer.close(); + } + + private registerTools() { + for (const tool of [...AtlasTools, ...MongoDbTools]) { + new tool(this.session).register(this.mcpServer); } - await this.server?.close(); } } diff --git a/src/state.ts b/src/session.ts similarity index 67% rename from src/state.ts rename to src/session.ts index f4e694bd..0d5ac951 100644 --- a/src/state.ts +++ b/src/session.ts @@ -2,11 +2,11 @@ import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver import { ApiClient } from "./common/atlas/apiClient.js"; import config from "./config.js"; -export class State { +export class Session { serviceProvider?: NodeDriverServiceProvider; apiClient?: ApiClient; - ensureApiClient(): asserts this is { apiClient: ApiClient } { + ensureAuthenticated(): asserts this is { apiClient: ApiClient } { if (!this.apiClient) { if (!config.apiClientId || !config.apiClientSecret) { throw new Error( @@ -23,7 +23,15 @@ export class State { }); } } -} -const defaultState = new State(); -export default defaultState; + async close(): Promise { + if (this.serviceProvider) { + try { + await this.serviceProvider.close(true); + } catch (error) { + console.error("Error closing service provider:", error); + } + this.serviceProvider = undefined; + } + } +} diff --git a/src/tools/atlas/atlasTool.ts b/src/tools/atlas/atlasTool.ts index 4aef681c..7a1c00fe 100644 --- a/src/tools/atlas/atlasTool.ts +++ b/src/tools/atlas/atlasTool.ts @@ -1,8 +1,8 @@ import { ToolBase } from "../tool.js"; -import { State } from "../../state.js"; +import { Session } from "../../session.js"; export abstract class AtlasToolBase extends ToolBase { - constructor(state: State) { - super(state); + constructor(protected readonly session: Session) { + super(session); } } diff --git a/src/tools/atlas/createAccessList.ts b/src/tools/atlas/createAccessList.ts index 09a991c8..3ba12046 100644 --- a/src/tools/atlas/createAccessList.ts +++ b/src/tools/atlas/createAccessList.ts @@ -26,7 +26,7 @@ export class CreateAccessListTool extends AtlasToolBase { comment, currentIpAddress, }: ToolArgs): Promise { - this.state.ensureApiClient(); + this.session.ensureAuthenticated(); if (!ipAddresses?.length && !cidrBlocks?.length && !currentIpAddress) { throw new Error("One of ipAddresses, cidrBlocks, currentIpAddress must be provided."); @@ -39,7 +39,7 @@ export class CreateAccessListTool extends AtlasToolBase { })); if (currentIpAddress) { - const currentIp = await this.state.apiClient.getIpInfo(); + const currentIp = await this.session.apiClient.getIpInfo(); const input = { groupId: projectId, ipAddress: currentIp.currentIpv4Address, @@ -56,7 +56,7 @@ export class CreateAccessListTool extends AtlasToolBase { const inputs = [...ipInputs, ...cidrInputs]; - await this.state.apiClient.createProjectIpAccessList({ + await this.session.apiClient.createProjectIpAccessList({ params: { path: { groupId: projectId, diff --git a/src/tools/atlas/createDBUser.ts b/src/tools/atlas/createDBUser.ts index 2698f0d8..a388ef9a 100644 --- a/src/tools/atlas/createDBUser.ts +++ b/src/tools/atlas/createDBUser.ts @@ -33,7 +33,7 @@ export class CreateDBUserTool extends AtlasToolBase { roles, clusters, }: ToolArgs): Promise { - this.state.ensureApiClient(); + this.session.ensureAuthenticated(); const input = { groupId: projectId, @@ -53,7 +53,7 @@ export class CreateDBUserTool extends AtlasToolBase { : undefined, } as CloudDatabaseUser; - await this.state.apiClient.createDatabaseUser({ + await this.session.apiClient.createDatabaseUser({ params: { path: { groupId: projectId, diff --git a/src/tools/atlas/createFreeCluster.ts b/src/tools/atlas/createFreeCluster.ts index 8179883f..675d48cd 100644 --- a/src/tools/atlas/createFreeCluster.ts +++ b/src/tools/atlas/createFreeCluster.ts @@ -14,7 +14,7 @@ export class CreateFreeClusterTool extends AtlasToolBase { }; protected async execute({ projectId, name, region }: ToolArgs): Promise { - this.state.ensureApiClient(); + this.session.ensureAuthenticated(); const input = { groupId: projectId, @@ -38,7 +38,7 @@ export class CreateFreeClusterTool extends AtlasToolBase { terminationProtectionEnabled: false, } as unknown as ClusterDescription20240805; - await this.state.apiClient.createCluster({ + await this.session.apiClient.createCluster({ params: { path: { groupId: projectId, diff --git a/src/tools/atlas/inspectAccessList.ts b/src/tools/atlas/inspectAccessList.ts index c66cf5dc..8c25367b 100644 --- a/src/tools/atlas/inspectAccessList.ts +++ b/src/tools/atlas/inspectAccessList.ts @@ -11,9 +11,9 @@ export class InspectAccessListTool extends AtlasToolBase { }; protected async execute({ projectId }: ToolArgs): Promise { - this.state.ensureApiClient(); + this.session.ensureAuthenticated(); - const accessList = await this.state.apiClient.listProjectIpAccessLists({ + const accessList = await this.session.apiClient.listProjectIpAccessLists({ params: { path: { groupId: projectId, diff --git a/src/tools/atlas/inspectCluster.ts b/src/tools/atlas/inspectCluster.ts index 9ad35a46..c8aa3185 100644 --- a/src/tools/atlas/inspectCluster.ts +++ b/src/tools/atlas/inspectCluster.ts @@ -13,9 +13,9 @@ export class InspectClusterTool extends AtlasToolBase { }; protected async execute({ projectId, clusterName }: ToolArgs): Promise { - this.state.ensureApiClient(); + this.session.ensureAuthenticated(); - const cluster = await this.state.apiClient.getCluster({ + const cluster = await this.session.apiClient.getCluster({ params: { path: { groupId: projectId, diff --git a/src/tools/atlas/listClusters.ts b/src/tools/atlas/listClusters.ts index 16c85233..8a6a1b08 100644 --- a/src/tools/atlas/listClusters.ts +++ b/src/tools/atlas/listClusters.ts @@ -12,14 +12,14 @@ export class ListClustersTool extends AtlasToolBase { }; protected async execute({ projectId }: ToolArgs): Promise { - this.state.ensureApiClient(); + this.session.ensureAuthenticated(); if (!projectId) { - const data = await this.state.apiClient.listClustersForAllProjects(); + const data = await this.session.apiClient.listClustersForAllProjects(); return this.formatAllClustersTable(data); } else { - const project = await this.state.apiClient.getProject({ + const project = await this.session.apiClient.getProject({ params: { path: { groupId: projectId, @@ -31,7 +31,7 @@ export class ListClustersTool extends AtlasToolBase { throw new Error(`Project with ID "${projectId}" not found.`); } - const data = await this.state.apiClient.listClusters({ + const data = await this.session.apiClient.listClusters({ params: { path: { groupId: project.id || "", diff --git a/src/tools/atlas/listDBUsers.ts b/src/tools/atlas/listDBUsers.ts index d9712b1e..5e7a73a9 100644 --- a/src/tools/atlas/listDBUsers.ts +++ b/src/tools/atlas/listDBUsers.ts @@ -12,9 +12,9 @@ export class ListDBUsersTool extends AtlasToolBase { }; protected async execute({ projectId }: ToolArgs): Promise { - this.state.ensureApiClient(); + this.session.ensureAuthenticated(); - const data = await this.state.apiClient.listDatabaseUsers({ + const data = await this.session.apiClient.listDatabaseUsers({ params: { path: { groupId: projectId, diff --git a/src/tools/atlas/listProjects.ts b/src/tools/atlas/listProjects.ts index cdf6d6cc..bb9e0865 100644 --- a/src/tools/atlas/listProjects.ts +++ b/src/tools/atlas/listProjects.ts @@ -7,9 +7,9 @@ export class ListProjectsTool extends AtlasToolBase { protected argsShape = {}; protected async execute(): Promise { - this.state.ensureApiClient(); + this.session.ensureAuthenticated(); - const data = await this.state.apiClient.listProjects(); + const data = await this.session.apiClient.listProjects(); if (!data?.results?.length) { throw new Error("No projects found in your MongoDB Atlas account."); diff --git a/src/tools/atlas/tools.ts b/src/tools/atlas/tools.ts index 5e717306..4e7bd200 100644 --- a/src/tools/atlas/tools.ts +++ b/src/tools/atlas/tools.ts @@ -1,6 +1,3 @@ -import { ToolBase } from "../tool.js"; -import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; -import { State } from "../../state.js"; import { ListClustersTool } from "./listClusters.js"; import { ListProjectsTool } from "./listProjects.js"; import { InspectClusterTool } from "./inspectCluster.js"; @@ -10,19 +7,13 @@ import { InspectAccessListTool } from "./inspectAccessList.js"; import { ListDBUsersTool } from "./listDBUsers.js"; import { CreateDBUserTool } from "./createDBUser.js"; -export function registerAtlasTools(server: McpServer, state: State) { - const tools: ToolBase[] = [ - new ListClustersTool(state), - new ListProjectsTool(state), - new InspectClusterTool(state), - new CreateFreeClusterTool(state), - new CreateAccessListTool(state), - new InspectAccessListTool(state), - new ListDBUsersTool(state), - new CreateDBUserTool(state), - ]; - - for (const tool of tools) { - tool.register(server); - } -} +export const AtlasTools = [ + ListClustersTool, + ListProjectsTool, + InspectClusterTool, + CreateFreeClusterTool, + CreateAccessListTool, + InspectAccessListTool, + ListDBUsersTool, + CreateDBUserTool, +]; diff --git a/src/tools/mongodb/connect.ts b/src/tools/mongodb/connect.ts index dfba9926..66df62e3 100644 --- a/src/tools/mongodb/connect.ts +++ b/src/tools/mongodb/connect.ts @@ -57,7 +57,7 @@ export class ConnectTool extends MongoDBToolBase { throw new MongoDBError(ErrorCodes.InvalidParams, "Invalid connection options"); } - await this.connectToMongoDB(connectionString, this.state); + await this.connectToMongoDB(connectionString); return { content: [{ type: "text", text: `Successfully connected to ${connectionString}.` }], diff --git a/src/tools/mongodb/mongodbTool.ts b/src/tools/mongodb/mongodbTool.ts index 9c09caf0..fb1a0a32 100644 --- a/src/tools/mongodb/mongodbTool.ts +++ b/src/tools/mongodb/mongodbTool.ts @@ -1,6 +1,6 @@ import { z } from "zod"; import { ToolBase } from "../tool.js"; -import { State } from "../../state.js"; +import { Session } from "../../session.js"; import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { ErrorCodes, MongoDBError } from "../../errors.js"; @@ -14,16 +14,16 @@ export const DbOperationArgs = { export type DbOperationType = "metadata" | "read" | "create" | "update" | "delete"; export abstract class MongoDBToolBase extends ToolBase { - constructor(state: State) { - super(state); + constructor(session: Session) { + super(session); } protected abstract operationType: DbOperationType; protected async ensureConnected(): Promise { - const provider = this.state.serviceProvider; + const provider = this.session.serviceProvider; if (!provider && config.connectionString) { - await this.connectToMongoDB(config.connectionString, this.state); + await this.connectToMongoDB(config.connectionString); } if (!provider) { @@ -53,7 +53,7 @@ export abstract class MongoDBToolBase extends ToolBase { return super.handleError(error); } - protected async connectToMongoDB(connectionString: string, state: State): Promise { + protected async connectToMongoDB(connectionString: string): Promise { const provider = await NodeDriverServiceProvider.connect(connectionString, { productDocsLink: "https://docs.mongodb.com/todo-mcp", productName: "MongoDB MCP", @@ -67,6 +67,6 @@ export abstract class MongoDBToolBase extends ToolBase { timeoutMS: config.connectOptions.timeoutMS, }); - state.serviceProvider = provider; + this.session.serviceProvider = provider; } } diff --git a/src/tools/mongodb/index.ts b/src/tools/mongodb/tools.ts similarity index 58% rename from src/tools/mongodb/index.ts rename to src/tools/mongodb/tools.ts index be30c494..0f89335d 100644 --- a/src/tools/mongodb/index.ts +++ b/src/tools/mongodb/tools.ts @@ -1,5 +1,3 @@ -import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; -import { State } from "../../state.js"; import { ConnectTool } from "./connect.js"; import { ListCollectionsTool } from "./metadata/listCollections.js"; import { CollectionIndexesTool } from "./collectionIndexes.js"; @@ -21,32 +19,25 @@ import { RenameCollectionTool } from "./update/renameCollection.js"; import { DropDatabaseTool } from "./delete/dropDatabase.js"; import { DropCollectionTool } from "./delete/dropCollection.js"; -export function registerMongoDBTools(server: McpServer, state: State) { - const tools = [ - ConnectTool, - ListCollectionsTool, - ListDatabasesTool, - CollectionIndexesTool, - CreateIndexTool, - CollectionSchemaTool, - InsertOneTool, - FindTool, - InsertManyTool, - DeleteManyTool, - DeleteOneTool, - CollectionStorageSizeTool, - CountTool, - DbStatsTool, - AggregateTool, - UpdateOneTool, - UpdateManyTool, - RenameCollectionTool, - DropDatabaseTool, - DropCollectionTool, - ]; - - for (const tool of tools) { - const instance = new tool(state); - instance.register(server); - } -} +export const MongoDbTools = [ + ConnectTool, + ListCollectionsTool, + ListDatabasesTool, + CollectionIndexesTool, + CreateIndexTool, + CollectionSchemaTool, + InsertOneTool, + FindTool, + InsertManyTool, + DeleteManyTool, + DeleteOneTool, + CollectionStorageSizeTool, + CountTool, + DbStatsTool, + AggregateTool, + UpdateOneTool, + UpdateManyTool, + RenameCollectionTool, + DropDatabaseTool, + DropCollectionTool, +]; diff --git a/src/tools/tool.ts b/src/tools/tool.ts index d09be52d..1f5dd8c4 100644 --- a/src/tools/tool.ts +++ b/src/tools/tool.ts @@ -1,7 +1,7 @@ import { McpServer, ToolCallback } from "@modelcontextprotocol/sdk/server/mcp.js"; import { z, ZodNever, ZodRawShape } from "zod"; import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; -import { State } from "../state.js"; +import { Session } from "../session.js"; import logger from "../logger.js"; import { mongoLogId } from "mongodb-log-writer"; @@ -16,7 +16,7 @@ export abstract class ToolBase { protected abstract execute(...args: Parameters>): Promise; - protected constructor(protected state: State) {} + protected constructor(protected session: Session) {} public register(server: McpServer): void { const callback: ToolCallback = async (...args) => { diff --git a/tests/integration/helpers.ts b/tests/integration/helpers.ts index a08b3eea..207492da 100644 --- a/tests/integration/helpers.ts +++ b/tests/integration/helpers.ts @@ -4,6 +4,8 @@ import { Server } from "../../src/server.js"; import runner, { MongoCluster } from "mongodb-runner"; import path from "path"; import fs from "fs/promises"; +import { Session } from "../../src/session.js"; +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; export async function setupIntegrationTest(): Promise<{ client: Client; @@ -29,7 +31,13 @@ export async function setupIntegrationTest(): Promise<{ } ); - const server = new Server(); + const server = new Server({ + mcpServer: new McpServer({ + name: "test-server", + version: "1.2.3", + }), + session: new Session(), + }); await server.connect(serverTransport); await client.connect(clientTransport); diff --git a/tests/unit/index.test.ts b/tests/unit/index.test.ts deleted file mode 100644 index 2e307bfb..00000000 --- a/tests/unit/index.test.ts +++ /dev/null @@ -1,27 +0,0 @@ -import { describe, it } from "@jest/globals"; -import { runServer } from "../../src/index"; -import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; - -// mock the StdioServerTransport -jest.mock("@modelcontextprotocol/sdk/server/stdio"); -// mock Server class and its methods -jest.mock("../../src/server.ts", () => { - return { - Server: jest.fn().mockImplementation(() => { - return { - connect: jest.fn().mockImplementation((transport) => { - return new Promise((resolve) => { - resolve(transport); - }); - }), - }; - }), - }; -}); - -describe("Server initialization", () => { - it("should create a server instance", async () => { - await runServer(); - expect(StdioServerTransport).toHaveBeenCalled(); - }); -}); diff --git a/tsconfig.json b/tsconfig.json index a195f859..1fe57f10 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -6,6 +6,7 @@ "rootDir": "./src", "outDir": "./dist", "strict": true, + "strictNullChecks": true, "esModuleInterop": true, "types": ["node"], "sourceMap": true,