diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 3b31dbb13f..e76f56a832 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -1010,6 +1010,32 @@ task_groups: tasks: - testazurekms-task + - name: testazureoidc_task_group + setup_group: + - func: fetch source + - func: prepare resources + - func: fix absolute paths + - func: make files executable + - command: shell.exec + params: + shell: bash + script: |- + set -o errexit + ${PREPARE_SHELL} + export AZUREOIDC_VMNAME_PREFIX="PYTHON_DRIVER" + $DRIVERS_TOOLS/.evergreen/auth_oidc/azure/create-and-setup-vm.sh + teardown_task: + - command: shell.exec + params: + shell: bash + script: |- + ${PREPARE_SHELL} + $DRIVERS_TOOLS/.evergreen/auth_oidc/azure/delete-vm.sh + setup_group_can_fail_task: true + setup_group_timeout_secs: 1800 + tasks: + - oidc-auth-test-azure-latest + - name: test_aws_lambda_task_group setup_group: - func: fetch source @@ -1978,6 +2004,22 @@ tasks: - func: "run load-balancer" - func: "run tests" + - name: "oidc-auth-test-azure-latest" + commands: + - command: shell.exec + params: + shell: bash + script: |- + set -o errexit + ${PREPARE_SHELL} + cd src + git add . + git commit -m "add files" + export AZUREOIDC_DRIVERS_TAR_FILE=/tmp/mongo-python-driver.tgz + git archive -o $AZUREOIDC_DRIVERS_TAR_FILE HEAD + export AZUREOIDC_TEST_CMD="source ./env.sh && export OIDC_PROVIDER_NAME=azure && ./.evergreen/run-mongodb-oidc-test.sh" + bash $DRIVERS_TOOLS/.evergreen/auth_oidc/azure/run-driver-test.sh + - name: "test-fips-standalone" tags: ["fips"] commands: @@ -3036,6 +3078,13 @@ buildvariants: tasks: - name: "oidc-auth-test-latest" +- name: testazureoidc-variant + display_name: "Azure OIDC" + run_on: ubuntu2004-small + tasks: + - name: testazureoidc_task_group + batchtime: 20160 # Use a batchtime of 14 days as suggested by the CSFLE test README + - matrix_name: "aws-auth-test" matrix_spec: platform: [ubuntu-20.04] diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index 75fafb448b..ac7d29e5e9 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -5,44 +5,69 @@ set -o errexit # Exit the script with error if any of the commands fail echo "Running MONGODB-OIDC authentication tests" -# Make sure DRIVERS_TOOLS is set. -if [ -z "$DRIVERS_TOOLS" ]; then - echo "Must specify DRIVERS_TOOLS" - exit 1 -fi +OIDC_PROVIDER_NAME=${OIDC_PROVIDER_NAME:-"aws"} -# Get the drivers secrets. Use an existing secrets file first. -if [ ! -f "./secrets-export.sh" ]; then - bash ${DRIVERS_TOOLS}/.evergreen/auth_aws/setup_secrets.sh drivers/oidc -fi -source ./secrets-export.sh +if [ $OIDC_PROVIDER_NAME == "aws" ]; then + # Make sure DRIVERS_TOOLS is set. + if [ -z "$DRIVERS_TOOLS" ]; then + echo "Must specify DRIVERS_TOOLS" + exit 1 + fi -# # If the file did not have our creds, get them from the vault. -if [ -z "$OIDC_ATLAS_URI_SINGLE" ]; then - bash ${DRIVERS_TOOLS}/.evergreen/auth_aws/setup_secrets.sh drivers/oidc + # Get the drivers secrets. Use an existing secrets file first. + if [ ! -f "./secrets-export.sh" ]; then + bash ${DRIVERS_TOOLS}/.evergreen/auth_aws/setup_secrets.sh drivers/oidc + fi source ./secrets-export.sh -fi -# Make the OIDC tokens. -set -x -pushd ${DRIVERS_TOOLS}/.evergreen/auth_oidc -. ./oidc_get_tokens.sh -popd + # # If the file did not have our creds, get them from the vault. + if [ -z "$OIDC_ATLAS_URI_SINGLE" ]; then + bash ${DRIVERS_TOOLS}/.evergreen/auth_aws/setup_secrets.sh drivers/oidc + source ./secrets-export.sh + fi -# Set up variables and run the test. -if [ -n "$LOCAL_OIDC_SERVER" ]; then - export MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"} - export MONGODB_URI_SINGLE="${MONGODB_URI}/?authMechanism=MONGODB-OIDC" - export MONGODB_URI_MULTI="${MONGODB_URI}:27018/?authMechanism=MONGODB-OIDC&directConnection=true" -else + # Make the OIDC tokens. + set -x + pushd ${DRIVERS_TOOLS}/.evergreen/auth_oidc + . ./oidc_get_tokens.sh + popd + + # Set up variables and run the test. + if [ -n "$LOCAL_OIDC_SERVER" ]; then + export MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"} + export MONGODB_URI_SINGLE="${MONGODB_URI}/?authMechanism=MONGODB-OIDC" + export MONGODB_URI_MULTI="${MONGODB_URI}:27018/?authMechanism=MONGODB-OIDC&directConnection=true" + else + set +x # turn off xtrace for this portion + export MONGODB_URI="$OIDC_ATLAS_URI_SINGLE" + export MONGODB_URI_SINGLE="$OIDC_ATLAS_URI_SINGLE/?authMechanism=MONGODB-OIDC" + export MONGODB_URI_MULTI="$OIDC_ATLAS_URI_MULTI/?authMechanism=MONGODB-OIDC" + set -x + fi + export AWS_WEB_IDENTITY_TOKEN_FILE="$OIDC_TOKEN_DIR/test_user1" + export OIDC_ADMIN_USER=$OIDC_ALTAS_USER + export OIDC_ADMIN_PWD=$OIDC_ATLAS_PASSWORD + +elif [ $OIDC_PROVIDER_NAME == "azure" ]; then + if [ -z "${AZUREOIDC_AUDIENCE}" ]; then + echo "Must specify an AZUREOIDC_AUDIENCE" + exit 1 + fi set +x # turn off xtrace for this portion - export MONGODB_URI="$OIDC_ATLAS_URI_SINGLE" - export MONGODB_URI_SINGLE="$OIDC_ATLAS_URI_SINGLE/?authMechanism=MONGODB-OIDC" - export MONGODB_URI_MULTI="$OIDC_ATLAS_URI_MULTI/?authMechanism=MONGODB-OIDC" + export OIDC_ADMIN_USER=$AZUREOIDC_USERNAME + export OIDC_ADMIN_PWD=pwd123 set -x + export MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"} + MONGODB_URI_SINGLE="${MONGODB_URI}/?authMechanism=MONGODB-OIDC" + MONGODB_URI_SINGLE="${MONGODB_URI_SINGLE}&authMechanismProperties=PROVIDER_NAME:azure" + export MONGODB_URI_SINGLE="${MONGODB_URI_SINGLE},TOKEN_AUDIENCE:${AZUREOIDC_AUDIENCE}" + export MONGODB_URI_MULTI=$MONGODB_URI_SINGLE +else + echo "Unrecognized OIDC_PROVIDER_NAME $OIDC_PROVIDER_NAME" + exit 1 fi export TEST_AUTH_OIDC=1 export COVERAGE=1 export AUTH="auth" -bash ./.evergreen/tox.sh -m test-eg +bash ./.evergreen/tox.sh -m test-eg -- "${@:1}" diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index ef71f29179..54073ffb34 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -30,7 +30,7 @@ set -o xtrace AUTH=${AUTH:-noauth} SSL=${SSL:-nossl} -TEST_ARGS="$1" +TEST_ARGS="${*:1}" PYTHON=$(which python) export PIP_QUIET=1 # Quiet by default @@ -50,8 +50,9 @@ if [ "$AUTH" != "noauth" ]; then export DB_USER=$SERVERLESS_ATLAS_USER export DB_PASSWORD=$SERVERLESS_ATLAS_PASSWORD elif [ ! -z "$TEST_AUTH_OIDC" ]; then - export DB_USER=$OIDC_ALTAS_USER - export DB_PASSWORD=$OIDC_ATLAS_PASSWORD + export DB_USER=$OIDC_ADMIN_USER + export DB_PASSWORD=$OIDC_ADMIN_PWD + export DB_IP="$MONGODB_URI" else export DB_USER="bob" export DB_PASSWORD="pwd123" @@ -205,7 +206,6 @@ fi if [ -n "$TEST_AUTH_OIDC" ]; then python -m pip install ".[aws]" - TEST_ARGS="test/auth_oidc/test_auth_oidc.py" fi diff --git a/pymongo/_azure_helpers.py b/pymongo/_azure_helpers.py new file mode 100644 index 0000000000..661e4ce37a --- /dev/null +++ b/pymongo/_azure_helpers.py @@ -0,0 +1,56 @@ +# Copyright 2023-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Azure helpers.""" +from __future__ import annotations + +import json +from typing import Any, Optional +from urllib.request import Request, urlopen + + +def _get_azure_response( + resource: str, object_id: Optional[str] = None, timeout: float = 5 +) -> dict[str, Any]: + url = "http://169.254.169.254/metadata/identity/oauth2/token" + url += "?api-version=2018-02-01" + url += f"&resource={resource}" + if object_id: + url += f"&object_id={object_id}" + headers = {"Metadata": "true", "Accept": "application/json"} + request = Request(url, headers=headers) # noqa: S310 + print("fetching url", url) # noqa: T201 + try: + with urlopen(request, timeout=timeout) as response: # noqa: S310 + status = response.status + body = response.read().decode("utf8") + except Exception as e: + msg = "Failed to acquire IMDS access token: %s" % e + raise ValueError(msg) from None + + if status != 200: + msg = "Failed to acquire IMDS access token." + raise ValueError(msg) + try: + data = json.loads(body) + except Exception: + raise ValueError("Azure IMDS response must be in JSON format.") from None + + for key in ["access_token", "expires_in"]: + if not data.get(key): + msg = "Azure IMDS response must contain %s, but was %s." + msg = msg % (key, body) + raise ValueError(msg) + + return data diff --git a/pymongo/auth.py b/pymongo/auth.py index a2c9c29980..9d70e2911a 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -37,7 +37,13 @@ from bson.binary import Binary from pymongo.auth_aws import _authenticate_aws -from pymongo.auth_oidc import _authenticate_oidc, _get_authenticator, _OIDCProperties +from pymongo.auth_oidc import ( + _authenticate_oidc, + _get_authenticator, + _OIDCAWSCallback, + _OIDCAzureCallback, + _OIDCProperties, +) from pymongo.errors import ConfigurationError, OperationFailure from pymongo.saslprep import saslprep @@ -162,8 +168,10 @@ def _build_credentials_tuple( return MongoCredential(mech, "$external", user, passwd, aws_props, None) elif mech == "MONGODB-OIDC": properties = extra.get("authmechanismproperties", {}) - request_token_callback = properties.get("request_token_callback") - provider_name = properties.get("PROVIDER_NAME", "") + callback = properties.get("OIDC_CALLBACK") + human_callback = properties.get("OIDC_HUMAN_CALLBACK") + provider_name = properties.get("PROVIDER_NAME") + token_audience = properties.get("TOKEN_AUDIENCE", "") default_allowed = [ "*.mongodb.net", "*.mongodb-dev.net", @@ -173,13 +181,40 @@ def _build_credentials_tuple( "127.0.0.1", "::1", ] - allowed_hosts = properties.get("allowed_hosts", default_allowed) - if not request_token_callback and provider_name != "aws": - raise ConfigurationError( - "authentication with MONGODB-OIDC requires providing an request_token_callback or a provider_name of 'aws'" - ) + allowed_hosts = properties.get("ALLOWED_HOSTS", default_allowed) + msg = "authentication with MONGODB-OIDC requires providing either a callback or a provider_name" + if passwd is not None: + msg = "password is not supported by MONGODB-OIDC" + raise ConfigurationError(msg) + if callback or human_callback: + if provider_name is not None: + raise ConfigurationError(msg) + if callback and human_callback: + msg = "cannot set both OIDC_CALLBACK and OIDC_HUMAN_CALLBACK" + raise ConfigurationError(msg) + elif provider_name is not None: + if provider_name == "aws": + if user is not None: + msg = "AWS provider for MONGODB-OIDC does not support username" + raise ConfigurationError(msg) + callback = _OIDCAWSCallback() + elif provider_name == "azure": + passwd = None + if not token_audience: + raise ConfigurationError( + "Azure provider for MONGODB-OIDC requires a TOKEN_AUDIENCE auth mechanism property" + ) + callback = _OIDCAzureCallback(token_audience, user) + else: + raise ConfigurationError( + f"unrecognized provider_name for MONGODB-OIDC: {provider_name}" + ) + else: + raise ConfigurationError(msg) + oidc_props = _OIDCProperties( - request_token_callback=request_token_callback, + callback=callback, + human_callback=human_callback, provider_name=provider_name, allowed_hosts=allowed_hosts, ) @@ -522,6 +557,7 @@ def _authenticate_default(credentials: MongoCredential, conn: Connection) -> Non "MONGODB-CR": _authenticate_mongo_cr, "MONGODB-X509": _authenticate_x509, "MONGODB-AWS": _authenticate_aws, + "MONGODB-OIDC": _authenticate_oidc, # type:ignore[dict-item] "PLAIN": _authenticate_plain, "SCRAM-SHA-1": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-1"), "SCRAM-SHA-256": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-256"), @@ -582,7 +618,7 @@ def speculate_command(self) -> MutableMapping[str, Any]: class _OIDCContext(_AuthContext): def speculate_command(self) -> Optional[MutableMapping[str, Any]]: authenticator = _get_authenticator(self.credentials, self.address) - cmd = authenticator.auth_start_cmd(False) + cmd = authenticator.get_spec_auth_cmd() if cmd is None: return None cmd["db"] = self.credentials.source diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index 357cb62fbd..9170e013b3 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -15,13 +15,17 @@ """MONGODB-OIDC Authentication helpers.""" from __future__ import annotations +import abc +import os import threading +import time from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Optional +from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union import bson from bson.binary import Binary -from bson.son import SON +from pymongo._azure_helpers import _get_azure_response +from pymongo._csot import remaining from pymongo.errors import ConfigurationError, OperationFailure if TYPE_CHECKING: @@ -29,18 +33,51 @@ from pymongo.pool import Connection +@dataclass +class OIDCIdPInfo: + issuer: str + clientId: str + requestScopes: Optional[list[str]] = field(default=None) + + +@dataclass +class OIDCCallbackContext: + timeout_seconds: float + version: int + refresh_token: Optional[str] = field(default=None) + idp_info: Optional[OIDCIdPInfo] = field(default=None) + + +@dataclass +class OIDCCallbackResult: + access_token: str + expires_in_seconds: Optional[float] = field(default=None) + refresh_token: Optional[str] = field(default=None) + + +class OIDCCallback(abc.ABC): + """A base class for defining OIDC callbacks.""" + + @abc.abstractmethod + def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: + """Convert the given BSON value into our own type.""" + + @dataclass class _OIDCProperties: - request_token_callback: Optional[Callable[..., dict]] - provider_name: Optional[str] - allowed_hosts: list[str] + callback: Optional[OIDCCallback] = field(default=None) + human_callback: Optional[OIDCCallback] = field(default=None) + provider_name: Optional[str] = field(default=None) + allowed_hosts: list[str] = field(default_factory=list) """Mechanism properties for MONGODB-OIDC authentication.""" TOKEN_BUFFER_MINUTES = 5 -CALLBACK_TIMEOUT_SECONDS = 5 * 60 +HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60 CALLBACK_VERSION = 1 +MACHINE_CALLBACK_TIMEOUT_SECONDS = 60 +TIME_BETWEEN_CALLS_SECONDS = 0.1 def _get_authenticator( @@ -72,28 +109,126 @@ def _get_authenticator( return credentials.cache.data +class _OIDCAWSCallback(OIDCCallback): + def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: + token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE") + if not token_file: + raise RuntimeError( + 'MONGODB-OIDC with an "aws" provider requires "AWS_WEB_IDENTITY_TOKEN_FILE" to be set' + ) + with open(token_file) as fid: + return OIDCCallbackResult(access_token=fid.read().strip()) + + +class _OIDCAzureCallback(OIDCCallback): + def __init__(self, token_audience: str, username: Optional[str]) -> None: + self.token_audience = token_audience + self.username = username + + def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: + resp = _get_azure_response(self.token_audience, self.username, context.timeout_seconds) + return OIDCCallbackResult( + access_token=resp["access_token"], expires_in_seconds=resp["expires_in"] + ) + + @dataclass class _OIDCAuthenticator: username: str properties: _OIDCProperties refresh_token: Optional[str] = field(default=None) access_token: Optional[str] = field(default=None) - idp_info: Optional[dict] = field(default=None) + idp_info: Optional[OIDCIdPInfo] = field(default=None) token_gen_id: int = field(default=0) lock: threading.Lock = field(default_factory=threading.Lock) + last_call_time: float = field(default=0) + + def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]: + """Handle a reauthenticate from the server.""" + # Invalidate the token for the connection. + self._invalidate(conn) + # Call the appropriate auth logic for the callback type. + if self.properties.callback: + return self._authenticate_machine(conn) + return self._authenticate_human(conn) + + def authenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]: + """Handle an initial authenticate request.""" + # First handle speculative auth. + # If it succeeded, we are done. + ctx = conn.auth_ctx + if ctx and ctx.speculate_succeeded(): + resp = ctx.speculative_authenticate + if resp and resp["done"]: + conn.oidc_token_gen_id = self.token_gen_id + return resp + + # If spec auth failed, call the appropriate auth logic for the callback type. + # We cannot assume that the token is invalid, because a proxy may have been + # involved that stripped the speculative auth information. + if self.properties.callback: + return self._authenticate_machine(conn) + return self._authenticate_human(conn) + + def get_spec_auth_cmd(self) -> Optional[MutableMapping[str, Any]]: + """Get the appropriate speculative auth command.""" + if not self.access_token: + return None + return self._get_start_command({"jwt": self.access_token}) + + def _authenticate_machine(self, conn: Connection) -> Mapping[str, Any]: + # If there is a cached access token, try to authenticate with it. If + # authentication fails, it's possible the cached access token is expired. In + # that case, invalidate the access token, fetch a new access token, and try + # to authenticate again. + if self.access_token: + try: + return self._sasl_start_jwt(conn) + except Exception: # noqa: S110 + pass + return self._sasl_start_jwt(conn) + + def _authenticate_human(self, conn: Connection) -> Optional[Mapping[str, Any]]: + # If we have a cached access token, try a JwtStepRequest. + if self.access_token: + try: + return self._sasl_start_jwt(conn) + except Exception: # noqa: S110 + pass - def get_current_token(self, use_callback: bool = True) -> Optional[str]: + # If we have a cached refresh token, try a JwtStepRequest with that. + if self.refresh_token: + try: + return self._sasl_start_jwt(conn) + except Exception: # noqa: S110 + pass + + # Start a new Two-Step SASL conversation. + # Run a PrincipalStepRequest to get the IdpInfo. + cmd = self._get_start_command(None) + start_resp = self._run_command(conn, cmd) + # Attempt to authenticate with a JwtStepRequest. + return self._sasl_continue_jwt(conn, start_resp) + + def _get_access_token(self) -> Optional[str]: properties = self.properties + cb: Union[None, OIDCCallback] + resp: OIDCCallbackResult - # TODO: DRIVERS-2672, handle machine callback here as well. - cb = properties.request_token_callback if use_callback else None - cb_type = "human" + is_human = properties.human_callback is not None + if is_human and self.idp_info is None: + return None + + if properties.callback: + cb = properties.callback + if properties.human_callback: + cb = properties.human_callback prev_token = self.access_token if prev_token: return prev_token - if not use_callback and not prev_token: + if cb is None and not prev_token: return None if not prev_token and cb is not None: @@ -104,163 +239,85 @@ def get_current_token(self, use_callback: bool = True) -> Optional[str]: if new_token != prev_token: return new_token - # TODO: DRIVERS-2672 handle machine callback here. - if cb_type == "human": - context = { - "timeout_seconds": CALLBACK_TIMEOUT_SECONDS, - "version": CALLBACK_VERSION, - "refresh_token": self.refresh_token, - } - resp = cb(self.idp_info, context) - - self.validate_request_token_response(resp) - + # Ensure that we are waiting a min time between callback invocations. + delta = time.time() - self.last_call_time + if delta < TIME_BETWEEN_CALLS_SECONDS: + time.sleep(TIME_BETWEEN_CALLS_SECONDS - delta) + self.last_call_time = time.time() + + if is_human: + timeout = HUMAN_CALLBACK_TIMEOUT_SECONDS + assert self.idp_info is not None + else: + timeout = int(remaining() or MACHINE_CALLBACK_TIMEOUT_SECONDS) + context = OIDCCallbackContext( + timeout_seconds=timeout, + version=CALLBACK_VERSION, + refresh_token=self.refresh_token, + idp_info=self.idp_info, + ) + resp = cb.fetch(context) + if not isinstance(resp, OIDCCallbackResult): + raise ValueError("Callback result must be of type OIDCCallbackResult") + self.refresh_token = resp.refresh_token + self.access_token = resp.access_token self.token_gen_id += 1 return self.access_token - def validate_request_token_response(self, resp: Mapping[str, Any]) -> None: - # Validate callback return value. - if not isinstance(resp, dict): - raise ValueError("OIDC callback returned invalid result") - - if "access_token" not in resp: - raise ValueError("OIDC callback did not return an access_token") - - expected = ["access_token", "refresh_token", "expires_in_seconds"] - for key in resp: - if key not in expected: - raise ValueError(f'Unexpected field in callback result "{key}"') - - self.access_token = resp["access_token"] - self.refresh_token = resp.get("refresh_token") - - def principal_step_cmd(self) -> SON[str, Any]: - """Get a SASL start command with an optional principal name""" - # Send the SASL start with the optional principal name. - payload = {} - - principal_name = self.username - if principal_name: - payload["n"] = principal_name - - return SON( - [ - ("saslStart", 1), - ("mechanism", "MONGODB-OIDC"), - ("payload", Binary(bson.encode(payload))), - ("autoAuthorize", 1), - ] - ) - - def auth_start_cmd(self, use_callback: bool = True) -> Optional[SON[str, Any]]: - # TODO: DRIVERS-2672, check for provider_name in self.properties here. - if self.idp_info is None: - return self.principal_step_cmd() - - token = self.get_current_token(use_callback) - if not token: - return None - bin_payload = Binary(bson.encode({"jwt": token})) - return SON( - [ - ("saslStart", 1), - ("mechanism", "MONGODB-OIDC"), - ("payload", bin_payload), - ] - ) - - def run_command( - self, conn: Connection, cmd: MutableMapping[str, Any] - ) -> Optional[Mapping[str, Any]]: + def _run_command(self, conn: Connection, cmd: MutableMapping[str, Any]) -> Mapping[str, Any]: try: return conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg] except OperationFailure: - self.access_token = None + self._invalidate(conn) raise - def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]: - """Handle a reauthenticate from the server.""" - # First see if we have the a newer token on the authenticator. - prev_id = conn.oidc_token_gen_id or 0 - # If we've already changed tokens, make one optimistic attempt. - if (prev_id < self.token_gen_id) and self.access_token: - try: - return self.authenticate(conn) - except OperationFailure: - pass - + def _invalidate(self, conn: Connection) -> None: + # Ignore the invalidation if a token gen id is given and is less than our + # current token gen id. + token_gen_id = conn.oidc_token_gen_id or 0 + if token_gen_id is not None and token_gen_id < self.token_gen_id: + return self.access_token = None - # TODO: DRIVERS-2672, check for provider_name in self.properties here. - # If so, we clear the access token and return finish_auth. - - # Next see if the idp info has changed. - prev_idp_info = self.idp_info - self.idp_info = None - cmd = self.principal_step_cmd() - resp = self.run_command(conn, cmd) - assert resp is not None - server_resp: dict = bson.decode(resp["payload"]) - if "issuer" in server_resp: - self.idp_info = server_resp - - # Handle the case of changed idp info. - if self.idp_info != prev_idp_info: - self.access_token = None - self.refresh_token = None - - # If we have a refresh token, try using that. - if self.refresh_token: - try: - return self.finish_auth(resp, conn) - except OperationFailure: - self.refresh_token = None - # If that fails, try again without the refresh token. - return self.authenticate(conn) - - # If we don't have a refresh token, just try once. - return self.finish_auth(resp, conn) - - def authenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]: - ctx = conn.auth_ctx - cmd = None - - if ctx and ctx.speculate_succeeded(): - resp = ctx.speculative_authenticate - else: - cmd = self.auth_start_cmd() - assert cmd is not None - resp = self.run_command(conn, cmd) - - assert resp is not None - if resp["done"]: - conn.oidc_token_gen_id = self.token_gen_id - return None - - server_resp: dict = bson.decode(resp["payload"]) - if "issuer" in server_resp: - self.idp_info = server_resp - - return self.finish_auth(resp, conn) + def _sasl_continue_jwt( + self, conn: Connection, start_resp: Mapping[str, Any] + ) -> Mapping[str, Any]: + self.access_token = None + self.refresh_token = None + start_payload: dict = bson.decode(start_resp["payload"]) + if "issuer" in start_payload: + self.idp_info = OIDCIdPInfo(**start_payload) + access_token = self._get_access_token() + conn.oidc_token_gen_id = self.token_gen_id + cmd = self._get_continue_command({"jwt": access_token}, start_resp) + return self._run_command(conn, cmd) - def finish_auth( - self, orig_resp: Mapping[str, Any], conn: Connection - ) -> Optional[Mapping[str, Any]]: - conversation_id = orig_resp["conversationId"] - token = self.get_current_token() + def _sasl_start_jwt(self, conn: Connection) -> Mapping[str, Any]: + access_token = self._get_access_token() conn.oidc_token_gen_id = self.token_gen_id - bin_payload = Binary(bson.encode({"jwt": token})) - cmd = { + cmd = self._get_start_command({"jwt": access_token}) + return self._run_command(conn, cmd) + + def _get_start_command(self, payload: Optional[Mapping[str, Any]]) -> MutableMapping[str, Any]: + if payload is None: + principal_name = self.username + if principal_name: + payload = {"n": principal_name} + else: + payload = {} + bin_payload = Binary(bson.encode(payload)) + return {"saslStart": 1, "mechanism": "MONGODB-OIDC", "payload": bin_payload} + + def _get_continue_command( + self, payload: Mapping[str, Any], start_resp: Mapping[str, Any] + ) -> MutableMapping[str, Any]: + bin_payload = Binary(bson.encode(payload)) + return { "saslContinue": 1, - "conversationId": conversation_id, "payload": bin_payload, + "conversationId": start_resp["conversationId"], } - resp = self.run_command(conn, cmd) - assert resp is not None - if not resp["done"]: - raise OperationFailure("SASL conversation failed to complete.") - return resp def _authenticate_oidc( diff --git a/pymongo/common.py b/pymongo/common.py index 41d1e1050e..88dd7d65ee 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -17,7 +17,6 @@ from __future__ import annotations import datetime -import inspect import warnings from collections import OrderedDict, abc from difflib import get_close_matches @@ -42,6 +41,7 @@ from bson.codec_options import CodecOptions, DatetimeConversion, TypeRegistry from bson.raw_bson import RawBSONDocument from pymongo.auth import MECHANISMS +from pymongo.auth_oidc import OIDCCallback from pymongo.compression_support import ( validate_compressors, validate_zlib_compression_level, @@ -425,6 +425,8 @@ def validate_read_preference_tags(name: str, value: Any) -> list[dict[str, str]] "SERVICE_REALM", "AWS_SESSION_TOKEN", "PROVIDER_NAME", + "TOKEN_AUDIENCE", + "ALLOWED_HOSTS", ] ) @@ -440,22 +442,14 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni props[key] = value elif isinstance(value, bool): props[key] = str(value).lower() - elif key in ["allowed_hosts"] and isinstance(value, list): + elif key in ["ALLOWED_HOSTS"] and isinstance(value, list): props[key] = value - elif inspect.isfunction(value): - signature = inspect.signature(value) - if key == "request_token_callback": - expected_params = 2 - else: - raise ValueError(f"Unrecognized Auth mechanism function {key}") - if len(signature.parameters) != expected_params: - msg = f"{key} must accept {expected_params} parameters" - raise ValueError(msg) + elif key in ["OIDC_CALLBACK", "OIDC_HUMAN_CALLBACK"]: + if not isinstance(value, OIDCCallback): + raise ValueError("callback must be an OIDCCallback object") props[key] = value else: - raise ValueError( - "Auth mechanism property values must be strings or callback functions" - ) + raise ValueError(f"Invalid type for auth mechanism property {key}, {type(value)}") return props value = validate_string(option, value) diff --git a/pyproject.toml b/pyproject.toml index b182a16f7c..46ffa8ab6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -222,6 +222,7 @@ exclude_lines = [ "return NotImplemented", "_use_c = true", "if __name__ == '__main__':", + "if TYPE_CHECKING:" ] partial_branches = ["if (.*and +)*not _use_c( and.*)*:"] diff --git a/test/__init__.py b/test/__init__.py index 2978c05d89..fc20f68df6 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -67,6 +67,7 @@ # for a replica set. host = os.environ.get("DB_IP", "localhost") port = int(os.environ.get("DB_PORT", 27017)) +IS_SRV = "mongodb+srv" in host db_user = os.environ.get("DB_USER", "user") db_pwd = os.environ.get("DB_PASSWORD", "password") @@ -384,7 +385,7 @@ def _init_client(self): self.auth_enabled = self._server_started_with_auth() if self.auth_enabled: - if not self.serverless: + if not self.serverless and not IS_SRV: # See if db_user already exists. if not self._check_user_provided(): _create_user(self.client.admin, db_user, db_pwd) @@ -452,7 +453,7 @@ def _init_client(self): else: self.server_parameters = self.client.admin.command("getParameter", "*") assert self.cmd_line is not None - if "enableTestCommands=1" in self.cmd_line["argv"]: + if self.server_parameters["enableTestCommands"]: self.test_commands_enabled = True elif "parsed" in self.cmd_line: params = self.cmd_line["parsed"].get("setParameter", []) @@ -488,14 +489,14 @@ def connection_attempt_info(self): @property def host(self): - if self.is_rs: + if self.is_rs and not IS_SRV: primary = self.client.primary return str(primary[0]) if primary is not None else host return host @property def port(self): - if self.is_rs: + if self.is_rs and not IS_SRV: primary = self.client.primary return primary[1] if primary is not None else port return port @@ -520,6 +521,10 @@ def storage_engine(self): # Raised if self.server_status is None. return None + def check_auth_type(self, auth_type): + auth_mechs = self.server_parameters.get("authenticationMechanisms", []) + return auth_type in auth_mechs + def _check_user_provided(self): """Return True if db_user/db_password is already an admin user.""" client: MongoClient = pymongo.MongoClient( diff --git a/test/auth/legacy/connection-string.json b/test/auth/legacy/connection-string.json index 0463a5141e..4fca1c33d1 100644 --- a/test/auth/legacy/connection-string.json +++ b/test/auth/legacy/connection-string.json @@ -446,9 +446,8 @@ } }, { - "description": "should recognise the mechanism and request callback (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", - "callback": ["oidcRequest"], + "description": "should recognise the mechanism with aws provider (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws", "valid": true, "credential": { "username": null, @@ -456,14 +455,13 @@ "source": "$external", "mechanism": "MONGODB-OIDC", "mechanism_properties": { - "REQUEST_TOKEN_CALLBACK": true + "PROVIDER_NAME": "aws" } } }, { - "description": "should recognise the mechanism when auth source is explicitly specified and with request callback (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external", - "callback": ["oidcRequest"], + "description": "should recognise the mechanism when auth source is explicitly specified and with provider (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external&authMechanismProperties=PROVIDER_NAME:aws", "valid": true, "credential": { "username": null, @@ -471,28 +469,43 @@ "source": "$external", "mechanism": "MONGODB-OIDC", "mechanism_properties": { - "REQUEST_TOKEN_CALLBACK": true + "PROVIDER_NAME": "aws" } } }, { - "description": "should recognise the mechanism and username with request callback (MONGODB-OIDC)", - "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC", - "callback": ["oidcRequest"], - "valid": true, - "credential": { - "username": "principalName", - "password": null, - "source": "$external", - "mechanism": "MONGODB-OIDC", - "mechanism_properties": { - "REQUEST_TOKEN_CALLBACK": true - } - } + "description": "should throw an exception if username and password is specified for aws provider (MONGODB-OIDC)", + "uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws", + "valid": false, + "credential": null }, { - "description": "should recognise the mechanism with aws device (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws", + "description": "should throw an exception if username is specified for aws provider (MONGODB-OIDC)", + "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC&PROVIDER_NAME:aws", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if specified provider is not supported (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:invalid", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception custom callback is chosen but no callback is provided (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:custom", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if neither provider nor callbacks specified (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "valid": false, + "credential": null + }, + { + "description": "should recognise the mechanism with azure provider (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:azure,TOKEN_AUDIENCE:foo", "valid": true, "credential": { "username": null, @@ -500,52 +513,35 @@ "source": "$external", "mechanism": "MONGODB-OIDC", "mechanism_properties": { - "PROVIDER_NAME": "aws" + "PROVIDER_NAME": "azure", + "TOKEN_AUDIENCE": "foo" } } }, { - "description": "should recognise the mechanism when auth source is explicitly specified and with aws device (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external&authMechanismProperties=PROVIDER_NAME:aws", + "description": "should accept a username with azure provider (MONGODB-OIDC)", + "uri": "mongodb://user@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:azure,TOKEN_AUDIENCE:foo", "valid": true, "credential": { - "username": null, + "username": "user", "password": null, "source": "$external", "mechanism": "MONGODB-OIDC", "mechanism_properties": { - "PROVIDER_NAME": "aws" + "PROVIDER_NAME": "azure", + "TOKEN_AUDIENCE": "foo" } } }, { - "description": "should throw an exception if username and password are specified (MONGODB-OIDC)", - "uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC", - "callback": ["oidcRequest"], - "valid": false, - "credential": null - }, - { - "description": "should throw an exception if username and deviceName are specified (MONGODB-OIDC)", - "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC&PROVIDER_NAME:gcp", - "valid": false, - "credential": null - }, - { - "description": "should throw an exception if specified deviceName is not supported (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:unexisted", - "valid": false, - "credential": null - }, - { - "description": "should throw an exception if neither deviceName nor callback specified (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "description": "should accept a username and throw an error for a password with azure provider (MONGODB-OIDC)", + "uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:azure,TOKEN_AUDIENCE:foo", "valid": false, "credential": null }, { - "description": "should throw an exception when unsupported auth property is specified (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=UnsupportedProperty:unexisted", + "description": "should throw and exception if no token audience is given for azure provider (MONGODB-OIDC)", + "uri": "mongodb://username@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:azure", "valid": false, "credential": null } diff --git a/test/auth/unified/mongodb-oidc-no-retry.json b/test/auth/unified/mongodb-oidc-no-retry.json new file mode 100644 index 0000000000..116ca3eea9 --- /dev/null +++ b/test/auth/unified/mongodb-oidc-no-retry.json @@ -0,0 +1,601 @@ +{ + "description": "MONGODB-OIDC authentication with retry disabled", + "schemaVersion": "1.19", + "runOnRequirements": [ + { + "minServerVersion": "7.0", + "auth": true, + "authMechanism": "MONGODB-OIDC" + } + ], + "createEntities": [ + { + "client": { + "id": "failPointClient", + "useMultipleMongoses": false + } + }, + { + "client": { + "id": "client0", + "uriOptions": { + "authMechanism": "MONGODB-OIDC", + "authMechanismProperties": { + "$$placeholder": 1 + }, + "retryReads": false, + "retryWrites": false + }, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "test" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "collName" + } + } + ], + "initialData": [ + { + "collectionName": "collName", + "databaseName": "test", + "documents": [ + + ] + } + ], + "tests": [ + { + "description": "A read operation should succeed", + "operations": [ + { + "name": "find", + "object": "collection0", + "arguments": { + "filter": { + } + }, + "expectResult": [ + + ] + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": { + } + } + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "A write operation should succeed", + "operations": [ + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + }, + { + "description": "Read commands should reauthenticate and retry when a ReauthenticationRequired error happens", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "find" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "find", + "object": "collection0", + "arguments": { + "filter": { + } + }, + "expectResult": [ + + ] + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": { + } + } + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": { + } + } + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "Write commands should reauthenticate and retry when a ReauthenticationRequired error happens", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + }, + { + "description": "Handshake with cached token should use speculative authentication", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "closeConnection": true + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + }, + "expectError": { + "isClientError": true + } + }, + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "saslStart" + ], + "errorCode": 20 + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + }, + { + "description": "Handshake without cached token should not use speculative authentication", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "saslStart" + ], + "errorCode": 20 + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + }, + "expectError": { + "errorCode": 20 + } + } + ] + }, + { + "description": "Read commands should fail if reauthentication fails", + "operations": [ + { + "name": "find", + "object": "collection0", + "arguments": { + "filter": { + } + }, + "expectResult": [ + + ] + }, + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "find", + "saslStart" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "find", + "object": "collection0", + "arguments": { + "filter": { + } + }, + "expectError": { + "errorCode": 391 + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": { + } + } + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": { + } + } + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "Write commands should fail if reauthentication fails", + "operations": [ + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + } + }, + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "insert", + "saslStart" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 2, + "x": 2 + } + }, + "expectError": { + "errorCode": 391 + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 2, + "x": 2 + } + ] + } + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + } + ] + } + ] + } + ] + } \ No newline at end of file diff --git a/test/auth_oidc/test_auth_oidc.py b/test/auth_oidc/test_auth_oidc.py index 7055816af3..5374f8ab3b 100644 --- a/test/auth_oidc/test_auth_oidc.py +++ b/test/auth_oidc/test_auth_oidc.py @@ -17,60 +17,62 @@ import os import sys +import threading import time import unittest +import warnings from contextlib import contextmanager +from pathlib import Path from typing import Dict sys.path[0:0] = [""] +from test.unified_format import generate_test_classes from test.utils import EventListener from bson import SON from pymongo import MongoClient -from pymongo.auth import _AUTH_MAP, _authenticate_oidc +from pymongo.auth_oidc import ( + OIDCCallback, + OIDCCallbackResult, +) +from pymongo.azure_helpers import _get_azure_response from pymongo.cursor import CursorType -from pymongo.errors import ConfigurationError, OperationFailure +from pymongo.errors import AutoReconnect, ConfigurationError, OperationFailure from pymongo.hello import HelloCompat from pymongo.operations import InsertOne +from pymongo.uri_parser import parse_uri -# Force MONGODB-OIDC to be enabled. -_AUTH_MAP["MONGODB-OIDC"] = _authenticate_oidc # type:ignore +ROOT = Path(__file__).parent.parent.resolve() +TEST_PATH = ROOT / "auth" / "unified" +PROVIDER_NAME = os.environ.get("OIDC_PROVIDER_NAME", "aws") -class TestAuthOIDC(unittest.TestCase): - uri: str +# Generate unified tests. +globals().update(generate_test_classes(str(TEST_PATH), module=__name__)) + +class OIDCTestBase(unittest.TestCase): @classmethod def setUpClass(cls): cls.uri_single = os.environ["MONGODB_URI_SINGLE"] cls.uri_multiple = os.environ["MONGODB_URI_MULTI"] cls.uri_admin = os.environ["MONGODB_URI"] - cls.token_dir = os.environ["OIDC_TOKEN_DIR"] def setUp(self): self.request_called = 0 - def create_request_cb(self, username="test_user1", sleep=0): - token_file = os.path.join(self.token_dir, username).replace(os.sep, "/") - - def request_token(server_info, context): - # Validate the info. - self.assertIn("issuer", server_info) - self.assertIn("clientId", server_info) - - # Validate the timeout. - timeout_seconds = context["timeout_seconds"] - self.assertEqual(timeout_seconds, 60 * 5) + def get_token(self, username=None): + """Get a token for the current provider.""" + if PROVIDER_NAME == "aws": + token_dir = os.environ["OIDC_TOKEN_DIR"] + token_file = os.path.join(token_dir, username).replace(os.sep, "/") with open(token_file) as fid: - token = fid.read() - resp = {"access_token": token, "refresh_token": token} - - time.sleep(sleep) - self.request_called += 1 - return resp - - return request_token + return fid.read() + elif PROVIDER_NAME == "azure": + opts = parse_uri(self.uri_single)["options"] + token_aud = opts["authmechanismproperties"]["TOKEN_AUDIENCE"] + return _get_azure_response(token_aud, username)["access_token"] @contextmanager def fail_point(self, command_args): @@ -83,156 +85,218 @@ def fail_point(self, command_args): finally: client.admin.command("configureFailPoint", cmd_on["configureFailPoint"], mode="off") - def test_connect_request_callback_single_implicit_username(self): - request_token = self.create_request_cb() - props: Dict = {"request_token_callback": request_token} - client = MongoClient(self.uri_single, authmechanismproperties=props) + +class TestAuthOIDCHuman(OIDCTestBase): + uri: str + + @classmethod + def setUpClass(cls): + if PROVIDER_NAME != "aws": + raise unittest.SkipTest("Human workflows are only tested with the aws provider") + super().setUpClass() + + def create_request_cb(self, username="test_user1", sleep=0): + def request_token(context): + # Validate the info. + self.assertIsInstance(context.idp_info.issuer, str) + self.assertIsInstance(context.idp_info.clientId, str) + + # Validate the timeout. + timeout_seconds = context.timeout_seconds + self.assertEqual(timeout_seconds, 60 * 5) + token = self.get_token(username) + resp = OIDCCallbackResult(access_token=token, refresh_token=token) + + time.sleep(sleep) + self.request_called += 1 + return resp + + class Inner(OIDCCallback): + def fetch(self, context): + return request_token(context) + + return Inner() + + def create_client(self, *args, **kwargs): + username = kwargs.get("username", "test_user1") + request_cb = kwargs.pop("request_cb", self.create_request_cb(username=username)) + props = kwargs.pop("authmechanismproperties", {"OIDC_HUMAN_CALLBACK": request_cb}) + kwargs["retryReads"] = False + if not len(args): + args = [self.uri_single] + return MongoClient(*args, authmechanismproperties=props, **kwargs) + + def test_1_1_single_principal_implicit_username(self): + # Create default OIDC client with authMechanism=MONGODB-OIDC. + client = self.create_client() + # Perform a find operation that succeeds. client.test.test.find_one() + # Close the client. client.close() - def test_connect_request_callback_single_explicit_username(self): - request_token = self.create_request_cb() - props: Dict = {"request_token_callback": request_token} - client = MongoClient(self.uri_single, username="test_user1", authmechanismproperties=props) + def test_1_2_single_principal_explicit_username(self): + # Create a client with MONGODB_URI_SINGLE, a username of test_user1, authMechanism=MONGODB-OIDC, and the OIDC human callback. + client = self.create_client(username="test_user1") + # Perform a find operation that succeeds. client.test.test.find_one() + # Close the client.. client.close() - def test_connect_request_callback_multiple_principal_user1(self): - request_token = self.create_request_cb() - props: Dict = {"request_token_callback": request_token} - client = MongoClient( - self.uri_multiple, username="test_user1", authmechanismproperties=props - ) + def test_1_3_multiple_principal_user_1(self): + # Create a client with MONGODB_URI_MULTI, a username of test_user1, authMechanism=MONGODB-OIDC, and the OIDC human callback. + client = self.create_client(self.uri_multiple, username="test_user1") + # Perform a find operation that succeeds. client.test.test.find_one() + # Close the client. client.close() - def test_connect_request_callback_multiple_principal_user2(self): - request_token = self.create_request_cb("test_user2") - props: Dict = {"request_token_callback": request_token} - client = MongoClient( - self.uri_multiple, username="test_user2", authmechanismproperties=props - ) + def test_1_4_multiple_principal_user_2(self): + # Create a human callback that reads in the generated test_user2 token file. + # Create a client with MONGODB_URI_MULTI, a username of test_user2, authMechanism=MONGODB-OIDC, and the OIDC human callback. + client = self.create_client(self.uri_multiple, username="test_user2") + # Perform a find operation that succeeds. client.test.test.find_one() + # Close the client. client.close() - def test_connect_request_callback_multiple_no_username(self): - request_token = self.create_request_cb() - props: Dict = {"request_token_callback": request_token} - client = MongoClient(self.uri_multiple, authmechanismproperties=props) + def test_1_5_multiple_principal_no_user(self): + # Create a client with MONGODB_URI_MULTI, no username, authMechanism=MONGODB-OIDC, and the OIDC human callback. + client = self.create_client(self.uri_multiple) + # Assert that a find operation fails. with self.assertRaises(OperationFailure): client.test.test.find_one() + # Close the client. client.close() - def test_allowed_hosts_blocked(self): + def test_1_6_allowed_hosts_blocked(self): + # Create a default OIDC client, with an ALLOWED_HOSTS that is an empty list. request_token = self.create_request_cb() - props: Dict = {"request_token_callback": request_token, "allowed_hosts": []} - client = MongoClient(self.uri_single, authmechanismproperties=props) - with self.assertRaises(ConfigurationError): - client.test.test.find_one() - client.close() - - props: Dict = {"request_token_callback": request_token, "allowed_hosts": ["example.com"]} - client = MongoClient( - self.uri_single + "&ignored=example.com", authmechanismproperties=props, connect=False - ) + props: Dict = {"OIDC_HUMAN_CALLBACK": request_token, "ALLOWED_HOSTS": []} + client = self.create_client(authmechanismproperties=props) + # Assert that a find operation fails with a client-side error. with self.assertRaises(ConfigurationError): client.test.test.find_one() + # Close the client. client.close() - def test_valid_request_token_callback(self): - request_cb = self.create_request_cb() - + # Create a client that uses the URL mongodb://localhost/?authMechanism=MONGODB-OIDC&ignored=example.com, + # a human callback, and an ALLOWED_HOSTS that contains ["example.com"]. props: Dict = { - "request_token_callback": request_cb, + "OIDC_HUMAN_CALLBACK": request_token, + "ALLOWED_HOSTS": ["example.com"], } - client = MongoClient(self.uri_single, authmechanismproperties=props) - client.test.test.find_one() + with warnings.catch_warnings(): + warnings.simplefilter("default") + client = self.create_client( + self.uri_single + "&ignored=example.com", + authmechanismproperties=props, + connect=False, + ) + # Assert that a find operation fails with a client-side error. + with self.assertRaises(ConfigurationError): + client.test.test.find_one() + # Close the client. client.close() - client = MongoClient(self.uri_single, authmechanismproperties=props) + def test_2_1_valid_callback_inputs(self): + # Create a MongoClient with a human callback that validates its inputs and returns a valid access token. + client = self.create_client() + # Perform a find operation that succeeds. Verify that the human callback was called with the appropriate inputs, including the timeout parameter if possible. + # Ensure that there are no unexpected fields. client.test.test.find_one() + # Close the client. client.close() - def test_request_callback_returns_null(self): - def request_token_null(a, b): - return None + def test_2_2_OIDC_HUMAN_CALLBACK_returns_missing_data(self): + # Create a MongoClient with a human callback that returns data not conforming to the OIDCCredential with missing fields. + class CustomCB(OIDCCallback): + def fetch(self, ctx): + return dict() - props: Dict = {"request_token_callback": request_token_null} - client = MongoClient(self.uri_single, authMechanismProperties=props) + client = self.create_client(request_cb=CustomCB()) + # Perform a find operation that fails. with self.assertRaises(ValueError): client.test.test.find_one() + # Close the client. client.close() - def test_request_callback_invalid_result(self): - def request_token_invalid(a, b): - return {} - - props: Dict = {"request_token_callback": request_token_invalid} - client = MongoClient(self.uri_single, authMechanismProperties=props) - with self.assertRaises(ValueError): - client.test.test.find_one() - client.close() + def test_3_1_uses_speculative_authentication_if_there_is_a_cached_token(self): + # Create a client with a human callback that returns a valid token. + client = self.create_client() - def request_cb_extra_value(server_info, context): - result = self.create_request_cb()(server_info, context) - result["foo"] = "bar" - return result + # Set a fail point for ``find`` commands. + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "errorCode": 391, "closeConnection": True}, + } + ): + # Perform a ``find`` operation that fails. + with self.assertRaises(AutoReconnect): + client.test.test.find_one() - props: Dict = {"request_token_callback": request_cb_extra_value} - client = MongoClient(self.uri_single, authMechanismProperties=props) - with self.assertRaises(ValueError): + # Set a fail point for ``saslStart`` commands. + with self.fail_point( + { + "mode": "alwaysOn", + "data": {"failCommands": ["saslStart"], "errorCode": 20}, + } + ): + # Perform a ``find`` operation that succeeds client.test.test.find_one() - client.close() - def test_speculative_auth_success(self): - request_token = self.create_request_cb() + # Close the client. + client.close() - # Create a client with a request callback that returns a valid token. - props: Dict = {"request_token_callback": request_token} - client = MongoClient(self.uri_single, authmechanismproperties=props) + def test_3_2_does_not_use_speculative_authentication_if_there_is_no_cached_token(self): + # Create a ``MongoClient`` with a human callback that returns a valid token + client = self.create_client() - # Set a fail point for saslStart commands. + # Set a fail point for ``saslStart`` commands. with self.fail_point( { - "mode": {"times": 2}, - "data": {"failCommands": ["saslStart"], "errorCode": 18}, + "mode": "alwaysOn", + "data": {"failCommands": ["saslStart"], "errorCode": 20}, } ): - # Perform a find operation. - client.test.test.find_one() + # Perform a ``find`` operation that fails. + with self.assertRaises(OperationFailure): + client.test.test.find_one() # Close the client. client.close() - def test_reauthenticate_succeeds(self): + def test_4_1_reauthenticate_succeeds(self): + # Create a default OIDC client and add an event listener. + # The following assumes that the driver does not emit saslStart or saslContinue events. + # If the driver does emit those events, ignore/filter them for the purposes of this test. listener = EventListener() + client = self.create_client(event_listeners=[listener]) - # Create request callback that returns valid credentials. - request_cb = self.create_request_cb() - - # Create a client with the callback. - props: Dict = {"request_token_callback": request_cb} - client = MongoClient( - self.uri_single, event_listeners=[listener], authmechanismproperties=props - ) - - # Perform a find operation. + # Perform a find operation that succeeds. client.test.test.find_one() - # Assert that the request callback has been called once. + # Assert that the human callback has been called once. self.assertEqual(self.request_called, 1) + # Clear the listener state if possible. listener.reset() + # Force a reauthenication using a fail point. with self.fail_point( { "mode": {"times": 1}, "data": {"failCommands": ["find"], "errorCode": 391}, } ): - # Perform a find operation. + # Perform another find operation that succeeds. client.test.test.find_one() + # Assert that the human callback has been called twice. + self.assertEqual(self.request_called, 2) + + # Assert that the ordering of list started events is [find, find]. + # Note that if the listener stat could not be cleared then there will be an extra find command. started_events = [ i.command_name for i in listener.started_events if not i.command_name.startswith("sasl") ] @@ -252,104 +316,164 @@ def test_reauthenticate_succeeds(self): "find", ], ) + # Assert that the list of command succeeded events is [find]. self.assertEqual(succeeded_events, ["find"]) + # Assert that a find operation failed once during the command execution. self.assertEqual(failed_events, ["find"]) - - # Assert that the request callback has been called twice. - self.assertEqual(self.request_called, 2) + # Close the client. client.close() - def test_reauthenticate_succeeds_no_refresh(self): + def test_4_2_reauthenticate_succeeds_no_refresh(self): + # Create a default OIDC client with a human callback that does not return a refresh token. cb = self.create_request_cb() - def request_cb(*args, **kwargs): - result = cb(*args, **kwargs) - del result["refresh_token"] - return result + class CustomRequest(OIDCCallback): + def fetch(self, *args, **kwargs): + result = cb.fetch(*args, **kwargs) + result.refresh_token = None + return result - # Create a client with the callback. - props: Dict = {"request_token_callback": request_cb} - client = MongoClient(self.uri_single, authmechanismproperties=props) + client = self.create_client(request_cb=CustomRequest()) - # Perform a find operation. + # Perform a find operation that succeeds. client.test.test.find_one() - # Assert that the request callback has been called once. + # Assert that the human callback has been called once. self.assertEqual(self.request_called, 1) + # Force a reauthenication using a fail point. with self.fail_point( { "mode": {"times": 1}, "data": {"failCommands": ["find"], "errorCode": 391}, } ): - # Perform a find operation. + # Perform a find operation that succeeds. client.test.test.find_one() - # Assert that the request callback has been called twice. + # Assert that the human callback has been called twice. self.assertEqual(self.request_called, 2) + # Close the client. client.close() - def test_reauthenticate_succeeds_after_refresh_fails(self): - # Create request callback that returns valid credentials. - request_cb = self.create_request_cb() - - # Create a client with the callback. - props: Dict = {"request_token_callback": request_cb} - client = MongoClient(self.uri_single, authmechanismproperties=props) + def test_4_3_reauthenticate_succeeds_after_refresh_fails(self): + # Create a client with a human callback that returns a valid token. + client = self.create_client() - # Perform a find operation. + # Perform a find operation that succeeds. client.test.test.find_one() - # Assert that the request callback has been called once. + # Assert that the human callback has been called once. self.assertEqual(self.request_called, 1) + # Force a reauthenication using a fail point. with self.fail_point( { "mode": {"times": 2}, - "data": {"failCommands": ["find", "saslContinue"], "errorCode": 391}, + "data": {"failCommands": ["find", "saslStart"], "errorCode": 391}, } ): - # Perform a find operation. + # Perform a find operation that succeeds. client.test.test.find_one() - # Assert that the request callback has been called three times. + # Assert that the human callback has been called 3 times. self.assertEqual(self.request_called, 3) - def test_reauthenticate_fails(self): - # Create request callback that returns valid credentials. - request_cb = self.create_request_cb() - - # Create a client with the callback. - props: Dict = {"request_token_callback": request_cb} - client = MongoClient(self.uri_single, authmechanismproperties=props) + # Close the client. + client.close() - # Perform a find operation. + def test_4_4_reauthenticate_fails(self): + # Create a client with a human callback that returns a valid token. + client = self.create_client() + # Perform a find operation that succeeds (to force a speculative auth). client.test.test.find_one() - - # Assert that the request callback has been called once. + # Assert that the human callback has been called once. self.assertEqual(self.request_called, 1) - + # Force a reauthentication using a failCommand. with self.fail_point( { - "mode": {"times": 2}, - "data": {"failCommands": ["find"], "errorCode": 391}, + "mode": {"times": 3}, + "data": {"failCommands": ["find", "saslStart"], "errorCode": 391}, } ): # Perform a find operation that fails. with self.assertRaises(OperationFailure): client.test.test.find_one() - - # Assert that the request callback has been called twice. + # Assert that the human callback has been called two times. self.assertEqual(self.request_called, 2) + # Close the client. client.close() - def test_reauthenticate_succeeds_bulk_write(self): + def test_request_callback_returns_null(self): + class RequestTokenNull(OIDCCallback): + def fetch(self, a): + return None + + client = self.create_client(request_cb=RequestTokenNull()) + with self.assertRaises(ValueError): + client.test.test.find_one() + client.close() + + def test_request_callback_invalid_result(self): + class CallbackInvalidToken(OIDCCallback): + def fetch(self, a): + return {} + + client = self.create_client(request_cb=CallbackInvalidToken()) + with self.assertRaises(ValueError): + client.test.test.find_one() + client.close() + + def test_reauthentication_succeeds_multiple_connections(self): request_cb = self.create_request_cb() # Create a client with the callback. - props: Dict = {"request_token_callback": request_cb} - client = MongoClient(self.uri_single, authmechanismproperties=props) + client1 = self.create_client(request_cb=request_cb) + client2 = self.create_client(request_cb=request_cb) + + # Perform an insert operation. + client1.test.test.insert_many([{"a": 1}, {"a": 1}]) + client2.test.test.find_one() + self.assertEqual(self.request_called, 2) + + # Use the same authenticator for both clients + # to simulate a race condition with separate connections. + # We should only see one extra callback despite both connections + # needing to reauthenticate. + client2.options.pool_options._credentials.cache.data = ( + client1.options.pool_options._credentials.cache.data + ) + + client1.test.test.find_one() + client2.test.test.find_one() + + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "errorCode": 391}, + } + ): + client1.test.test.find_one() + + self.assertEqual(self.request_called, 3) + + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "errorCode": 391}, + } + ): + client2.test.test.find_one() + + self.assertEqual(self.request_called, 3) + client1.close() + client2.close() + + # PyMongo specific tests, since we have multiple code paths for reauth handling. + + def test_reauthenticate_succeeds_bulk_write(self): + # Create a client. + client = self.create_client() # Perform a find operation. client.test.test.find_one() @@ -364,24 +488,21 @@ def test_reauthenticate_succeeds_bulk_write(self): } ): # Perform a bulk write operation. - client.test.test.bulk_write([InsertOne({})]) + client.test.test.bulk_write([InsertOne({})]) # type:ignore[type-var] # Assert that the request callback has been called twice. self.assertEqual(self.request_called, 2) client.close() def test_reauthenticate_succeeds_bulk_read(self): - request_cb = self.create_request_cb() - - # Create a client with the callback. - props: Dict = {"request_token_callback": request_cb} - client = MongoClient(self.uri_single, authmechanismproperties=props) + # Create a client. + client = self.create_client() # Perform a find operation. client.test.test.find_one() # Perform a bulk write operation. - client.test.test.bulk_write([InsertOne({})]) + client.test.test.bulk_write([InsertOne({})]) # type:ignore[type-var] # Assert that the request callback has been called once. self.assertEqual(self.request_called, 1) @@ -401,11 +522,8 @@ def test_reauthenticate_succeeds_bulk_read(self): client.close() def test_reauthenticate_succeeds_cursor(self): - request_cb = self.create_request_cb() - - # Create a client with the callback. - props: Dict = {"request_token_callback": request_cb} - client = MongoClient(self.uri_single, authmechanismproperties=props) + # Create a client. + client = self.create_client() # Perform an insert operation. client.test.test.insert_one({"a": 1}) @@ -428,11 +546,8 @@ def test_reauthenticate_succeeds_cursor(self): client.close() def test_reauthenticate_succeeds_get_more(self): - request_cb = self.create_request_cb() - - # Create a client with the callback. - props: Dict = {"request_token_callback": request_cb} - client = MongoClient(self.uri_single, authmechanismproperties=props) + # Create a client. + client = self.create_client() # Perform an insert operation. client.test.test.insert_many([{"a": 1}, {"a": 1}]) @@ -456,17 +571,13 @@ def test_reauthenticate_succeeds_get_more(self): def test_reauthenticate_succeeds_get_more_exhaust(self): # Ensure no mongos - props = {"request_token_callback": self.create_request_cb()} - client = MongoClient(self.uri_single, authmechanismproperties=props) + client = self.create_client() hello = client.admin.command(HelloCompat.LEGACY_CMD) if hello.get("msg") != "isdbgrid": raise unittest.SkipTest("Must not be a mongos") - request_cb = self.create_request_cb() - # Create a client with the callback. - props: Dict = {"request_token_callback": request_cb} - client = MongoClient(self.uri_single, authmechanismproperties=props) + client = self.create_client() # Perform an insert operation. client.test.test.insert_many([{"a": 1}, {"a": 1}]) @@ -489,13 +600,8 @@ def test_reauthenticate_succeeds_get_more_exhaust(self): client.close() def test_reauthenticate_succeeds_command(self): - request_cb = self.create_request_cb() - - # Create a client with the callback. - props: Dict = {"request_token_callback": request_cb} - - print("start of test") - client = MongoClient(self.uri_single, authmechanismproperties=props) + # Create a client. + client = self.create_client() # Perform an insert operation. client.test.test.insert_one({"a": 1}) @@ -518,14 +624,235 @@ def test_reauthenticate_succeeds_command(self): self.assertEqual(self.request_called, 2) client.close() - def test_reauthentication_succeeds_multiple_connections(self): + +class TestAuthOIDCMachine(OIDCTestBase): + uri: str + + def setUp(self): + self.request_called = 0 + if PROVIDER_NAME == "aws": + self.default_username = "test_user1" + else: + self.default_username = None + + def create_request_cb(self, username=None, sleep=0): + if username is None: + username = self.default_username + + def request_token(context): + assert isinstance(context.timeout_seconds, int) + assert context.version == 1 + assert context.refresh_token is None + assert context.idp_info is None + token = self.get_token(username) + time.sleep(sleep) + self.request_called += 1 + return OIDCCallbackResult(access_token=token) + + class Inner(OIDCCallback): + def fetch(self, context): + return request_token(context) + + return Inner() + + def create_client(self, *args, **kwargs): + request_cb = kwargs.pop("request_cb", self.create_request_cb()) + props = kwargs.pop("authmechanismproperties", {"OIDC_CALLBACK": request_cb}) + kwargs["retryReads"] = False + if not len(args): + args = [self.uri_single] + return MongoClient(*args, authmechanismproperties=props, **kwargs) + + def test_1_1_callback_is_called_during_reauthentication(self): + # Create a ``MongoClient`` configured with a custom OIDC callback that + # implements the provider logic. + client = self.create_client() + # Perform a ``find`` operation that succeeds. + client.test.test.find_one() + # Assert that the callback was called 1 time. + self.assertEqual(self.request_called, 1) + # Close the client. + client.close() + + def test_1_2_callback_is_called_once_for_multiple_connections(self): + # Create a ``MongoClient`` configured with a custom OIDC callback that + # implements the provider logic. + client = self.create_client() + + # Start 10 threads and run 100 find operations in each thread that all succeed. + def target(): + for _ in range(100): + client.test.test.find_one() + + threads = [] + for _ in range(10): + thread = threading.Thread(target=target) + thread.start() + threads.append(thread) + for thread in threads: + thread.join() + # Assert that the callback was called 1 time. + self.assertEqual(self.request_called, 1) + # Close the client. + client.close() + + def test_2_1_valid_callback_inputs(self): + # Create a MongoClient configured with an OIDC callback that validates its inputs and returns a valid access token. + client = self.create_client() + # Perform a find operation that succeeds. + client.test.test.find_one() + # Assert that the OIDC callback was called with the appropriate inputs, including the timeout parameter if possible. Ensure that there are no unexpected fields. + self.assertEqual(self.request_called, 1) + # Close the client. + client.close() + + def test_2_2_oidc_callback_returns_null(self): + # Create a MongoClient configured with an OIDC callback that returns null. + class CallbackNullToken(OIDCCallback): + def fetch(self, a): + return None + + client = self.create_client(request_cb=CallbackNullToken()) + # Perform a find operation that fails. + with self.assertRaises(ValueError): + client.test.test.find_one() + # Close the client. + client.close() + + def test_2_3_oidc_callback_returns_missing_data(self): + # Create a MongoClient configured with an OIDC callback that returns data not conforming to the OIDCCredential with missing fields. + class CustomCallback(OIDCCallback): + count = 0 + + def fetch(self, a): + self.count += 1 + return object() + + client = self.create_client(request_cb=CustomCallback()) + # Perform a find operation that fails. + with self.assertRaises(ValueError): + client.test.test.find_one() + # Close the client. + client.close() + + def test_2_4_oidc_callback_returns_invalid_data(self): + # Create a MongoClient configured with an OIDC callback that returns data not conforming to the OIDCCredential with extra fields. + class CustomCallback(OIDCCallback): + count = 0 + + def fetch(self, a): + self.count += 1 + return OIDCCallbackResult(access_token="bad value") + + client = self.create_client(request_cb=CustomCallback()) + # Perform a ``find`` operation that fails. + with self.assertRaises(OperationFailure): + client.test.test.find_one() + # Close the client. + client.close() + + def test_2_5_invalid_client_configuration_with_callback(self): + # Create a MongoClient configured with an OIDC callback and auth mechanism property PROVIDER_NAME:aws. request_cb = self.create_request_cb() + props: Dict = {"OIDC_CALLBACK": request_cb, "PROVIDER_NAME": "aws"} + # Assert it returns a client configuration error. + with self.assertRaises(ConfigurationError): + self.create_client(authmechanismproperties=props) - # Create a client with the callback. - props: Dict = {"request_token_callback": request_cb} + def test_3_1_authentication_failure_with_cached_tokens_fetch_a_new_token_and_retry(self): + # Create a MongoClient and an OIDC callback that implements the provider logic. + client = self.create_client() + # Poison the cache with an invalid access token. + # Set a fail point for ``find`` command. + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "errorCode": 391, "closeConnection": True}, + } + ): + # Perform a ``find`` operation that fails. This is to force the ``MongoClient`` + # to cache an access token. + with self.assertRaises(AutoReconnect): + client.test.test.find_one() + # Poison the cache of the client. + client.options.pool_options._credentials.cache.data.access_token = "bad" + # Reset the request count. + self.request_called = 0 + # Verify that a find succeeds. + client.test.test.find_one() + # Verify that the callback was called 1 time. + self.assertEqual(self.request_called, 1) + # Close the client. + client.close() + + def test_3_2_authentication_failures_without_cached_tokens_returns_an_error(self): + # Create a MongoClient configured with retryReads=false and an OIDC callback that always returns invalid access tokens. + class CustomCallback(OIDCCallback): + count = 0 - client1 = MongoClient(self.uri_single, authmechanismproperties=props) - client2 = MongoClient(self.uri_single, authmechanismproperties=props) + def fetch(self, a): + self.count += 1 + return OIDCCallbackResult(access_token="bad value") + + callback = CustomCallback() + client = self.create_client(request_cb=callback) + # Perform a ``find`` operation that fails. + with self.assertRaises(OperationFailure): + client.test.test.find_one() + # Verify that the callback was called 1 time. + self.assertEqual(callback.count, 1) + # Close the client. + client.close() + + def test_4_reauthentication(self): + # Create a ``MongoClient`` configured with a custom OIDC callback that + # implements the provider logic. + client = self.create_client() + + # Set a fail point for the find command. + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "errorCode": 391}, + } + ): + # Perform a ``find`` operation that succeeds. + client.test.test.find_one() + + # Verify that the callback was called 2 times (once during the connection + # handshake, and again during reauthentication). + self.assertEqual(self.request_called, 2) + + # Close the client. + client.close() + + def test_speculative_auth_success(self): + client1 = self.create_client() + client1.test.test.find_one() + client2 = self.create_client() + + # Prime the cache of the second client. + client2.options.pool_options._credentials.cache.data = ( + client1.options.pool_options._credentials.cache.data + ) + + # Set a fail point for saslStart commands. + with self.fail_point( + { + "mode": {"times": 2}, + "data": {"failCommands": ["saslStart"], "errorCode": 18}, + } + ): + # Perform a find operation. + client2.test.test.find_one() + + # Close the clients. + client2.close() + client1.close() + + def test_reauthentication_succeeds_multiple_connections(self): + client1 = self.create_client() + client2 = self.create_client() # Perform an insert operation. client1.test.test.insert_many([{"a": 1}, {"a": 1}]) @@ -565,6 +892,30 @@ def test_reauthentication_succeeds_multiple_connections(self): client1.close() client2.close() + def test_azure_no_username(self): + if PROVIDER_NAME != "azure": + raise unittest.SkipTest("Test is only supported on Azure") + opts = parse_uri(self.uri_single)["options"] + token_aud = opts["authmechanismproperties"]["TOKEN_AUDIENCE"] + + props = dict(TOKEN_AUDIENCE=token_aud, PROVIDER_NAME="azure") + client = self.create_client(authMechanismProperties=props) + client.test.test.find_one() + client.close() + + def test_azure_bad_username(self): + if PROVIDER_NAME != "azure": + raise unittest.SkipTest("Test is only supported on Azure") + + opts = parse_uri(self.uri_single)["options"] + token_aud = opts["authmechanismproperties"]["TOKEN_AUDIENCE"] + + props = dict(TOKEN_AUDIENCE=token_aud, PROVIDER_NAME="azure") + client = self.create_client(username="bad", authmechanismproperties=props) + with self.assertRaises(ValueError): + client.test.test.find_one() + client.close() + if __name__ == "__main__": unittest.main() diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 4976a6dd49..ba0e2f5e33 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -19,6 +19,7 @@ import json import os import sys +import warnings sys.path[0:0] = [""] @@ -26,6 +27,7 @@ from test.unified_format import generate_test_classes from pymongo import MongoClient +from pymongo.auth_oidc import OIDCCallback _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth") @@ -34,6 +36,11 @@ class TestAuthSpec(unittest.TestCase): pass +class SampleHumanCallback(OIDCCallback): + def fetch(self, context): + pass + + def create_test(test_case): def run_test(self): uri = test_case["uri"] @@ -41,14 +48,15 @@ def run_test(self): credential = test_case.get("credential") if not valid: - self.assertRaises(Exception, MongoClient, uri, connect=False) + with warnings.catch_warnings(): + warnings.simplefilter("default") + self.assertRaises(Exception, MongoClient, uri, connect=False) else: props = {} if credential: props = credential["mechanism_properties"] or {} - if props.get("REQUEST_TOKEN_CALLBACK"): - props["request_token_callback"] = lambda x, y: 1 - del props["REQUEST_TOKEN_CALLBACK"] + if props.get("CALLBACK"): + props["callback"] = SampleHumanCallback() client = MongoClient(uri, connect=False, authmechanismproperties=props) credentials = client.options.pool_options._credentials if credential is None: @@ -80,10 +88,8 @@ def run_test(self): ) elif "PROVIDER_NAME" in expected: self.assertEqual(actual.provider_name, expected["PROVIDER_NAME"]) - elif "request_token_callback" in expected: - self.assertEqual( - actual.request_token_callback, expected["request_token_callback"] - ) + elif "callback" in expected: + self.assertEqual(actual.callback, expected["callback"]) else: self.fail(f"Unhandled property: {key}") else: diff --git a/test/unified_format.py b/test/unified_format.py index 2fd2579be2..a8e437e44e 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -140,7 +140,7 @@ } -# Build up a placeholder map. +# Build up a placeholder maps. PLACEHOLDER_MAP = {} for provider_name, provider_data in [ ("local", {"key": LOCAL_MASTER_KEY}), @@ -159,6 +159,15 @@ placeholder = f"/clientEncryptionOpts/kmsProviders/{provider_name}/{key}" PLACEHOLDER_MAP[placeholder] = value +PROVIDER_NAME = os.environ.get("OIDC_PROVIDER_NAME", "aws") +if PROVIDER_NAME == "aws": + PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {"PROVIDER_NAME": "aws"} +elif PROVIDER_NAME == "azure": + PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = { + "PROVIDER_NAME": "azure", + "TOKEN_AUDIENCE": os.environ["AZUREOIDC_AUDIENCE"], + } + def interrupt_loop(): global IS_INTERRUPTED @@ -233,6 +242,8 @@ def is_run_on_requirement_satisfied(requirement): if req_auth is not None: if req_auth: auth_satisfied = client_context.auth_enabled + if auth_satisfied and "authMechanism" in requirement: + auth_satisfied = client_context.check_auth_type(requirement["authMechanism"]) else: auth_satisfied = not client_context.auth_enabled @@ -933,7 +944,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest): a class attribute ``TEST_SPEC``. """ - SCHEMA_VERSION = Version.from_string("1.18") + SCHEMA_VERSION = Version.from_string("1.19") RUN_ON_LOAD_BALANCER = True RUN_ON_SERVERLESS = True TEST_SPEC: Any diff --git a/test/utils.py b/test/utils.py index 85e952bf8b..2c6ec86bc8 100644 --- a/test/utils.py +++ b/test/utils.py @@ -558,7 +558,7 @@ def _mongo_client(host, port, authenticate=True, directConnection=None, **kwargs client_options.update(kwargs) uri = _connection_string(host) - if client_context.auth_enabled and authenticate: + if client_context.auth_enabled and authenticate and "authMechanism" not in kwargs: # Only add the default username or password if one is not provided. res = parse_uri(uri) if (