diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index 75fafb448b..0c6a5ae63a 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -45,4 +45,5 @@ fi export TEST_AUTH_OIDC=1 export COVERAGE=1 export AUTH="auth" +export AWS_WEB_IDENTITY_TOKEN_FILE="$OIDC_TOKEN_DIR/test_user1" bash ./.evergreen/tox.sh -m test-eg diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 8a31a96a3c..d05bfdaa03 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -30,7 +30,7 @@ set -o xtrace AUTH=${AUTH:-noauth} SSL=${SSL:-nossl} -TEST_ARGS="$1" +TEST_ARGS="${*:1}" PYTHON=$(which python) export PIP_QUIET=1 # Quiet by default @@ -52,6 +52,7 @@ if [ "$AUTH" != "noauth" ]; then elif [ ! -z "$TEST_AUTH_OIDC" ]; then export DB_USER=$OIDC_ALTAS_USER export DB_PASSWORD=$OIDC_ATLAS_PASSWORD + export DB_IP="$MONGODB_URI" else export DB_USER="bob" export DB_PASSWORD="pwd123" diff --git a/pymongo/auth.py b/pymongo/auth.py index 1926a3ba92..99075e7f88 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -38,7 +38,12 @@ from bson.binary import Binary from bson.son import SON from pymongo.auth_aws import _authenticate_aws -from pymongo.auth_oidc import _authenticate_oidc, _get_authenticator, _OIDCProperties +from pymongo.auth_oidc import ( + _authenticate_oidc, + _get_authenticator, + _OIDCAWSCallback, + _OIDCProperties, +) from pymongo.errors import ConfigurationError, OperationFailure from pymongo.saslprep import saslprep @@ -164,7 +169,8 @@ def _build_credentials_tuple( elif mech == "MONGODB-OIDC": properties = extra.get("authmechanismproperties", {}) request_token_callback = properties.get("request_token_callback") - provider_name = properties.get("PROVIDER_NAME", "") + custom_token_callback = properties.get("custom_token_callback") + provider_name = properties.get("PROVIDER_NAME") default_allowed = [ "*.mongodb.net", "*.mongodb-dev.net", @@ -175,12 +181,25 @@ def _build_credentials_tuple( "::1", ] allowed_hosts = properties.get("allowed_hosts", default_allowed) - if not request_token_callback and provider_name != "aws": - raise ConfigurationError( - "authentication with MONGODB-OIDC requires providing an request_token_callback or a provider_name of 'aws'" - ) + msg = "authentication with MONGODB-OIDC requires providing a request_token_callback, a provider_name, or a custom_token_callback" + if request_token_callback is not None: + if provider_name is not None or custom_token_callback is not None: + raise ConfigurationError(msg) + elif provider_name is not None: + if custom_token_callback is not None: + raise ConfigurationError(msg) + if provider_name == "aws": + custom_token_callback = _OIDCAWSCallback() + else: + raise ConfigurationError( + f"unrecognized provider_name for MONGODB-OIDC: {provider_name}" + ) + elif custom_token_callback is None: + raise ConfigurationError(msg) + oidc_props = _OIDCProperties( request_token_callback=request_token_callback, + custom_token_callback=custom_token_callback, provider_name=provider_name, allowed_hosts=allowed_hosts, ) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index ad9223809e..f796f7225f 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -15,13 +15,16 @@ """MONGODB-OIDC Authentication helpers.""" from __future__ import annotations +import abc +import os import threading from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Optional +from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union import bson from bson.binary import Binary from bson.son import SON +from pymongo._csot import remaining from pymongo.errors import ConfigurationError, OperationFailure if TYPE_CHECKING: @@ -29,18 +32,76 @@ from pymongo.pool import Connection +@dataclass +class OIDCIdPInfo: + issuer: str + clientId: str + requestScopes: Optional[list[str]] = field(default=None) + + +@dataclass +class OIDCHumanCallbackContext: + timeout_seconds: float + version: int + refresh_token: Optional[str] = field(default=None) + + +@dataclass +class OIDCMachineCallbackContext: + timeout_seconds: float + version: int + + +@dataclass +class OIDCHumanCallbackResult: + access_token: str + expires_in_seconds: Optional[float] = field(default=None) + refresh_token: Optional[str] = field(default=None) + + +@dataclass +class OIDCMachineCallbackResult: + access_token: str + expires_in_seconds: Optional[float] = field(default=None) + + +class OIDCMachineCallback(abc.ABC): + """A base class for defining OIDC machine (workload federation) + callbacks. + """ + + @abc.abstractmethod + def fetch(self, context: OIDCMachineCallbackContext) -> OIDCMachineCallbackResult: + """Convert the given BSON value into our own type.""" + + +class OIDCHumanCallback(abc.ABC): + """A base class for defining OIDC human (workforce federation) + callbacks. + """ + + @abc.abstractmethod + def fetch( + self, idp_info: OIDCIdPInfo, context: OIDCHumanCallbackContext + ) -> OIDCHumanCallbackResult: + """Convert the given BSON value into our own type.""" + + @dataclass class _OIDCProperties: - request_token_callback: Optional[Callable[..., dict]] - provider_name: Optional[str] - allowed_hosts: list[str] + request_token_callback: Optional[OIDCHumanCallback] = field(default=None) + custom_token_callback: Optional[OIDCMachineCallback] = field(default=None) + provider_name: Optional[str] = field(default=None) + allowed_hosts: list[str] = field(default_factory=list) """Mechanism properties for MONGODB-OIDC authentication.""" TOKEN_BUFFER_MINUTES = 5 -CALLBACK_TIMEOUT_SECONDS = 5 * 60 -CALLBACK_VERSION = 1 +HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60 +HUMAN_CALLBACK_VERSION = 1 +MACHINE_CALLBACK_TIMEOUT_SECONDS = 60 +MACHINE_CALLBACK_VERSION = 1 def _get_authenticator( @@ -72,22 +133,37 @@ def _get_authenticator( return credentials.cache.data +class _OIDCAWSCallback(OIDCMachineCallback): + def fetch(self, context: OIDCMachineCallbackContext) -> OIDCMachineCallbackResult: + token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE") + if not token_file: + raise RuntimeError( + 'MONGODB-OIDC with an "aws" provider requires "AWS_WEB_IDENTITY_TOKEN_FILE" to be set' + ) + with open(token_file) as fid: + return OIDCMachineCallbackResult(access_token=fid.read().strip()) + + @dataclass class _OIDCAuthenticator: username: str properties: _OIDCProperties refresh_token: Optional[str] = field(default=None) access_token: Optional[str] = field(default=None) - idp_info: Optional[dict] = field(default=None) + idp_info: Optional[OIDCIdPInfo] = field(default=None) token_gen_id: int = field(default=0) lock: threading.Lock = field(default_factory=threading.Lock) def get_current_token(self, use_callback: bool = True) -> Optional[str]: properties = self.properties - - # TODO: DRIVERS-2672, handle machine callback here as well. - cb = properties.request_token_callback if use_callback else None - cb_type = "human" + cb: Union[None, OIDCHumanCallback, OIDCMachineCallback] + resp: Union[None, OIDCHumanCallbackResult, OIDCMachineCallbackResult] + if not use_callback: + cb = None + elif properties.request_token_callback: + cb = properties.request_token_callback + elif properties.custom_token_callback: + cb = properties.custom_token_callback prev_token = self.access_token if prev_token: @@ -104,37 +180,32 @@ def get_current_token(self, use_callback: bool = True) -> Optional[str]: if new_token != prev_token: return new_token - # TODO: DRIVERS-2672 handle machine callback here. - if cb_type == "human": - context = { - "timeout_seconds": CALLBACK_TIMEOUT_SECONDS, - "version": CALLBACK_VERSION, - "refresh_token": self.refresh_token, - } - resp = cb(self.idp_info, context) - - self.validate_request_token_response(resp) - + if isinstance(cb, OIDCHumanCallback): + human_context = OIDCHumanCallbackContext( + timeout_seconds=HUMAN_CALLBACK_TIMEOUT_SECONDS, + version=HUMAN_CALLBACK_VERSION, + refresh_token=self.refresh_token, + ) + assert self.idp_info is not None + resp = cb.fetch(self.idp_info, human_context) + if not isinstance(resp, OIDCHumanCallbackResult): + raise ValueError("Callback result must be of type OIDCHumanCallbackResult") + self.refresh_token = resp.refresh_token + else: + machine_context = OIDCMachineCallbackContext( + timeout_seconds=remaining() or MACHINE_CALLBACK_TIMEOUT_SECONDS, + version=MACHINE_CALLBACK_VERSION, + ) + resp = cb.fetch(machine_context) + if not isinstance(resp, OIDCMachineCallbackResult): + raise ValueError( + "Callback result must be of type OIDCMachineCallbackResult" + ) + self.access_token = resp.access_token self.token_gen_id += 1 return self.access_token - def validate_request_token_response(self, resp: Mapping[str, Any]) -> None: - # Validate callback return value. - if not isinstance(resp, dict): - raise ValueError("OIDC callback returned invalid result") - - if "access_token" not in resp: - raise ValueError("OIDC callback did not return an access_token") - - expected = ["access_token", "refresh_token", "expires_in_seconds"] - for key in resp: - if key not in expected: - raise ValueError(f'Unexpected field in callback result "{key}"') - - self.access_token = resp["access_token"] - self.refresh_token = resp.get("refresh_token") - def principal_step_cmd(self) -> SON[str, Any]: """Get a SASL start command with an optional principal name""" # Send the SASL start with the optional principal name. @@ -154,8 +225,7 @@ def principal_step_cmd(self) -> SON[str, Any]: ) def auth_start_cmd(self, use_callback: bool = True) -> Optional[SON[str, Any]]: - # TODO: DRIVERS-2672, check for provider_name in self.properties here. - if self.idp_info is None: + if self.properties.request_token_callback is not None and self.idp_info is None: return self.principal_step_cmd() token = self.get_current_token(use_callback) @@ -192,8 +262,10 @@ def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]: self.access_token = None - # TODO: DRIVERS-2672, check for provider_name in self.properties here. - # If so, we clear the access token and return finish_auth. + # If we are using machine callbacks, clear the access token and + # re-authenticate. + if self.properties.provider_name: + return self.authenticate(conn) # Next see if the idp info has changed. prev_idp_info = self.idp_info @@ -203,7 +275,7 @@ def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]: assert resp is not None server_resp: dict = bson.decode(resp["payload"]) if "issuer" in server_resp: - self.idp_info = server_resp + self.idp_info = OIDCIdPInfo(**server_resp) # Handle the case of changed idp info. if self.idp_info != prev_idp_info: @@ -240,7 +312,7 @@ def authenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]: server_resp: dict = bson.decode(resp["payload"]) if "issuer" in server_resp: - self.idp_info = server_resp + self.idp_info = OIDCIdPInfo(**server_resp) return self.finish_auth(resp, conn) @@ -273,4 +345,10 @@ def _authenticate_oidc( if reauthenticate: return authenticator.reauthenticate(conn) else: - return authenticator.authenticate(conn) + try: + return authenticator.authenticate(conn) + except Exception as e: + # Try one more time an an authentication failure. + if isinstance(e, OperationFailure) and e.code == 18: + return authenticator.authenticate(conn) + raise diff --git a/pymongo/common.py b/pymongo/common.py index bda294af93..16a0f50c9c 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -17,7 +17,6 @@ from __future__ import annotations import datetime -import inspect import warnings from collections import OrderedDict, abc from typing import ( @@ -41,6 +40,7 @@ from bson.codec_options import CodecOptions, DatetimeConversion, TypeRegistry from bson.raw_bson import RawBSONDocument from pymongo.auth import MECHANISMS +from pymongo.auth_oidc import OIDCHumanCallback, OIDCMachineCallback from pymongo.compression_support import ( validate_compressors, validate_zlib_compression_level, @@ -438,20 +438,16 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni props[key] = str(value).lower() elif key in ["allowed_hosts"] and isinstance(value, list): props[key] = value - elif inspect.isfunction(value): - signature = inspect.signature(value) - if key == "request_token_callback": - expected_params = 2 - else: - raise ValueError(f"Unrecognized Auth mechanism function {key}") - if len(signature.parameters) != expected_params: - msg = f"{key} must accept {expected_params} parameters" - raise ValueError(msg) + elif key == "request_token_callback": + if not isinstance(value, OIDCHumanCallback): + raise ValueError("request_token_callback must be a OIDCHumanCallback object") + props[key] = value + elif key == "custom_token_callback": + if not isinstance(value, OIDCMachineCallback): + raise ValueError("custom_token_callback must be a OIDCMachineCallback object") props[key] = value else: - raise ValueError( - "Auth mechanism property values must be strings or callback functions" - ) + raise ValueError(f"Invalid type for auth mechanism property {key}, {type(value)}") return props value = validate_string(option, value) diff --git a/pyproject.toml b/pyproject.toml index db2c956690..f8ec5535b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -219,6 +219,7 @@ exclude_lines = [ "return NotImplemented", "_use_c = true", "if __name__ == '__main__':", + "if TYPE_CHECKING:" ] partial_branches = ["if (.*and +)*not _use_c( and.*)*:"] diff --git a/test/__init__.py b/test/__init__.py index 4d98a1911c..906f7f18a1 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -67,6 +67,7 @@ # for a replica set. host = os.environ.get("DB_IP", "localhost") port = int(os.environ.get("DB_PORT", 27017)) +IS_SRV = "mongodb+srv" in host db_user = os.environ.get("DB_USER", "user") db_pwd = os.environ.get("DB_PASSWORD", "password") @@ -378,7 +379,7 @@ def _init_client(self): self.auth_enabled = self._server_started_with_auth() if self.auth_enabled: - if not self.serverless: + if not self.serverless and not IS_SRV: # See if db_user already exists. if not self._check_user_provided(): _create_user(self.client.admin, db_user, db_pwd) @@ -446,7 +447,7 @@ def _init_client(self): else: self.server_parameters = self.client.admin.command("getParameter", "*") assert self.cmd_line is not None - if "enableTestCommands=1" in self.cmd_line["argv"]: + if self.server_parameters["enableTestCommands"]: self.test_commands_enabled = True elif "parsed" in self.cmd_line: params = self.cmd_line["parsed"].get("setParameter", []) @@ -482,14 +483,14 @@ def connection_attempt_info(self): @property def host(self): - if self.is_rs: + if self.is_rs and not IS_SRV: primary = self.client.primary return str(primary[0]) if primary is not None else host return host @property def port(self): - if self.is_rs: + if self.is_rs and not IS_SRV: primary = self.client.primary return primary[1] if primary is not None else port return port @@ -514,6 +515,10 @@ def storage_engine(self): # Raised if self.server_status is None. return None + def check_auth_type(self, auth_type): + auth_mechs = self.server_parameters.get("authenticationMechanisms", []) + return auth_type in auth_mechs + def _check_user_provided(self): """Return True if db_user/db_password is already an admin user.""" client: MongoClient = pymongo.MongoClient( diff --git a/test/auth/legacy/connection-string.json b/test/auth/legacy/connection-string.json index 0463a5141e..40ef630ca3 100644 --- a/test/auth/legacy/connection-string.json +++ b/test/auth/legacy/connection-string.json @@ -80,7 +80,7 @@ }, { "description": "should accept generic mechanism property (GSSAPI)", - "uri": "mongodb://user%40DOMAIN.COM@localhost/?authMechanism=GSSAPI&authMechanismProperties=SERVICE_NAME:other,CANONICALIZE_HOST_NAME:true", + "uri": "mongodb://user%40DOMAIN.COM@localhost/?authMechanism=GSSAPI&authMechanismProperties=SERVICE_NAME:other,CANONICALIZE_HOST_NAME:forward,SERVICE_HOST:example.com", "valid": true, "credential": { "username": "user@DOMAIN.COM", @@ -89,10 +89,46 @@ "mechanism": "GSSAPI", "mechanism_properties": { "SERVICE_NAME": "other", - "CANONICALIZE_HOST_NAME": true + "SERVICE_HOST": "example.com", + "CANONICALIZE_HOST_NAME": "forward" } } }, + { + "description": "should accept forwardAndReverse hostname canonicalization (GSSAPI)", + "uri": "mongodb://user%40DOMAIN.COM@localhost/?authMechanism=GSSAPI&authMechanismProperties=SERVICE_NAME:other,CANONICALIZE_HOST_NAME:forwardAndReverse", + "valid": true, + "credential": { + "username": "user@DOMAIN.COM", + "password": null, + "source": "$external", + "mechanism": "GSSAPI", + "mechanism_properties": { + "SERVICE_NAME": "other", + "CANONICALIZE_HOST_NAME": "forwardAndReverse" + } + } + }, + { + "description": "should accept no hostname canonicalization (GSSAPI)", + "uri": "mongodb://user%40DOMAIN.COM@localhost/?authMechanism=GSSAPI&authMechanismProperties=SERVICE_NAME:other,CANONICALIZE_HOST_NAME:none", + "valid": true, + "credential": { + "username": "user@DOMAIN.COM", + "password": null, + "source": "$external", + "mechanism": "GSSAPI", + "mechanism_properties": { + "SERVICE_NAME": "other", + "CANONICALIZE_HOST_NAME": "none" + } + } + }, + { + "description": "must raise an error when the hostname canonicalization is invalid", + "uri": "mongodb://user%40DOMAIN.COM@localhost/?authMechanism=GSSAPI&authMechanismProperties=SERVICE_NAME:other,CANONICALIZE_HOST_NAME:invalid", + "valid": false + }, { "description": "should accept the password (GSSAPI)", "uri": "mongodb://user%40DOMAIN.COM:password@localhost/?authMechanism=GSSAPI&authSource=$external", @@ -448,7 +484,9 @@ { "description": "should recognise the mechanism and request callback (MONGODB-OIDC)", "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", - "callback": ["oidcRequest"], + "callback": [ + "oidcRequest" + ], "valid": true, "credential": { "username": null, @@ -456,14 +494,34 @@ "source": "$external", "mechanism": "MONGODB-OIDC", "mechanism_properties": { - "REQUEST_TOKEN_CALLBACK": true + "REQUEST_TOKEN_CALLBACK": true } } }, { "description": "should recognise the mechanism when auth source is explicitly specified and with request callback (MONGODB-OIDC)", "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external", - "callback": ["oidcRequest"], + "callback": [ + "oidcRequest" + ], + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "REQUEST_TOKEN_CALLBACK": true + } + } + }, + { + "description": "should recognise the mechanism with request and refresh callback (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "callback": [ + "oidcRequest", + "oidcRefresh" + ], "valid": true, "credential": { "username": null, @@ -471,14 +529,17 @@ "source": "$external", "mechanism": "MONGODB-OIDC", "mechanism_properties": { - "REQUEST_TOKEN_CALLBACK": true + "REQUEST_TOKEN_CALLBACK": true, + "REFRESH_TOKEN_CALLBACK": true } } }, { "description": "should recognise the mechanism and username with request callback (MONGODB-OIDC)", "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC", - "callback": ["oidcRequest"], + "callback": [ + "oidcRequest" + ], "valid": true, "credential": { "username": "principalName", @@ -486,7 +547,7 @@ "source": "$external", "mechanism": "MONGODB-OIDC", "mechanism_properties": { - "REQUEST_TOKEN_CALLBACK": true + "REQUEST_TOKEN_CALLBACK": true } } }, @@ -500,7 +561,7 @@ "source": "$external", "mechanism": "MONGODB-OIDC", "mechanism_properties": { - "PROVIDER_NAME": "aws" + "PROVIDER_NAME": "aws" } } }, @@ -514,14 +575,16 @@ "source": "$external", "mechanism": "MONGODB-OIDC", "mechanism_properties": { - "PROVIDER_NAME": "aws" + "PROVIDER_NAME": "aws" } } }, { "description": "should throw an exception if username and password are specified (MONGODB-OIDC)", "uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC", - "callback": ["oidcRequest"], + "callback": [ + "oidcRequest" + ], "valid": false, "credential": null }, @@ -538,8 +601,17 @@ "credential": null }, { - "description": "should throw an exception if neither deviceName nor callback specified (MONGODB-OIDC)", + "description": "should throw an exception if neither deviceName nor callbacks specified (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception when only refresh callback is specified (MONGODB-OIDC)", "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "callback": [ + "oidcRefresh" + ], "valid": false, "credential": null }, diff --git a/test/auth/unified/oidc-auth-with-retry.json b/test/auth/unified/oidc-auth-with-retry.json new file mode 100644 index 0000000000..aeae3288c9 --- /dev/null +++ b/test/auth/unified/oidc-auth-with-retry.json @@ -0,0 +1,170 @@ +{ + "description": "OIDC authentication with retry", + "schemaVersion": "1.18", + "runOnRequirements": [ + { + "minServerVersion": "7.0", + "auth": true, + "authMechanism": "MONGODB-OIDC" + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "uriOptions": { + "authMechanism": "MONGODB-OIDC", + "authMechanismProperties": { + "$$placeholder": 1 + }, + "retryReads": true, + "retryWrites": true + }, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "test" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "collName" + } + } + ], + "initialData": [ + { + "collectionName": "collName", + "databaseName": "test", + "documents": [ + + ] + } + ], + "tests": [ + { + "description": "A simple find operation should succeed", + "operations": [ + { + "name": "find", + "arguments": { + "filter": { + } + }, + "object": "collection0", + "expectResult": [ + + ] + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": { + } + } + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=true", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + } + ] +} diff --git a/test/auth/unified/oidc-auth-without-retry.json b/test/auth/unified/oidc-auth-without-retry.json new file mode 100644 index 0000000000..ad8c93c03f --- /dev/null +++ b/test/auth/unified/oidc-auth-without-retry.json @@ -0,0 +1,175 @@ +{ + "description": "OIDC authentication without retry", + "schemaVersion": "1.18", + "runOnRequirements": [ + { + "minServerVersion": "7.0", + "auth": true, + "authMechanism": "MONGODB-OIDC" + } + ], + "createEntities": [ + { + "client": { + "id": "authClient" + } + }, + { + "client": { + "id": "client0", + "uriOptions": { + "authMechanism": "MONGODB-OIDC", + "authMechanismProperties": { + "$$placeholder": 1 + }, + "retryReads": true, + "retryWrites": true + }, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "test" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "collName" + } + } + ], + "initialData": [ + { + "collectionName": "collName", + "databaseName": "test", + "documents": [ + + ] + } + ], + "tests": [ + { + "description": "A simple find operation should succeed", + "operations": [ + { + "name": "find", + "arguments": { + "filter": { + } + }, + "object": "collection0", + "expectResult": [ + + ] + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": { + } + } + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=true", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + } + ] +} diff --git a/test/auth/unified/reauthenticate_with_retry.json b/test/auth/unified/reauthenticate_with_retry.json index ef110562ed..e094a8b58b 100644 --- a/test/auth/unified/reauthenticate_with_retry.json +++ b/test/auth/unified/reauthenticate_with_retry.json @@ -1,6 +1,6 @@ { "description": "reauthenticate_with_retry", - "schemaVersion": "1.12", + "schemaVersion": "1.3", "runOnRequirements": [ { "minServerVersion": "6.3", diff --git a/test/auth/unified/reauthenticate_without_retry.json b/test/auth/unified/reauthenticate_without_retry.json index 6fded47634..8bbc5cc64d 100644 --- a/test/auth/unified/reauthenticate_without_retry.json +++ b/test/auth/unified/reauthenticate_without_retry.json @@ -1,6 +1,6 @@ { "description": "reauthenticate_without_retry", - "schemaVersion": "1.12", + "schemaVersion": "1.3", "runOnRequirements": [ { "minServerVersion": "6.3", diff --git a/test/auth_oidc/test_auth_oidc.py b/test/auth_oidc/test_auth_oidc.py index 7055816af3..2a39f612a7 100644 --- a/test/auth_oidc/test_auth_oidc.py +++ b/test/auth_oidc/test_auth_oidc.py @@ -19,16 +19,25 @@ import sys import time import unittest +import warnings from contextlib import contextmanager +from pathlib import Path from typing import Dict sys.path[0:0] = [""] +from test.unified_format import generate_test_classes from test.utils import EventListener from bson import SON from pymongo import MongoClient from pymongo.auth import _AUTH_MAP, _authenticate_oidc +from pymongo.auth_oidc import ( + OIDCHumanCallback, + OIDCHumanCallbackResult, + OIDCMachineCallback, + OIDCMachineCallbackResult, +) from pymongo.cursor import CursorType from pymongo.errors import ConfigurationError, OperationFailure from pymongo.hello import HelloCompat @@ -37,40 +46,31 @@ # Force MONGODB-OIDC to be enabled. _AUTH_MAP["MONGODB-OIDC"] = _authenticate_oidc # type:ignore +ROOT = Path(__file__).parent.parent.resolve() +TEST_PATH = ROOT / "auth" / "unified" +PROVIDER_NAME = os.environ.get("OIDC_PROVIDER_NAME", "aws") + +# Generate unified tests. +globals().update(generate_test_classes(str(TEST_PATH), module=__name__)) -class TestAuthOIDC(unittest.TestCase): - uri: str +class OIDCTestBase(unittest.TestCase): @classmethod def setUpClass(cls): cls.uri_single = os.environ["MONGODB_URI_SINGLE"] cls.uri_multiple = os.environ["MONGODB_URI_MULTI"] cls.uri_admin = os.environ["MONGODB_URI"] - cls.token_dir = os.environ["OIDC_TOKEN_DIR"] def setUp(self): self.request_called = 0 - def create_request_cb(self, username="test_user1", sleep=0): - token_file = os.path.join(self.token_dir, username).replace(os.sep, "/") - - def request_token(server_info, context): - # Validate the info. - self.assertIn("issuer", server_info) - self.assertIn("clientId", server_info) - - # Validate the timeout. - timeout_seconds = context["timeout_seconds"] - self.assertEqual(timeout_seconds, 60 * 5) + def get_token(self, username): + """Get a token for the current provider.""" + if PROVIDER_NAME == "aws": + token_dir = os.environ["OIDC_TOKEN_DIR"] + token_file = os.path.join(token_dir, username).replace(os.sep, "/") with open(token_file) as fid: - token = fid.read() - resp = {"access_token": token, "refresh_token": token} - - time.sleep(sleep) - self.request_called += 1 - return resp - - return request_token + return fid.read() @contextmanager def fail_point(self, command_args): @@ -83,6 +83,37 @@ def fail_point(self, command_args): finally: client.admin.command("configureFailPoint", cmd_on["configureFailPoint"], mode="off") + +class TestAuthOIDCHuman(OIDCTestBase): + uri: str + + def create_request_cb(self, username="test_user1", sleep=0): + def request_token(server_info, context): + # Validate the info. + self.assertIsInstance(server_info.issuer, str) + self.assertIsInstance(server_info.clientId, str) + + # Validate the timeout. + timeout_seconds = context.timeout_seconds + self.assertEqual(timeout_seconds, 60 * 5) + token = self.get_token(username) + resp = OIDCHumanCallbackResult(access_token=token, refresh_token=token) + + time.sleep(sleep) + self.request_called += 1 + return resp + + class Inner(OIDCHumanCallback): + def fetch(self, idp_info, context): + return request_token(idp_info, context) + + return Inner() + + def create_client(self, username="test_user1"): + request_cb = self.create_request_cb(username) + props: Dict = {"request_token_callback": request_cb} + return MongoClient(self.uri_multiple, username=username, authmechanismproperties=props) + def test_connect_request_callback_single_implicit_username(self): request_token = self.create_request_cb() props: Dict = {"request_token_callback": request_token} @@ -98,20 +129,12 @@ def test_connect_request_callback_single_explicit_username(self): client.close() def test_connect_request_callback_multiple_principal_user1(self): - request_token = self.create_request_cb() - props: Dict = {"request_token_callback": request_token} - client = MongoClient( - self.uri_multiple, username="test_user1", authmechanismproperties=props - ) + client = self.create_client() client.test.test.find_one() client.close() def test_connect_request_callback_multiple_principal_user2(self): - request_token = self.create_request_cb("test_user2") - props: Dict = {"request_token_callback": request_token} - client = MongoClient( - self.uri_multiple, username="test_user2", authmechanismproperties=props - ) + client = self.create_client("test_user2") client.test.test.find_one() client.close() @@ -132,64 +155,75 @@ def test_allowed_hosts_blocked(self): client.close() props: Dict = {"request_token_callback": request_token, "allowed_hosts": ["example.com"]} - client = MongoClient( - self.uri_single + "&ignored=example.com", authmechanismproperties=props, connect=False - ) + with warnings.catch_warnings(): + warnings.simplefilter("default") + client = MongoClient( + self.uri_single + "&ignored=example.com", + authmechanismproperties=props, + connect=False, + ) with self.assertRaises(ConfigurationError): client.test.test.find_one() client.close() - def test_valid_request_token_callback(self): - request_cb = self.create_request_cb() + def test_configuration_errors(self): + request_token = self.create_request_cb() - props: Dict = { - "request_token_callback": request_cb, - } - client = MongoClient(self.uri_single, authmechanismproperties=props) + class CustomCB(OIDCMachineCallback): + def fetch(self, ctx): + return None + + props: Dict = {"request_token_callback": request_token} + + # Assert that providing both a human callback and a provider raises an error. + props["PROVIDER_NAME"] = PROVIDER_NAME + with self.assertRaises(ConfigurationError): + _ = MongoClient(self.uri_single, authmechanismproperties=props) + props["custom_token_callback"] = CustomCB() + + # Assert that providing both callback types and a provider raises an error. + with self.assertRaises(ConfigurationError): + _ = MongoClient(self.uri_single, authmechanismproperties=props) + del props["PROVIDER_NAME"] + + # Assert that providing both callback types raises an error. + with self.assertRaises(ConfigurationError): + _ = MongoClient(self.uri_single, authmechanismproperties=props) + + def test_valid_request_token_callback(self): + client = self.create_client() client.test.test.find_one() client.close() - client = MongoClient(self.uri_single, authmechanismproperties=props) + client = self.create_client() client.test.test.find_one() client.close() def test_request_callback_returns_null(self): - def request_token_null(a, b): - return None + class RequestTokenNull(OIDCHumanCallback): + def fetch(self, a, b): + return None - props: Dict = {"request_token_callback": request_token_null} + props: Dict = {"request_token_callback": RequestTokenNull()} client = MongoClient(self.uri_single, authMechanismProperties=props) with self.assertRaises(ValueError): client.test.test.find_one() client.close() def test_request_callback_invalid_result(self): - def request_token_invalid(a, b): - return {} + class CallbackInvalidToken(OIDCHumanCallback): + def fetch(self, a, b): + return {} - props: Dict = {"request_token_callback": request_token_invalid} - client = MongoClient(self.uri_single, authMechanismProperties=props) - with self.assertRaises(ValueError): - client.test.test.find_one() - client.close() - - def request_cb_extra_value(server_info, context): - result = self.create_request_cb()(server_info, context) - result["foo"] = "bar" - return result - - props: Dict = {"request_token_callback": request_cb_extra_value} + props: Dict = {"request_token_callback": CallbackInvalidToken()} client = MongoClient(self.uri_single, authMechanismProperties=props) with self.assertRaises(ValueError): client.test.test.find_one() client.close() def test_speculative_auth_success(self): - request_token = self.create_request_cb() - # Create a client with a request callback that returns a valid token. - props: Dict = {"request_token_callback": request_token} - client = MongoClient(self.uri_single, authmechanismproperties=props) + client = self.create_client() # Set a fail point for saslStart commands. with self.fail_point( @@ -262,13 +296,14 @@ def test_reauthenticate_succeeds(self): def test_reauthenticate_succeeds_no_refresh(self): cb = self.create_request_cb() - def request_cb(*args, **kwargs): - result = cb(*args, **kwargs) - del result["refresh_token"] - return result + class CustomRequest(OIDCHumanCallback): + def fetch(self, *args, **kwargs): + result = cb.fetch(*args, **kwargs) + result.refresh_token = None + return result # Create a client with the callback. - props: Dict = {"request_token_callback": request_cb} + props: Dict = {"request_token_callback": CustomRequest()} client = MongoClient(self.uri_single, authmechanismproperties=props) # Perform a find operation. @@ -316,33 +351,54 @@ def test_reauthenticate_succeeds_after_refresh_fails(self): # Assert that the request callback has been called three times. self.assertEqual(self.request_called, 3) - def test_reauthenticate_fails(self): - # Create request callback that returns valid credentials. + def test_reauthentication_succeeds_multiple_connections(self): request_cb = self.create_request_cb() # Create a client with the callback. props: Dict = {"request_token_callback": request_cb} - client = MongoClient(self.uri_single, authmechanismproperties=props) - # Perform a find operation. - client.test.test.find_one() + client1 = MongoClient(self.uri_single, authmechanismproperties=props) + client2 = MongoClient(self.uri_single, authmechanismproperties=props) - # Assert that the request callback has been called once. - self.assertEqual(self.request_called, 1) + # Perform an insert operation. + client1.test.test.insert_many([{"a": 1}, {"a": 1}]) + client2.test.test.find_one() + self.assertEqual(self.request_called, 2) + + # Use the same authenticator for both clients + # to simulate a race condition with separate connections. + # We should only see one extra callback despite both connections + # needing to reauthenticate. + client2.options.pool_options._credentials.cache.data = ( + client1.options.pool_options._credentials.cache.data + ) + + client1.test.test.find_one() + client2.test.test.find_one() with self.fail_point( { - "mode": {"times": 2}, + "mode": {"times": 1}, "data": {"failCommands": ["find"], "errorCode": 391}, } ): - # Perform a find operation that fails. - with self.assertRaises(OperationFailure): - client.test.test.find_one() + client1.test.test.find_one() - # Assert that the request callback has been called twice. - self.assertEqual(self.request_called, 2) - client.close() + self.assertEqual(self.request_called, 3) + + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "errorCode": 391}, + } + ): + client2.test.test.find_one() + + self.assertEqual(self.request_called, 3) + client1.close() + client2.close() + + # PyMongo specific tests, since we have multiple code paths for reauth handling. def test_reauthenticate_succeeds_bulk_write(self): request_cb = self.create_request_cb() @@ -518,14 +574,163 @@ def test_reauthenticate_succeeds_command(self): self.assertEqual(self.request_called, 2) client.close() - def test_reauthentication_succeeds_multiple_connections(self): + +class TestAuthOIDCMachine(OIDCTestBase): + uri: str + + def setUp(self): + self.request_called = 0 + + def create_request_cb(self, username="test_user1", sleep=0): + def request_token(_context): + token = self.get_token(username) + time.sleep(sleep) + self.request_called += 1 + return OIDCMachineCallbackResult(access_token=token) + + class Inner(OIDCMachineCallback): + def fetch(self, context): + return request_token(context) + + return Inner() + + def create_client(self): + request_cb = self.create_request_cb() + props: Dict = {"custom_token_callback": request_cb} + return MongoClient(self.uri_single, authmechanismproperties=props) + + def test_custom_callback(self): + client = self.create_client() + client.test.test.find_one() + client.close() + + def test_callback_is_called_during_reauthentication(self): + listener = EventListener() + + # Create request callback that returns valid credentials. request_cb = self.create_request_cb() # Create a client with the callback. - props: Dict = {"request_token_callback": request_cb} + props: Dict = {"custom_token_callback": request_cb} + client = MongoClient( + self.uri_single, event_listeners=[listener], authmechanismproperties=props + ) - client1 = MongoClient(self.uri_single, authmechanismproperties=props) - client2 = MongoClient(self.uri_single, authmechanismproperties=props) + # Perform a find operation. + client.test.test.find_one() + + # Assert that the request callback has been called once. + self.assertEqual(self.request_called, 1) + + listener.reset() + + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "errorCode": 391}, + } + ): + # Perform a find operation. + client.test.test.find_one() + + started_events = [ + i.command_name for i in listener.started_events if not i.command_name.startswith("sasl") + ] + succeeded_events = [ + i.command_name + for i in listener.succeeded_events + if not i.command_name.startswith("sasl") + ] + failed_events = [ + i.command_name for i in listener.failed_events if not i.command_name.startswith("sasl") + ] + + self.assertEqual( + started_events, + [ + "find", + "find", + ], + ) + self.assertEqual(succeeded_events, ["find"]) + self.assertEqual(failed_events, ["find"]) + + # Assert that the request callback has been called twice. + self.assertEqual(self.request_called, 2) + client.close() + + def test_callback_is_called_twice_on_handshake_authentication_failure(self): + client = self.create_client() + + # Set a fail point for ``saslStart`` commands. + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["saslStart"], "errorCode": 18}, + } + ): + # Perform a find operation. + client.test.test.find_one() + + # Assert that the request callback has been called twice. + self.assertEqual(self.request_called, 2) + client.close() + + def test_request_callback_invalid_result(self): + request_cb = self.create_request_cb() + props: Dict = {"custom_token_callback": request_cb, "PROVIDER_NAME": PROVIDER_NAME} + with self.assertRaises(ConfigurationError): + _ = MongoClient(self.uri_single, authmechanismproperties=props) + + def test_request_callback_returns_null(self): + class CallbackNullToken(OIDCMachineCallback): + def fetch(self, a): + return None + + props: Dict = {"custom_token_callback": CallbackNullToken()} + client = MongoClient(self.uri_single, authMechanismProperties=props) + with self.assertRaises(ValueError): + client.test.test.find_one() + client.close() + + def test_request_callback_invalid_result(self): + class CallbackTokenInvalid(OIDCMachineCallback): + def fetch(self, a): + return {} + + props: Dict = {"custom_token_callback": CallbackTokenInvalid()} + client = MongoClient(self.uri_single, authMechanismProperties=props) + with self.assertRaises(ValueError): + client.test.test.find_one() + client.close() + + def test_speculative_auth_success(self): + client1 = self.create_client() + client1.test.test.find_one() + client2 = self.create_client() + + # Prime the cache of the second client. + client2.options.pool_options._credentials.cache.data = ( + client1.options.pool_options._credentials.cache.data + ) + + # Set a fail point for saslStart commands. + with self.fail_point( + { + "mode": {"times": 2}, + "data": {"failCommands": ["saslStart"], "errorCode": 18}, + } + ): + # Perform a find operation. + client2.test.test.find_one() + + # Close the clients. + client2.close() + client1.close() + + def test_reauthentication_succeeds_multiple_connections(self): + client1 = self.create_client() + client2 = self.create_client() # Perform an insert operation. client1.test.test.insert_many([{"a": 1}, {"a": 1}]) diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 4976a6dd49..1d70b12496 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -26,6 +26,7 @@ from test.unified_format import generate_test_classes from pymongo import MongoClient +from pymongo.auth_oidc import OIDCHumanCallback _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth") @@ -34,6 +35,11 @@ class TestAuthSpec(unittest.TestCase): pass +class SampleHumanCallback(OIDCHumanCallback): + def fetch(self, info, context): + pass + + def create_test(test_case): def run_test(self): uri = test_case["uri"] @@ -47,8 +53,7 @@ def run_test(self): if credential: props = credential["mechanism_properties"] or {} if props.get("REQUEST_TOKEN_CALLBACK"): - props["request_token_callback"] = lambda x, y: 1 - del props["REQUEST_TOKEN_CALLBACK"] + props["request_token_callback"] = SampleHumanCallback() client = MongoClient(uri, connect=False, authmechanismproperties=props) credentials = client.options.pool_options._credentials if credential is None: diff --git a/test/unified_format.py b/test/unified_format.py index 99758989c9..b0625d03eb 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -139,7 +139,7 @@ } -# Build up a placeholder map. +# Build up a placeholder maps. PLACEHOLDER_MAP = {} for provider_name, provider_data in [ ("local", {"key": LOCAL_MASTER_KEY}), @@ -152,6 +152,10 @@ placeholder = f"/clientEncryptionOpts/kmsProviders/{provider_name}/{key}" PLACEHOLDER_MAP[placeholder] = value +PROVIDER_NAME = os.environ.get("OIDC_PROVIDER_NAME", "aws") +if PROVIDER_NAME == "aws": + PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {"PROVIDER_NAME": "aws"} + def interrupt_loop(): global IS_INTERRUPTED @@ -226,6 +230,8 @@ def is_run_on_requirement_satisfied(requirement): if req_auth is not None: if req_auth: auth_satisfied = client_context.auth_enabled + if auth_satisfied and "authMechanism" in requirement: + auth_satisfied = client_context.check_auth_type(requirement["authMechanism"]) else: auth_satisfied = not client_context.auth_enabled @@ -900,7 +906,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest): a class attribute ``TEST_SPEC``. """ - SCHEMA_VERSION = Version.from_string("1.17") + SCHEMA_VERSION = Version.from_string("1.18") RUN_ON_LOAD_BALANCER = True RUN_ON_SERVERLESS = True TEST_SPEC: Any