Skip to content

PYTHON-3845 OIDC: Implement Machine Callback Mechanism #1401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d7265bf
PYTHON-3845 OIDC: Implement Machine Callback Mechanism
blink1073 Oct 15, 2023
cc838c5
Merge branch 'master' of github.com:mongodb/mongo-python-driver into …
blink1073 Oct 18, 2023
4c22587
run unified tests in oidc test
blink1073 Oct 18, 2023
0485e82
fix path handling
blink1073 Oct 18, 2023
ae77133
debug
blink1073 Oct 18, 2023
e5b31cb
fix path handling
blink1073 Oct 18, 2023
fee9cf4
handle db_ip
blink1073 Oct 18, 2023
e11865a
handle srv hosts
blink1073 Oct 19, 2023
fb7470d
PYTHON-3845 OIDC: Implement Machine Callback Mechanism
blink1073 Oct 20, 2023
131de6a
undo test comment out
blink1073 Oct 20, 2023
1d7011f
change name to custom_token_callback
blink1073 Oct 31, 2023
36d6c8b
Merge branch 'master' of github.com:mongodb/mongo-python-driver into …
blink1073 Nov 1, 2023
9c961e9
simplify custom callback and start prose tests
blink1073 Nov 2, 2023
c0eed02
fix placeholder handling
blink1073 Nov 3, 2023
65e47a1
Merge branch 'master' of github.com:mongodb/mongo-python-driver into …
blink1073 Nov 7, 2023
f26c60a
wip implement OIDC tests
blink1073 Nov 8, 2023
7897053
Merge branch 'master' of github.com:mongodb/mongo-python-driver into …
blink1073 Nov 14, 2023
a95479c
Use callback class and update tests
blink1073 Nov 15, 2023
e55d34e
Fix typing and test
blink1073 Nov 15, 2023
480a648
Merge branch 'master' of github.com:mongodb/mongo-python-driver into …
blink1073 Nov 16, 2023
c085b50
fix tests
blink1073 Nov 16, 2023
07b8c6d
fix default port
blink1073 Nov 16, 2023
5b991c3
lint
blink1073 Nov 16, 2023
2cd3aa7
add reauth succeeds prose test for machine
blink1073 Nov 16, 2023
e221ebf
Merge branch 'master' of github.com:mongodb/mongo-python-driver into …
blink1073 Nov 17, 2023
5c0f770
use dataclasses in callbacks
blink1073 Nov 17, 2023
dae8124
Merge branch 'master' of github.com:mongodb/mongo-python-driver into …
blink1073 Nov 27, 2023
e563840
update for spec changes
blink1073 Nov 28, 2023
9a4d6b0
add updated spec tests
blink1073 Nov 28, 2023
789a55c
fix test and typing
blink1073 Nov 29, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .evergreen/run-mongodb-oidc-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,5 @@ fi
export TEST_AUTH_OIDC=1
export COVERAGE=1
export AUTH="auth"
export AWS_WEB_IDENTITY_TOKEN_FILE="$OIDC_TOKEN_DIR/test_user1"
bash ./.evergreen/tox.sh -m test-eg
3 changes: 2 additions & 1 deletion .evergreen/run-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ set -o xtrace

AUTH=${AUTH:-noauth}
SSL=${SSL:-nossl}
TEST_ARGS="$1"
TEST_ARGS="${*:1}"
PYTHON=$(which python)
export PIP_QUIET=1 # Quiet by default

Expand All @@ -52,6 +52,7 @@ if [ "$AUTH" != "noauth" ]; then
elif [ ! -z "$TEST_AUTH_OIDC" ]; then
export DB_USER=$OIDC_ALTAS_USER
export DB_PASSWORD=$OIDC_ATLAS_PASSWORD
export DB_IP="$MONGODB_URI"
else
export DB_USER="bob"
export DB_PASSWORD="pwd123"
Expand Down
31 changes: 25 additions & 6 deletions pymongo/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@
from bson.binary import Binary
from bson.son import SON
from pymongo.auth_aws import _authenticate_aws
from pymongo.auth_oidc import _authenticate_oidc, _get_authenticator, _OIDCProperties
from pymongo.auth_oidc import (
_authenticate_oidc,
_get_authenticator,
_OIDCAWSCallback,
_OIDCProperties,
)
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.saslprep import saslprep

Expand Down Expand Up @@ -164,7 +169,8 @@ def _build_credentials_tuple(
elif mech == "MONGODB-OIDC":
properties = extra.get("authmechanismproperties", {})
request_token_callback = properties.get("request_token_callback")
provider_name = properties.get("PROVIDER_NAME", "")
custom_token_callback = properties.get("custom_token_callback")
provider_name = properties.get("PROVIDER_NAME")
default_allowed = [
"*.mongodb.net",
"*.mongodb-dev.net",
Expand All @@ -175,12 +181,25 @@ def _build_credentials_tuple(
"::1",
]
allowed_hosts = properties.get("allowed_hosts", default_allowed)
if not request_token_callback and provider_name != "aws":
raise ConfigurationError(
"authentication with MONGODB-OIDC requires providing an request_token_callback or a provider_name of 'aws'"
)
msg = "authentication with MONGODB-OIDC requires providing a request_token_callback, a provider_name, or a custom_token_callback"
if request_token_callback is not None:
if provider_name is not None or custom_token_callback is not None:
raise ConfigurationError(msg)
elif provider_name is not None:
if custom_token_callback is not None:
raise ConfigurationError(msg)
if provider_name == "aws":
custom_token_callback = _OIDCAWSCallback()
else:
raise ConfigurationError(
f"unrecognized provider_name for MONGODB-OIDC: {provider_name}"
)
elif custom_token_callback is None:
raise ConfigurationError(msg)

oidc_props = _OIDCProperties(
request_token_callback=request_token_callback,
custom_token_callback=custom_token_callback,
provider_name=provider_name,
allowed_hosts=allowed_hosts,
)
Expand Down
168 changes: 123 additions & 45 deletions pymongo/auth_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,93 @@
"""MONGODB-OIDC Authentication helpers."""
from __future__ import annotations

import abc
import os
import threading
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Optional
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union

import bson
from bson.binary import Binary
from bson.son import SON
from pymongo._csot import remaining
from pymongo.errors import ConfigurationError, OperationFailure

if TYPE_CHECKING:
from pymongo.auth import MongoCredential
from pymongo.pool import Connection


@dataclass
class OIDCIdPInfo:
issuer: str
clientId: str
requestScopes: Optional[list[str]] = field(default=None)


@dataclass
class OIDCHumanCallbackContext:
timeout_seconds: float
version: int
refresh_token: Optional[str] = field(default=None)


@dataclass
class OIDCMachineCallbackContext:
timeout_seconds: float
version: int


@dataclass
class OIDCHumanCallbackResult:
access_token: str
expires_in_seconds: Optional[float] = field(default=None)
refresh_token: Optional[str] = field(default=None)


@dataclass
class OIDCMachineCallbackResult:
access_token: str
expires_in_seconds: Optional[float] = field(default=None)


class OIDCMachineCallback(abc.ABC):
"""A base class for defining OIDC machine (workload federation)
callbacks.
"""

@abc.abstractmethod
def fetch(self, context: OIDCMachineCallbackContext) -> OIDCMachineCallbackResult:
"""Convert the given BSON value into our own type."""


class OIDCHumanCallback(abc.ABC):
"""A base class for defining OIDC human (workforce federation)
callbacks.
"""

@abc.abstractmethod
def fetch(
self, idp_info: OIDCIdPInfo, context: OIDCHumanCallbackContext
) -> OIDCHumanCallbackResult:
"""Convert the given BSON value into our own type."""


@dataclass
class _OIDCProperties:
request_token_callback: Optional[Callable[..., dict]]
provider_name: Optional[str]
allowed_hosts: list[str]
request_token_callback: Optional[OIDCHumanCallback] = field(default=None)
custom_token_callback: Optional[OIDCMachineCallback] = field(default=None)
provider_name: Optional[str] = field(default=None)
allowed_hosts: list[str] = field(default_factory=list)


"""Mechanism properties for MONGODB-OIDC authentication."""

TOKEN_BUFFER_MINUTES = 5
CALLBACK_TIMEOUT_SECONDS = 5 * 60
CALLBACK_VERSION = 1
HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60
HUMAN_CALLBACK_VERSION = 1
MACHINE_CALLBACK_TIMEOUT_SECONDS = 60
MACHINE_CALLBACK_VERSION = 1


def _get_authenticator(
Expand Down Expand Up @@ -72,22 +133,37 @@ def _get_authenticator(
return credentials.cache.data


class _OIDCAWSCallback(OIDCMachineCallback):
def fetch(self, context: OIDCMachineCallbackContext) -> OIDCMachineCallbackResult:
token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE")
if not token_file:
raise RuntimeError(
'MONGODB-OIDC with an "aws" provider requires "AWS_WEB_IDENTITY_TOKEN_FILE" to be set'
)
with open(token_file) as fid:
return OIDCMachineCallbackResult(access_token=fid.read().strip())


@dataclass
class _OIDCAuthenticator:
username: str
properties: _OIDCProperties
refresh_token: Optional[str] = field(default=None)
access_token: Optional[str] = field(default=None)
idp_info: Optional[dict] = field(default=None)
idp_info: Optional[OIDCIdPInfo] = field(default=None)
token_gen_id: int = field(default=0)
lock: threading.Lock = field(default_factory=threading.Lock)

def get_current_token(self, use_callback: bool = True) -> Optional[str]:
properties = self.properties

# TODO: DRIVERS-2672, handle machine callback here as well.
cb = properties.request_token_callback if use_callback else None
cb_type = "human"
cb: Union[None, OIDCHumanCallback, OIDCMachineCallback]
resp: Union[None, OIDCHumanCallbackResult, OIDCMachineCallbackResult]
if not use_callback:
cb = None
elif properties.request_token_callback:
cb = properties.request_token_callback
elif properties.custom_token_callback:
cb = properties.custom_token_callback

prev_token = self.access_token
if prev_token:
Expand All @@ -104,37 +180,32 @@ def get_current_token(self, use_callback: bool = True) -> Optional[str]:
if new_token != prev_token:
return new_token

# TODO: DRIVERS-2672 handle machine callback here.
if cb_type == "human":
context = {
"timeout_seconds": CALLBACK_TIMEOUT_SECONDS,
"version": CALLBACK_VERSION,
"refresh_token": self.refresh_token,
}
resp = cb(self.idp_info, context)

self.validate_request_token_response(resp)

if isinstance(cb, OIDCHumanCallback):
human_context = OIDCHumanCallbackContext(
timeout_seconds=HUMAN_CALLBACK_TIMEOUT_SECONDS,
version=HUMAN_CALLBACK_VERSION,
refresh_token=self.refresh_token,
)
assert self.idp_info is not None
resp = cb.fetch(self.idp_info, human_context)
if not isinstance(resp, OIDCHumanCallbackResult):
raise ValueError("Callback result must be of type OIDCHumanCallbackResult")
self.refresh_token = resp.refresh_token
else:
machine_context = OIDCMachineCallbackContext(
timeout_seconds=remaining() or MACHINE_CALLBACK_TIMEOUT_SECONDS,
version=MACHINE_CALLBACK_VERSION,
)
resp = cb.fetch(machine_context)
if not isinstance(resp, OIDCMachineCallbackResult):
raise ValueError(
"Callback result must be of type OIDCMachineCallbackResult"
)
self.access_token = resp.access_token
self.token_gen_id += 1

return self.access_token

def validate_request_token_response(self, resp: Mapping[str, Any]) -> None:
# Validate callback return value.
if not isinstance(resp, dict):
raise ValueError("OIDC callback returned invalid result")

if "access_token" not in resp:
raise ValueError("OIDC callback did not return an access_token")

expected = ["access_token", "refresh_token", "expires_in_seconds"]
for key in resp:
if key not in expected:
raise ValueError(f'Unexpected field in callback result "{key}"')

self.access_token = resp["access_token"]
self.refresh_token = resp.get("refresh_token")

def principal_step_cmd(self) -> SON[str, Any]:
"""Get a SASL start command with an optional principal name"""
# Send the SASL start with the optional principal name.
Expand All @@ -154,8 +225,7 @@ def principal_step_cmd(self) -> SON[str, Any]:
)

def auth_start_cmd(self, use_callback: bool = True) -> Optional[SON[str, Any]]:
# TODO: DRIVERS-2672, check for provider_name in self.properties here.
if self.idp_info is None:
if self.properties.request_token_callback is not None and self.idp_info is None:
return self.principal_step_cmd()

token = self.get_current_token(use_callback)
Expand Down Expand Up @@ -192,8 +262,10 @@ def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:

self.access_token = None

# TODO: DRIVERS-2672, check for provider_name in self.properties here.
# If so, we clear the access token and return finish_auth.
# If we are using machine callbacks, clear the access token and
# re-authenticate.
if self.properties.provider_name:
return self.authenticate(conn)

# Next see if the idp info has changed.
prev_idp_info = self.idp_info
Expand All @@ -203,7 +275,7 @@ def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
assert resp is not None
server_resp: dict = bson.decode(resp["payload"])
if "issuer" in server_resp:
self.idp_info = server_resp
self.idp_info = OIDCIdPInfo(**server_resp)

# Handle the case of changed idp info.
if self.idp_info != prev_idp_info:
Expand Down Expand Up @@ -240,7 +312,7 @@ def authenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:

server_resp: dict = bson.decode(resp["payload"])
if "issuer" in server_resp:
self.idp_info = server_resp
self.idp_info = OIDCIdPInfo(**server_resp)

return self.finish_auth(resp, conn)

Expand Down Expand Up @@ -273,4 +345,10 @@ def _authenticate_oidc(
if reauthenticate:
return authenticator.reauthenticate(conn)
else:
return authenticator.authenticate(conn)
try:
return authenticator.authenticate(conn)
except Exception as e:
# Try one more time an an authentication failure.
if isinstance(e, OperationFailure) and e.code == 18:
return authenticator.authenticate(conn)
raise
22 changes: 9 additions & 13 deletions pymongo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from __future__ import annotations

import datetime
import inspect
import warnings
from collections import OrderedDict, abc
from typing import (
Expand All @@ -41,6 +40,7 @@
from bson.codec_options import CodecOptions, DatetimeConversion, TypeRegistry
from bson.raw_bson import RawBSONDocument
from pymongo.auth import MECHANISMS
from pymongo.auth_oidc import OIDCHumanCallback, OIDCMachineCallback
from pymongo.compression_support import (
validate_compressors,
validate_zlib_compression_level,
Expand Down Expand Up @@ -438,20 +438,16 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni
props[key] = str(value).lower()
elif key in ["allowed_hosts"] and isinstance(value, list):
props[key] = value
elif inspect.isfunction(value):
signature = inspect.signature(value)
if key == "request_token_callback":
expected_params = 2
else:
raise ValueError(f"Unrecognized Auth mechanism function {key}")
if len(signature.parameters) != expected_params:
msg = f"{key} must accept {expected_params} parameters"
raise ValueError(msg)
elif key == "request_token_callback":
if not isinstance(value, OIDCHumanCallback):
raise ValueError("request_token_callback must be a OIDCHumanCallback object")
props[key] = value
elif key == "custom_token_callback":
if not isinstance(value, OIDCMachineCallback):
raise ValueError("custom_token_callback must be a OIDCMachineCallback object")
props[key] = value
else:
raise ValueError(
"Auth mechanism property values must be strings or callback functions"
)
raise ValueError(f"Invalid type for auth mechanism property {key}, {type(value)}")
return props

value = validate_string(option, value)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ exclude_lines = [
"return NotImplemented",
"_use_c = true",
"if __name__ == '__main__':",
"if TYPE_CHECKING:"
]
partial_branches = ["if (.*and +)*not _use_c( and.*)*:"]

Expand Down
Loading