diff --git a/eslint.config.js b/eslint.config.js index b42518a5..e6dd1af0 100644 --- a/eslint.config.js +++ b/eslint.config.js @@ -49,6 +49,7 @@ export default defineConfig([ "global.d.ts", "eslint.config.js", "jest.config.ts", + "src/types/*.d.ts", ]), eslintPluginPrettierRecommended, ]); diff --git a/package-lock.json b/package-lock.json index 3f3004f1..9d01e564 100644 --- a/package-lock.json +++ b/package-lock.json @@ -15,6 +15,7 @@ "bson": "^6.10.3", "lru-cache": "^11.1.0", "mongodb": "^6.15.0", + "mongodb-connection-string-url": "^3.0.2", "mongodb-log-writer": "^2.4.1", "mongodb-redact": "^1.1.6", "mongodb-schema": "^12.6.2", diff --git a/package.json b/package.json index 6e77412f..d8ce1f40 100644 --- a/package.json +++ b/package.json @@ -66,6 +66,7 @@ "bson": "^6.10.3", "lru-cache": "^11.1.0", "mongodb": "^6.15.0", + "mongodb-connection-string-url": "^3.0.2", "mongodb-log-writer": "^2.4.1", "mongodb-redact": "^1.1.6", "mongodb-schema": "^12.6.2", diff --git a/src/common/atlas/apiClient.ts b/src/common/atlas/apiClient.ts index 7f74f578..13272127 100644 --- a/src/common/atlas/apiClient.ts +++ b/src/common/atlas/apiClient.ts @@ -4,7 +4,7 @@ import { AccessToken, ClientCredentials } from "simple-oauth2"; import { ApiClientError } from "./apiClientError.js"; import { paths, operations } from "./openapi.js"; import { CommonProperties, TelemetryEvent } from "../../telemetry/types.js"; -import { packageInfo } from "../../packageInfo.js"; +import { packageInfo } from "../../helpers/packageInfo.js"; const ATLAS_API_VERSION = "2025-03-12"; diff --git a/src/helpers/connectionOptions.ts b/src/helpers/connectionOptions.ts new file mode 100644 index 00000000..10b1ecc8 --- /dev/null +++ b/src/helpers/connectionOptions.ts @@ -0,0 +1,20 @@ +import { MongoClientOptions } from "mongodb"; +import ConnectionString from "mongodb-connection-string-url"; + +export function setAppNameParamIfMissing({ + connectionString, + defaultAppName, +}: { + connectionString: string; + defaultAppName?: string; +}): string { + const connectionStringUrl = new ConnectionString(connectionString); + + const searchParams = connectionStringUrl.typedSearchParams(); + + if (!searchParams.has("appName") && defaultAppName !== undefined) { + searchParams.set("appName", defaultAppName); + } + + return connectionStringUrl.toString(); +} diff --git a/src/deferred-promise.ts b/src/helpers/deferred-promise.ts similarity index 100% rename from src/deferred-promise.ts rename to src/helpers/deferred-promise.ts diff --git a/src/packageInfo.ts b/src/helpers/packageInfo.ts similarity index 61% rename from src/packageInfo.ts rename to src/helpers/packageInfo.ts index dea9214b..6c075dc0 100644 --- a/src/packageInfo.ts +++ b/src/helpers/packageInfo.ts @@ -1,4 +1,4 @@ -import packageJson from "../package.json" with { type: "json" }; +import packageJson from "../../package.json" with { type: "json" }; export const packageInfo = { version: packageJson.version, diff --git a/src/index.ts b/src/index.ts index 20a60e53..f91db447 100644 --- a/src/index.ts +++ b/src/index.ts @@ -6,7 +6,7 @@ import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import { config } from "./config.js"; import { Session } from "./session.js"; import { Server } from "./server.js"; -import { packageInfo } from "./packageInfo.js"; +import { packageInfo } from "./helpers/packageInfo.js"; import { Telemetry } from "./telemetry/telemetry.js"; try { diff --git a/src/session.ts b/src/session.ts index 6f219c41..a7acabb1 100644 --- a/src/session.ts +++ b/src/session.ts @@ -4,6 +4,8 @@ import { Implementation } from "@modelcontextprotocol/sdk/types.js"; import logger, { LogId } from "./logger.js"; import EventEmitter from "events"; import { ConnectOptions } from "./config.js"; +import { setAppNameParamIfMissing } from "./helpers/connectionOptions.js"; +import { packageInfo } from "./helpers/packageInfo.js"; export interface SessionOptions { apiBaseUrl: string; @@ -98,6 +100,10 @@ export class Session extends EventEmitter<{ } async connectToMongoDB(connectionString: string, connectOptions: ConnectOptions): Promise { + connectionString = setAppNameParamIfMissing({ + connectionString, + defaultAppName: `${packageInfo.mcpServerName} ${packageInfo.version}`, + }); const provider = await NodeDriverServiceProvider.connect(connectionString, { productDocsLink: "https://docs.mongodb.com/todo-mcp", productName: "MongoDB MCP", diff --git a/src/telemetry/constants.ts b/src/telemetry/constants.ts index 998f6e24..9dd1cc76 100644 --- a/src/telemetry/constants.ts +++ b/src/telemetry/constants.ts @@ -1,4 +1,4 @@ -import { packageInfo } from "../packageInfo.js"; +import { packageInfo } from "../helpers/packageInfo.js"; import { type CommonStaticProperties } from "./types.js"; /** diff --git a/src/telemetry/telemetry.ts b/src/telemetry/telemetry.ts index 30a0363b..5f8554e6 100644 --- a/src/telemetry/telemetry.ts +++ b/src/telemetry/telemetry.ts @@ -7,7 +7,7 @@ import { MACHINE_METADATA } from "./constants.js"; import { EventCache } from "./eventCache.js"; import { createHmac } from "crypto"; import nodeMachineId from "node-machine-id"; -import { DeferredPromise } from "../deferred-promise.js"; +import { DeferredPromise } from "../helpers/deferred-promise.js"; type EventResult = { success: boolean; @@ -40,7 +40,6 @@ export class Telemetry { commonProperties = { ...MACHINE_METADATA }, eventCache = EventCache.getInstance(), - // eslint-disable-next-line @typescript-eslint/no-unsafe-return, @typescript-eslint/no-unsafe-call, @typescript-eslint/no-unsafe-member-access getRawMachineId = () => nodeMachineId.machineId(true), }: { eventCache?: EventCache; diff --git a/src/types/mongodb-connection-string-url.d.ts b/src/types/mongodb-connection-string-url.d.ts new file mode 100644 index 00000000..01a0cff2 --- /dev/null +++ b/src/types/mongodb-connection-string-url.d.ts @@ -0,0 +1,69 @@ +declare module "mongodb-connection-string-url" { + import { URL } from "whatwg-url"; + import { redactConnectionString, ConnectionStringRedactionOptions } from "./redact"; + export { redactConnectionString, ConnectionStringRedactionOptions }; + declare class CaseInsensitiveMap extends Map { + delete(name: K): boolean; + get(name: K): string | undefined; + has(name: K): boolean; + set(name: K, value: any): this; + _normalizeKey(name: any): K; + } + declare abstract class URLWithoutHost extends URL { + abstract get host(): never; + abstract set host(value: never); + abstract get hostname(): never; + abstract set hostname(value: never); + abstract get port(): never; + abstract set port(value: never); + abstract get href(): string; + abstract set href(value: string); + } + export interface ConnectionStringParsingOptions { + looseValidation?: boolean; + } + export declare class ConnectionString extends URLWithoutHost { + _hosts: string[]; + constructor(uri: string, options?: ConnectionStringParsingOptions); + get host(): never; + set host(_ignored: never); + get hostname(): never; + set hostname(_ignored: never); + get port(): never; + set port(_ignored: never); + get href(): string; + set href(_ignored: string); + get isSRV(): boolean; + get hosts(): string[]; + set hosts(list: string[]); + toString(): string; + clone(): ConnectionString; + redact(options?: ConnectionStringRedactionOptions): ConnectionString; + typedSearchParams(): { + append(name: keyof T & string, value: any): void; + delete(name: keyof T & string): void; + get(name: keyof T & string): string | null; + getAll(name: keyof T & string): string[]; + has(name: keyof T & string): boolean; + set(name: keyof T & string, value: any): void; + keys(): IterableIterator; + values(): IterableIterator; + entries(): IterableIterator<[keyof T & string, string]>; + _normalizeKey(name: keyof T & string): string; + [Symbol.iterator](): IterableIterator<[keyof T & string, string]>; + sort(): void; + forEach( + callback: (this: THIS_ARG, value: string, name: string, searchParams: any) => void, + thisArg?: THIS_ARG | undefined + ): void; + readonly [Symbol.toStringTag]: "URLSearchParams"; + }; + } + export declare class CommaAndColonSeparatedRecord< + K extends {} = Record, + > extends CaseInsensitiveMap { + constructor(from?: string | null); + toString(): string; + } + export default ConnectionString; +} diff --git a/tests/integration/telemetry.test.ts b/tests/integration/telemetry.test.ts index fe8e51ff..522c1154 100644 --- a/tests/integration/telemetry.test.ts +++ b/tests/integration/telemetry.test.ts @@ -6,7 +6,6 @@ import nodeMachineId from "node-machine-id"; describe("Telemetry", () => { it("should resolve the actual machine ID", async () => { - // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment, @typescript-eslint/no-unsafe-call, @typescript-eslint/no-unsafe-member-access const actualId: string = await nodeMachineId.machineId(true); const actualHashedId = createHmac("sha256", actualId.toUpperCase()).update("atlascli").digest("hex"); diff --git a/tests/unit/deferred-promise.test.ts b/tests/unit/deferred-promise.test.ts index c6011af1..5fdaba7d 100644 --- a/tests/unit/deferred-promise.test.ts +++ b/tests/unit/deferred-promise.test.ts @@ -1,4 +1,4 @@ -import { DeferredPromise } from "../../src/deferred-promise.js"; +import { DeferredPromise } from "../../src/helpers/deferred-promise.js"; import { jest } from "@jest/globals"; describe("DeferredPromise", () => { diff --git a/tests/unit/session.test.ts b/tests/unit/session.test.ts new file mode 100644 index 00000000..f60feca1 --- /dev/null +++ b/tests/unit/session.test.ts @@ -0,0 +1,65 @@ +import { jest } from "@jest/globals"; +import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; +import { Session } from "../../src/session.js"; +import { config } from "../../src/config.js"; + +jest.mock("@mongosh/service-provider-node-driver"); +const MockNodeDriverServiceProvider = NodeDriverServiceProvider as jest.MockedClass; + +describe("Session", () => { + let session: Session; + beforeEach(() => { + session = new Session({ + apiClientId: "test-client-id", + apiBaseUrl: "https://api.test.com", + }); + + MockNodeDriverServiceProvider.connect = jest.fn(() => + Promise.resolve({} as unknown as NodeDriverServiceProvider) + ); + }); + + describe("connectToMongoDB", () => { + const testCases: { + connectionString: string; + expectAppName: boolean; + name: string; + }[] = [ + { + connectionString: "mongodb://localhost:27017", + expectAppName: true, + name: "db without appName", + }, + { + connectionString: "mongodb://localhost:27017?appName=CustomAppName", + expectAppName: false, + name: "db with custom appName", + }, + { + connectionString: + "mongodb+srv://test.mongodb.net/test?retryWrites=true&w=majority&appName=CustomAppName", + expectAppName: false, + name: "atlas db with custom appName", + }, + ]; + + for (const testCase of testCases) { + it(`should update connection string for ${testCase.name}`, async () => { + await session.connectToMongoDB(testCase.connectionString, config.connectOptions); + expect(session.serviceProvider).toBeDefined(); + + // eslint-disable-next-line @typescript-eslint/unbound-method + const connectMock = MockNodeDriverServiceProvider.connect as jest.Mock< + typeof NodeDriverServiceProvider.connect + >; + expect(connectMock).toHaveBeenCalledOnce(); + const connectionString = connectMock.mock.calls[0][0]; + if (testCase.expectAppName) { + expect(connectionString).toContain("appName=MongoDB+MCP+Server"); + } else { + expect(connectionString).not.toContain("appName=MongoDB+MCP+Server"); + } + }); + } + }); +}); diff --git a/tests/unit/telemetry.test.ts b/tests/unit/telemetry.test.ts index 969a4ee8..c1ae28ea 100644 --- a/tests/unit/telemetry.test.ts +++ b/tests/unit/telemetry.test.ts @@ -303,7 +303,11 @@ describe("Telemetry", () => { }); afterEach(() => { - process.env.DO_NOT_TRACK = originalEnv; + if (originalEnv) { + process.env.DO_NOT_TRACK = originalEnv; + } else { + delete process.env.DO_NOT_TRACK; + } }); it("should not send events", async () => {