Skip to content

fix: validate creds #222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions src/common/atlas/apiClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ export class ApiClient {
return !!(this.oauth2Client && this.accessToken);
}

public async hasValidAccessToken(): Promise<boolean> {
const accessToken = await this.getAccessToken();
return accessToken !== undefined;
}

public async getIpInfo(): Promise<{
currentIpv4Address: string;
}> {
Expand All @@ -115,7 +120,6 @@ export class ApiClient {
}

async sendEvents(events: TelemetryEvent<CommonProperties>[]): Promise<void> {
let endpoint = "api/private/unauth/telemetry/events";
const headers: Record<string, string> = {
Accept: "application/json",
"Content-Type": "application/json",
Expand All @@ -124,12 +128,41 @@ export class ApiClient {

const accessToken = await this.getAccessToken();
if (accessToken) {
endpoint = "api/private/v1.0/telemetry/events";
const authUrl = new URL("api/private/v1.0/telemetry/events", this.options.baseUrl);
headers["Authorization"] = `Bearer ${accessToken}`;

try {
const response = await fetch(authUrl, {
method: "POST",
headers,
body: JSON.stringify(events),
});

if (response.ok) {
return;
}

// If anything other than 401, throw the error
if (response.status !== 401) {
throw await ApiClientError.fromResponse(response);
}

// For 401, fall through to unauthenticated endpoint
delete headers["Authorization"];
} catch (error) {
// If the error is not a 401, rethrow it
if (!(error instanceof ApiClientError) || error.response.status !== 401) {
throw error;
}

// For 401 errors, fall through to unauthenticated endpoint
delete headers["Authorization"];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we ever get here? Doesn't hurt to have it just in case, but also wondering if I'm not seeing something.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if someone sets connection string + invalid api keys, we let them start the server

Copy link
Collaborator

@gagik gagik May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if fetch errors, we'd get here right? actually not sure

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'd get here if we get a 401, if any other error, we do nothing

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, my point was that we need to get a 401 wrapped in a ApiClientError, which fetch shouldn't be doing, should it? The only place we can throw ApiClientError is on line 146, but that explicitly avoids throwing for 401s.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it! i think you're right, let me try to fix it

}
}

const url = new URL(endpoint, this.options.baseUrl);
const response = await fetch(url, {
// Send to unauthenticated endpoint (either as fallback from 401 or direct if no token)
const unauthUrl = new URL("api/private/unauth/telemetry/events", this.options.baseUrl);
const response = await fetch(unauthUrl, {
method: "POST",
headers,
body: JSON.stringify(events),
Expand Down Expand Up @@ -237,6 +270,7 @@ export class ApiClient {
"/api/atlas/v2/groups/{groupId}/clusters/{clusterName}",
options
);

if (error) {
throw ApiClientError.fromError(response, error);
}
Expand Down
19 changes: 18 additions & 1 deletion src/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ export class Server {
* @param command - The server command (e.g., "start", "stop", "register", "deregister")
* @param additionalProperties - Additional properties specific to the event
*/
emitServerEvent(command: ServerCommand, commandDuration: number, error?: Error) {
private emitServerEvent(command: ServerCommand, commandDuration: number, error?: Error) {
const event: ServerEvent = {
timestamp: new Date().toISOString(),
source: "mdbmcp",
Expand Down Expand Up @@ -185,5 +185,22 @@ export class Server {
throw new Error("Failed to connect to MongoDB instance using the connection string from the config");
}
}

if (this.userConfig.apiClientId && this.userConfig.apiClientSecret) {
try {
await this.session.apiClient.hasValidAccessToken();
} catch (error) {
if (this.userConfig.connectionString === undefined) {
console.error("Failed to validate MongoDB Atlas the credentials from the config: ", error);

throw new Error(
"Failed to connect to MongoDB Atlas instance using the credentials from the config"
);
}
console.error(
"Failed to validate MongoDB Atlas the credentials from the config, but validated the connection string."
);
}
}
}
}
8 changes: 8 additions & 0 deletions tests/integration/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import { Session } from "../../src/session.js";
import { Telemetry } from "../../src/telemetry/telemetry.js";
import { config } from "../../src/config.js";
import { jest } from "@jest/globals";

interface ParameterInfo {
name: string;
Expand Down Expand Up @@ -57,6 +58,12 @@ export function setupIntegrationTest(getUserConfig: () => UserConfig): Integrati
apiClientSecret: userConfig.apiClientSecret,
});

// Mock hasValidAccessToken for tests
if (userConfig.apiClientId && userConfig.apiClientSecret) {
const mockFn = jest.fn<() => Promise<boolean>>().mockResolvedValue(true);
session.apiClient.hasValidAccessToken = mockFn;
}

userConfig.telemetry = "disabled";

const telemetry = Telemetry.create(session, userConfig);
Expand All @@ -70,6 +77,7 @@ export function setupIntegrationTest(getUserConfig: () => UserConfig): Integrati
version: "5.2.3",
}),
});

await mcpServer.connect(serverTransport);
await mcpClient.connect(clientTransport);
});
Expand Down
172 changes: 172 additions & 0 deletions tests/unit/apiClient.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import { jest } from "@jest/globals";
import { ApiClient } from "../../src/common/atlas/apiClient.js";
import { CommonProperties, TelemetryEvent, TelemetryResult } from "../../src/telemetry/types.js";

describe("ApiClient", () => {
let apiClient: ApiClient;

const mockEvents: TelemetryEvent<CommonProperties>[] = [
{
timestamp: new Date().toISOString(),
source: "mdbmcp",
properties: {
mcp_client_version: "1.0.0",
mcp_client_name: "test-client",
mcp_server_version: "1.0.0",
mcp_server_name: "test-server",
platform: "test-platform",
arch: "test-arch",
os_type: "test-os",
component: "test-component",
duration_ms: 100,
result: "success" as TelemetryResult,
category: "test-category",
},
},
];

beforeEach(() => {
apiClient = new ApiClient({
baseUrl: "https://api.test.com",
credentials: {
clientId: "test-client-id",
clientSecret: "test-client-secret",
},
userAgent: "test-user-agent",
});

// @ts-expect-error accessing private property for testing
apiClient.getAccessToken = jest.fn().mockResolvedValue("mockToken");
});

afterEach(() => {
jest.clearAllMocks();
});

describe("constructor", () => {
it("should create a client with the correct configuration", () => {
expect(apiClient).toBeDefined();
expect(apiClient.hasCredentials()).toBeDefined();
});
});

describe("listProjects", () => {
it("should return a list of projects", async () => {
const mockProjects = {
results: [
{ id: "1", name: "Project 1" },
{ id: "2", name: "Project 2" },
],
totalCount: 2,
};

const mockGet = jest.fn().mockImplementation(() => ({
data: mockProjects,
error: null,
response: new Response(),
}));

// @ts-expect-error accessing private property for testing
apiClient.client.GET = mockGet;

const result = await apiClient.listProjects();

expect(mockGet).toHaveBeenCalledWith("/api/atlas/v2/groups", undefined);
expect(result).toEqual(mockProjects);
});

it("should throw an error when the API call fails", async () => {
const mockError = {
reason: "Test error",
detail: "Something went wrong",
};

const mockGet = jest.fn().mockImplementation(() => ({
data: null,
error: mockError,
response: new Response(),
}));

// @ts-expect-error accessing private property for testing
apiClient.client.GET = mockGet;

await expect(apiClient.listProjects()).rejects.toThrow();
});
});

describe("sendEvents", () => {
it("should send events to authenticated endpoint when token is available", async () => {
const mockFetch = jest.spyOn(global, "fetch");
mockFetch.mockResolvedValueOnce(new Response(null, { status: 200 }));

await apiClient.sendEvents(mockEvents);

const url = new URL("api/private/v1.0/telemetry/events", "https://api.test.com");
expect(mockFetch).toHaveBeenCalledWith(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: "Bearer mockToken",
Accept: "application/json",
"User-Agent": "test-user-agent",
},
body: JSON.stringify(mockEvents),
});
});

it("should fall back to unauthenticated endpoint when token is not available", async () => {
const mockFetch = jest.spyOn(global, "fetch");
mockFetch.mockResolvedValueOnce(new Response(null, { status: 200 }));

// @ts-expect-error accessing private property for testing
apiClient.getAccessToken = jest.fn().mockResolvedValue(undefined);

await apiClient.sendEvents(mockEvents);

const url = new URL("api/private/unauth/telemetry/events", "https://api.test.com");
expect(mockFetch).toHaveBeenCalledWith(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
Accept: "application/json",
"User-Agent": "test-user-agent",
},
body: JSON.stringify(mockEvents),
});
});

it("should fall back to unauthenticated endpoint on 401 error", async () => {
const mockFetch = jest.spyOn(global, "fetch");
mockFetch
.mockResolvedValueOnce(new Response(null, { status: 401 }))
.mockResolvedValueOnce(new Response(null, { status: 200 }));

await apiClient.sendEvents(mockEvents);

const url = new URL("api/private/unauth/telemetry/events", "https://api.test.com");
expect(mockFetch).toHaveBeenCalledTimes(2);
expect(mockFetch).toHaveBeenLastCalledWith(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
Accept: "application/json",
"User-Agent": "test-user-agent",
},
body: JSON.stringify(mockEvents),
});
});

it("should throw error when both authenticated and unauthenticated requests fail", async () => {
const mockFetch = jest.spyOn(global, "fetch");
mockFetch
.mockResolvedValueOnce(new Response(null, { status: 401 }))
.mockResolvedValueOnce(new Response(null, { status: 500 }));

const mockToken = "test-token";
// @ts-expect-error accessing private property for testing
apiClient.getAccessToken = jest.fn().mockResolvedValue(mockToken);

await expect(apiClient.sendEvents(mockEvents)).rejects.toThrow();
});
});
});
Loading