diff --git a/src/common/atlas/apiClient.ts b/src/common/atlas/apiClient.ts index 0287f72..78bd688 100644 --- a/src/common/atlas/apiClient.ts +++ b/src/common/atlas/apiClient.ts @@ -89,6 +89,11 @@ export class ApiClient { return !!(this.oauth2Client && this.accessToken); } + public async hasValidAccessToken(): Promise { + const accessToken = await this.getAccessToken(); + return accessToken !== undefined; + } + public async getIpInfo(): Promise<{ currentIpv4Address: string; }> { @@ -115,7 +120,6 @@ export class ApiClient { } async sendEvents(events: TelemetryEvent[]): Promise { - let endpoint = "api/private/unauth/telemetry/events"; const headers: Record = { Accept: "application/json", "Content-Type": "application/json", @@ -124,12 +128,41 @@ export class ApiClient { const accessToken = await this.getAccessToken(); if (accessToken) { - endpoint = "api/private/v1.0/telemetry/events"; + const authUrl = new URL("api/private/v1.0/telemetry/events", this.options.baseUrl); headers["Authorization"] = `Bearer ${accessToken}`; + + try { + const response = await fetch(authUrl, { + method: "POST", + headers, + body: JSON.stringify(events), + }); + + if (response.ok) { + return; + } + + // If anything other than 401, throw the error + if (response.status !== 401) { + throw await ApiClientError.fromResponse(response); + } + + // For 401, fall through to unauthenticated endpoint + delete headers["Authorization"]; + } catch (error) { + // If the error is not a 401, rethrow it + if (!(error instanceof ApiClientError) || error.response.status !== 401) { + throw error; + } + + // For 401 errors, fall through to unauthenticated endpoint + delete headers["Authorization"]; + } } - const url = new URL(endpoint, this.options.baseUrl); - const response = await fetch(url, { + // Send to unauthenticated endpoint (either as fallback from 401 or direct if no token) + const unauthUrl = new URL("api/private/unauth/telemetry/events", this.options.baseUrl); + const response = await fetch(unauthUrl, { method: "POST", headers, body: JSON.stringify(events), @@ -237,6 +270,7 @@ export class ApiClient { "/api/atlas/v2/groups/{groupId}/clusters/{clusterName}", options ); + if (error) { throw ApiClientError.fromError(response, error); } diff --git a/src/server.ts b/src/server.ts index 4d2df64..091ebd7 100644 --- a/src/server.ts +++ b/src/server.ts @@ -104,7 +104,7 @@ export class Server { * @param command - The server command (e.g., "start", "stop", "register", "deregister") * @param additionalProperties - Additional properties specific to the event */ - emitServerEvent(command: ServerCommand, commandDuration: number, error?: Error) { + private emitServerEvent(command: ServerCommand, commandDuration: number, error?: Error) { const event: ServerEvent = { timestamp: new Date().toISOString(), source: "mdbmcp", @@ -185,5 +185,22 @@ export class Server { throw new Error("Failed to connect to MongoDB instance using the connection string from the config"); } } + + if (this.userConfig.apiClientId && this.userConfig.apiClientSecret) { + try { + await this.session.apiClient.hasValidAccessToken(); + } catch (error) { + if (this.userConfig.connectionString === undefined) { + console.error("Failed to validate MongoDB Atlas the credentials from the config: ", error); + + throw new Error( + "Failed to connect to MongoDB Atlas instance using the credentials from the config" + ); + } + console.error( + "Failed to validate MongoDB Atlas the credentials from the config, but validated the connection string." + ); + } + } } } diff --git a/tests/integration/helpers.ts b/tests/integration/helpers.ts index fd79ecf..fb79d08 100644 --- a/tests/integration/helpers.ts +++ b/tests/integration/helpers.ts @@ -7,6 +7,7 @@ import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import { Session } from "../../src/session.js"; import { Telemetry } from "../../src/telemetry/telemetry.js"; import { config } from "../../src/config.js"; +import { jest } from "@jest/globals"; interface ParameterInfo { name: string; @@ -57,6 +58,12 @@ export function setupIntegrationTest(getUserConfig: () => UserConfig): Integrati apiClientSecret: userConfig.apiClientSecret, }); + // Mock hasValidAccessToken for tests + if (userConfig.apiClientId && userConfig.apiClientSecret) { + const mockFn = jest.fn<() => Promise>().mockResolvedValue(true); + session.apiClient.hasValidAccessToken = mockFn; + } + userConfig.telemetry = "disabled"; const telemetry = Telemetry.create(session, userConfig); @@ -70,6 +77,7 @@ export function setupIntegrationTest(getUserConfig: () => UserConfig): Integrati version: "5.2.3", }), }); + await mcpServer.connect(serverTransport); await mcpClient.connect(clientTransport); }); diff --git a/tests/unit/apiClient.test.ts b/tests/unit/apiClient.test.ts new file mode 100644 index 0000000..a704e6b --- /dev/null +++ b/tests/unit/apiClient.test.ts @@ -0,0 +1,172 @@ +import { jest } from "@jest/globals"; +import { ApiClient } from "../../src/common/atlas/apiClient.js"; +import { CommonProperties, TelemetryEvent, TelemetryResult } from "../../src/telemetry/types.js"; + +describe("ApiClient", () => { + let apiClient: ApiClient; + + const mockEvents: TelemetryEvent[] = [ + { + timestamp: new Date().toISOString(), + source: "mdbmcp", + properties: { + mcp_client_version: "1.0.0", + mcp_client_name: "test-client", + mcp_server_version: "1.0.0", + mcp_server_name: "test-server", + platform: "test-platform", + arch: "test-arch", + os_type: "test-os", + component: "test-component", + duration_ms: 100, + result: "success" as TelemetryResult, + category: "test-category", + }, + }, + ]; + + beforeEach(() => { + apiClient = new ApiClient({ + baseUrl: "https://api.test.com", + credentials: { + clientId: "test-client-id", + clientSecret: "test-client-secret", + }, + userAgent: "test-user-agent", + }); + + // @ts-expect-error accessing private property for testing + apiClient.getAccessToken = jest.fn().mockResolvedValue("mockToken"); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + describe("constructor", () => { + it("should create a client with the correct configuration", () => { + expect(apiClient).toBeDefined(); + expect(apiClient.hasCredentials()).toBeDefined(); + }); + }); + + describe("listProjects", () => { + it("should return a list of projects", async () => { + const mockProjects = { + results: [ + { id: "1", name: "Project 1" }, + { id: "2", name: "Project 2" }, + ], + totalCount: 2, + }; + + const mockGet = jest.fn().mockImplementation(() => ({ + data: mockProjects, + error: null, + response: new Response(), + })); + + // @ts-expect-error accessing private property for testing + apiClient.client.GET = mockGet; + + const result = await apiClient.listProjects(); + + expect(mockGet).toHaveBeenCalledWith("/api/atlas/v2/groups", undefined); + expect(result).toEqual(mockProjects); + }); + + it("should throw an error when the API call fails", async () => { + const mockError = { + reason: "Test error", + detail: "Something went wrong", + }; + + const mockGet = jest.fn().mockImplementation(() => ({ + data: null, + error: mockError, + response: new Response(), + })); + + // @ts-expect-error accessing private property for testing + apiClient.client.GET = mockGet; + + await expect(apiClient.listProjects()).rejects.toThrow(); + }); + }); + + describe("sendEvents", () => { + it("should send events to authenticated endpoint when token is available", async () => { + const mockFetch = jest.spyOn(global, "fetch"); + mockFetch.mockResolvedValueOnce(new Response(null, { status: 200 })); + + await apiClient.sendEvents(mockEvents); + + const url = new URL("api/private/v1.0/telemetry/events", "https://api.test.com"); + expect(mockFetch).toHaveBeenCalledWith(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: "Bearer mockToken", + Accept: "application/json", + "User-Agent": "test-user-agent", + }, + body: JSON.stringify(mockEvents), + }); + }); + + it("should fall back to unauthenticated endpoint when token is not available", async () => { + const mockFetch = jest.spyOn(global, "fetch"); + mockFetch.mockResolvedValueOnce(new Response(null, { status: 200 })); + + // @ts-expect-error accessing private property for testing + apiClient.getAccessToken = jest.fn().mockResolvedValue(undefined); + + await apiClient.sendEvents(mockEvents); + + const url = new URL("api/private/unauth/telemetry/events", "https://api.test.com"); + expect(mockFetch).toHaveBeenCalledWith(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json", + "User-Agent": "test-user-agent", + }, + body: JSON.stringify(mockEvents), + }); + }); + + it("should fall back to unauthenticated endpoint on 401 error", async () => { + const mockFetch = jest.spyOn(global, "fetch"); + mockFetch + .mockResolvedValueOnce(new Response(null, { status: 401 })) + .mockResolvedValueOnce(new Response(null, { status: 200 })); + + await apiClient.sendEvents(mockEvents); + + const url = new URL("api/private/unauth/telemetry/events", "https://api.test.com"); + expect(mockFetch).toHaveBeenCalledTimes(2); + expect(mockFetch).toHaveBeenLastCalledWith(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json", + "User-Agent": "test-user-agent", + }, + body: JSON.stringify(mockEvents), + }); + }); + + it("should throw error when both authenticated and unauthenticated requests fail", async () => { + const mockFetch = jest.spyOn(global, "fetch"); + mockFetch + .mockResolvedValueOnce(new Response(null, { status: 401 })) + .mockResolvedValueOnce(new Response(null, { status: 500 })); + + const mockToken = "test-token"; + // @ts-expect-error accessing private property for testing + apiClient.getAccessToken = jest.fn().mockResolvedValue(mockToken); + + await expect(apiClient.sendEvents(mockEvents)).rejects.toThrow(); + }); + }); +});