diff --git a/CHANGELOG.md b/CHANGELOG.md index 97f6b1c..72f0624 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,18 @@ Changelog ========= +v2.1.3 (2024-07-31) +------------------- +- Added support for a new browser authentication plugin called + BrowserIdcAuthPlugin to facilitate single-sign-on integration with AWS + IAM Identity Center. [Brooke White] +- Chore: publish inline type annotations (#224) [James Dow, James Dow] + + Allow inline type hints to be packaged and distributed + following PEP561 specification + https://peps.python.org/pep-0561/#specification + + v2.1.2 (2024-06-19) ------------------- - Temporarily reverted the following commit which caused connection diff --git a/README.rst b/README.rst index 46d062e..932440d 100644 --- a/README.rst +++ b/README.rst @@ -333,12 +333,18 @@ Connection Parameters +-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------+----------+ | iam_disable_cache | bool | This option specifies whether the IAM credentials are cached. By default the IAM credentials are cached. This improves performance when requests to the API gateway are throttled. | FALSE | No | +-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------+----------+ +| idc_client_display_name | str | The client display name to be used in user consent in IdC browser auth. This is an optional value. The default value is "Amazon Redshift Python connector". | None | No | ++-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------+----------+ +| idc_region | str | The AWS region where AWS identity center instance is located. It is required for the IdC browser auth plugin. | None | No | ++-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------+----------+ | identity_namespace | str | The identity namespace to be used for the IdC browser auth plugin and IdP token auth plugin. It is an optional value if there is only one IdC instance existing or if default identity namespace is set on the cluster - else it is required. | None | No | +-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------+----------+ | idp_response_timeout | int | The timeout for retrieving SAML assertion from IdP | 120 | No | +-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------+----------+ | idp_tenant | str | The IdP tenant | None | No | +-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------+----------+ +| issuer_url | str | The issuer url for the AWS IdC access portal. It is required for the IdC browser auth plugin. | None | No | ++-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------+----------+ | listen_port | int | The listen port IdP will send the SAML assertion to | 7890 | No | +-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------+----------+ | login_to_rp | str | Only for AdfsCredentialsProvider. Used to specify the loginToRp when performing IdpInitiatedSignOn as apart of form based authentication. | urn:amazon:webservices | No | diff --git a/redshift_connector/__init__.py b/redshift_connector/__init__.py index c2b6a76..49320b0 100644 --- a/redshift_connector/__init__.py +++ b/redshift_connector/__init__.py @@ -57,6 +57,8 @@ _logger: logging.Logger = logging.getLogger(__name__) IDC_PLUGINS_LIST = ( + "redshift_connector.plugin.BrowserIdcAuthPlugin", + "BrowserIdcAuthPlugin", "redshift_connector.plugin.IdpTokenAuthPlugin", "IdpTokenAuthPlugin", ) @@ -65,6 +67,8 @@ "BrowserAzureOAuth2CredentialsProvider", "redshift_connector.plugin.BasicJwtCredentialsProvider", "BasicJwtCredentialsProvider", + "redshift_connector.plugin.BrowserIdcAuthPlugin", + "BrowserIdcAuthPlugin", "redshift_connector.plugin.IdpTokenAuthPlugin", "IdpTokenAuthPlugin", ) @@ -158,6 +162,9 @@ def connect( serverless_work_group: typing.Optional[str] = None, group_federation: typing.Optional[bool] = None, identity_namespace: typing.Optional[str] = None, + idc_client_display_name: typing.Optional[str] = None, + idc_region: typing.Optional[str] = None, + issuer_url: typing.Optional[str] = None, token: typing.Optional[str] = None, token_type: typing.Optional[str] = None, ) -> Connection: @@ -265,6 +272,12 @@ def connect( Use the IDP Groups in the Redshift. Default value False. identity_namespace: Optional[str] The identity namespace to be used with IdC auth plugin. Default value is None. + idc_client_display_name: Optional[str] + The client display name to be used in user consent in IdC browser auth. Default value is `Amazon Redshift Python connector`. + idc_region: Optional[str] + The AWS region where IdC instance is located. Default value is None. + issuer_url: Optional[str] + The issuer url for the AWS IdC access portal. Default value is None. token: Optional[str] The access token to be used with IdC basic credentials provider plugin. Default value is None. token_type: Optional[str] @@ -296,10 +309,13 @@ def connect( info.put("host", host) info.put("iam", iam) info.put("iam_disable_cache", iam_disable_cache) + info.put("idc_client_display_name", idc_client_display_name) + info.put("idc_region", idc_region) info.put("identity_namespace", identity_namespace) info.put("idp_host", idp_host) info.put("idp_response_timeout", idp_response_timeout) info.put("idp_tenant", idp_tenant) + info.put("issuer_url", issuer_url) info.put("is_serverless", is_serverless) info.put("listen_port", listen_port) info.put("login_url", login_url) @@ -398,6 +414,7 @@ def connect( numeric_to_float=info.numeric_to_float, identity_namespace=info.identity_namespace, token_type=info.token_type, + idc_client_display_name=info.idc_client_display_name, ) diff --git a/redshift_connector/core.py b/redshift_connector/core.py index d5a3ed5..3c12530 100644 --- a/redshift_connector/core.py +++ b/redshift_connector/core.py @@ -433,6 +433,7 @@ def __init__( numeric_to_float: bool = False, identity_namespace: typing.Optional[str] = None, token_type: typing.Optional[str] = None, + idc_client_display_name: typing.Optional[str] = None, ): """ Creates a :class:`Connection` to an Amazon Redshift cluster. For more information on establishing a connection to an Amazon Redshift cluster using `federated API access `_ see our examples page. @@ -481,6 +482,8 @@ def __init__( The identity namespace to be used with IdC auth plugin. Default value is None. token_type: Optional[str] The token type to be used for authentication using IdP Token auth plugin + idc_client_display_name: Optional[str] + The client display name to be used for user consent in IdC browser auth plugin. """ self.merge_socket_read = True @@ -561,7 +564,10 @@ def get_calling_module() -> str: redshift_native_auth = True init_params["idp_type"] = "AzureAD" - if credentials_provider.split(".")[-1] in ("IdpTokenAuthPlugin",): + if credentials_provider.split(".")[-1] in ( + "IdpTokenAuthPlugin", + "BrowserIdcAuthPlugin", + ): redshift_native_auth = True self.set_idc_plugins_params( init_params, credentials_provider, identity_namespace, token_type @@ -2617,6 +2623,7 @@ def set_idc_plugins_params( credentials_provider: typing.Optional[str] = None, identity_namespace: typing.Optional[str] = None, token_type: typing.Optional[str] = None, + idc_client_display_name: typing.Optional[str] = None, ) -> None: plugin_name = typing.cast(str, credentials_provider).split(".")[-1] init_params["idp_type"] = "AwsIdc" @@ -2624,5 +2631,10 @@ def set_idc_plugins_params( if identity_namespace: init_params["identity_namespace"] = identity_namespace - if token_type: + if plugin_name == "BrowserIdcAuthPlugin": + init_params["token_type"] = "ACCESS_TOKEN" + elif token_type: init_params["token_type"] = token_type + + if idc_client_display_name: + init_params["idc_client_display_name"] = idc_client_display_name diff --git a/redshift_connector/plugin/__init__.py b/redshift_connector/plugin/__init__.py index 04541dd..74edb4e 100644 --- a/redshift_connector/plugin/__init__.py +++ b/redshift_connector/plugin/__init__.py @@ -4,6 +4,7 @@ from .browser_azure_oauth2_credentials_provider import ( BrowserAzureOAuth2CredentialsProvider, ) +from .browser_idc_auth_plugin import BrowserIdcAuthPlugin from .browser_saml_credentials_provider import BrowserSamlCredentialsProvider from .common_credentials_provider import CommonCredentialsProvider from .idp_credentials_provider import IdpCredentialsProvider diff --git a/redshift_connector/plugin/browser_idc_auth_plugin.py b/redshift_connector/plugin/browser_idc_auth_plugin.py new file mode 100644 index 0000000..a3a8335 --- /dev/null +++ b/redshift_connector/plugin/browser_idc_auth_plugin.py @@ -0,0 +1,414 @@ +import base64 +import concurrent.futures +import hashlib +import logging +import os +import socket +import threading +import time +import typing +import webbrowser +from enum import Enum +from urllib.parse import urlencode, urlunsplit + +import boto3 +from botocore.exceptions import ClientError + +from redshift_connector.error import InterfaceError +from redshift_connector.plugin.common_credentials_provider import ( + CommonCredentialsProvider, +) +from redshift_connector.redshift_property import RedshiftProperty + +logging.getLogger(__name__).addHandler(logging.NullHandler()) +_logger: logging.Logger = logging.getLogger(__name__) + + +class BrowserIdcAuthPlugin(CommonCredentialsProvider): + """ + Class to get IdC Token using SSO OIDC APIs + """ + + class OAuthParamNames(Enum): + """ + Defines OAuth parameter names used when requesting IdC token from the IdC + """ + + STATE_PARAMETER_NAME = "state" + AUTH_CODE_PARAMETER_NAME = "code" + REDIRECT_PARAMETER_NAME = "redirect_uri" + CLIENT_ID_PARAMETER_NAME = "client_id" + RESPONSE_TYPE_PARAMETER_NAME = "response_type" + GRANT_TYPE_PARAMETER_NAME = "grant_type" + SCOPE_PARAMETER_NAME = "scopes" + CODE_CHALLENGE_PARAMETER_NAME = "code_challenge" + CHALLENGE_METHOD_PARAMETER_NAME = "code_challenge_method" + + IDC_CLIENT_DISPLAY_NAME = "Amazon Redshift Python connector" + CLIENT_TYPE = "public" + CREATE_TOKEN_INTERVAL = 1 + CODE_VERIFIER_LENGTH = 60 + CURRENT_INTERACTION_SCHEMA = "https" + OIDC_SCHEMA = "oidc" + AMAZON_COM_SCHEMA = "amazonaws.com" + REDSHIFT_IDC_CONNECT_SCOPE = "redshift:connect" + AUTH_CODE_GRANT_TYPE = "authorization_code" + REDIRECT_URI = "http://127.0.0.1" + AUTHORIZE_ENDPOINT = "/authorize" + CHALLENGE_METHOD = "S256" + DEFAULT_RESPONSE_TIMEOUT = 120 + DEFAULT_LISTEN_PORT = 7890 + STATE_LENGTH = 10 + + def __init__(self: "BrowserIdcAuthPlugin") -> None: + super().__init__() + self.idp_response_timeout: int = self.DEFAULT_RESPONSE_TIMEOUT + self.idc_client_display_name: str = self.IDC_CLIENT_DISPLAY_NAME + self.listen_port: int = self.DEFAULT_LISTEN_PORT + self.register_client_cache: typing.Dict[str, dict] = {} + self.idc_region: typing.Optional[str] = None + self.issuer_url: typing.Optional[str] = None + self.redirect_uri: typing.Optional[str] = None + self.sso_oidc_client: boto3.client = None + self.auth_code: typing.Optional[str] = None + + def add_parameter( + self: "BrowserIdcAuthPlugin", + info: RedshiftProperty, + ) -> None: + """ + Adds parameters to the BrowserIdcAuthPlugin + :param info: RedshiftProperty object containing the parameters to be added to the BrowserIdcAuthPlugin. + :return: None. + """ + super().add_parameter(info) + self.issuer_url = info.issuer_url + _logger.debug("Setting issuer_url = {}".format(self.issuer_url)) + self.idc_region = info.idc_region + _logger.debug("Setting idc_region = {}".format(self.idc_region)) + if info.idp_response_timeout and info.idp_response_timeout > 10: + self.idp_response_timeout = info.idp_response_timeout + _logger.debug("Setting idp_response_timeout = {}".format(self.idp_response_timeout)) + self.listen_port = info.listen_port + _logger.debug("Setting listen_port = {}".format(self.listen_port)) + if info.idc_client_display_name: + self.idc_client_display_name = info.idc_client_display_name + _logger.debug("Setting idc_client_display_name = {}".format(self.idc_client_display_name)) + + def check_required_parameters(self: "BrowserIdcAuthPlugin") -> None: + """ + Checks if the required parameters are set. + :return: None. + :raises InterfaceError: Raised when the parameters are not valid. + """ + super().check_required_parameters() + if not self.issuer_url: + _logger.error("IdC authentication failed: issuer_url needs to be provided in connection params") + raise InterfaceError( + "IdC authentication failed: The issuer_url must be included in the connection parameters." + ) + if not self.idc_region: + _logger.error("IdC authentication failed: idc_region needs to be provided in connection params") + raise InterfaceError( + "IdC authentication failed: The idc_region must be included in the connection parameters." + ) + + def get_auth_token(self: "BrowserIdcAuthPlugin") -> str: + """ + Returns the auth token as per plugin specific implementation. + :return: str. + """ + return self.get_idc_token() + + def get_idc_token(self: "BrowserIdcAuthPlugin") -> str: + """ + Returns the IdC token using SSO OIDC APIs. + :return: str. + """ + _logger.debug("BrowserIdcAuthPlugin.get_idc_token") + try: + self.check_required_parameters() + + self.sso_oidc_client = boto3.client("sso-oidc", region_name=self.idc_region) + self.redirect_uri = self.REDIRECT_URI + ":" + str(self.listen_port) + + register_client_result: typing.Dict[str, typing.Any] = self.register_client() + code_verifier: str = self.generate_code_verifier() + code_challenge: str = self.generate_code_challenge(code_verifier) + auth_code: str = self.fetch_authorization_code(code_challenge, register_client_result) + access_token = self.fetch_access_token(register_client_result, auth_code, code_verifier) + + return access_token + + except InterfaceError as e: + raise + except Exception as e: + _logger.debug("An error occurred while trying to obtain an IdC token : {}".format(str(e))) + raise InterfaceError("There was an error during authentication.") + + def register_client(self: "BrowserIdcAuthPlugin") -> typing.Dict[str, typing.Any]: + """ + Registers the client with IdC. + :param client_type: str + The client type to be used for registering the client. + :return: dict + The register client result from IdC + """ + _logger.debug("BrowserIdcAuthPlugin.register_client") + register_client_cache_key: str = f"{self.idc_client_display_name}:{self.idc_region}:{self.listen_port}" + + if ( + register_client_cache_key in self.register_client_cache + and self.register_client_cache[register_client_cache_key]["clientSecretExpiresAt"] > time.time() + ): + _logger.debug( + "Valid registerClient result found from cache with expiration time: {}".format( + str(self.register_client_cache[register_client_cache_key]["clientSecretExpiresAt"]) + ) + ) + return self.register_client_cache[register_client_cache_key] + + try: + register_client_result: typing.Dict[str, typing.Any] = self.sso_oidc_client.register_client( + clientName=self.idc_client_display_name, + clientType=self.CLIENT_TYPE, + scopes=[self.REDSHIFT_IDC_CONNECT_SCOPE], + issuerUrl=self.issuer_url, + redirectUris=[self.redirect_uri], + grantTypes=[self.AUTH_CODE_GRANT_TYPE], + ) + self.register_client_cache[register_client_cache_key] = register_client_result + _logger.debug( + "Added entry to client cache with expiry: {}".format( + str(register_client_result["clientSecretExpiresAt"]) + ) + ) + return register_client_result + except ClientError as e: + raise InterfaceError("IdC authentication failed : Error registering client with IdC.") + + def generate_code_verifier(self: "BrowserIdcAuthPlugin") -> str: + """ + Generates a random code verifier of length 60. + :return: str + Returns the generated code verifier. + """ + _logger.debug("BrowserIdcAuthPlugin.generate_code_verifier") + random_bytes = os.urandom(self.CODE_VERIFIER_LENGTH) + base64_encoded = base64.urlsafe_b64encode(random_bytes).decode("utf-8") + base64_encoded_no_newline = base64_encoded.replace("\n", "") + code_verifier = base64_encoded_no_newline.replace("=", "") + return code_verifier + + def generate_code_challenge(self: "BrowserIdcAuthPlugin", code_verifier: str) -> str: + """ + Generates a random code verifier + :param code_verifier: str + The code_verifier is used to generate the code_challenge. + :return: dict + Returns the generated base64 encoded code challenge. + """ + _logger.debug("BrowserIdcAuthPlugin.generate_code_challenge") + sha256_hash = hashlib.sha256(code_verifier.encode("ascii")).digest() + code_challenge = base64.urlsafe_b64encode(sha256_hash).rstrip(b"=").decode("ascii") + return code_challenge + + def fetch_authorization_code( + self: "BrowserIdcAuthPlugin", code_challenge: str, register_client_result: typing.Dict[str, typing.Any] + ) -> str: + """ + Fetches IdC authorization code using the default browser. + :param code_challenge: str + The generated code challenge. + :param register_client_result: dict + The register client result from IdC. + :return: str + The IdC authorization code obtained from the browser. + """ + state = self.generate_random_state() + listen_socket: socket.socket = self.get_listen_socket(self.listen_port) + + try: + listen_socket.settimeout(float(self.idp_response_timeout)) + server_thread = threading.Thread(target=self.run_server, args=(listen_socket, state)) + server_thread.start() + + self.open_browser(state, register_client_result["clientId"], code_challenge) + + server_thread.join() + + return str(self.auth_code) + except socket.timeout: + raise InterfaceError("IdC authentication failed : Timeout while retrieving authorization code.") + except Exception as e: + raise e + finally: + listen_socket.close() + + def fetch_access_token( + self: "BrowserIdcAuthPlugin", + register_client_result: typing.Dict[str, typing.Any], + auth_code: str, + code_verifier: str, + ) -> str: + """ + Fetches IdC access token using SSO OIDC APIs. + :param register_client_result: dict + The register client result from IdC. + :param auth_code: str + The authorization code result from IdC. + :param grant_type: str + The grant type to be used for fetch IdC access token. + :return: str + The IdC access token obtained from fetching IdC access token. + :raises InterfaceError: Raised when the IdC access token is not fetched successfully. + """ + _logger.debug("BrowserIdcAuthPlugin.fetch_access_token") + polling_end_time: float = time.time() + self.idp_response_timeout + polling_interval_in_sec: int = self.CREATE_TOKEN_INTERVAL + + while time.time() < polling_end_time: + try: + _logger.debug("Calling IdC method create_token") + response: typing.Dict[str, typing.Any] = self.sso_oidc_client.create_token( + clientId=register_client_result["clientId"], + clientSecret=register_client_result["clientSecret"], + code=auth_code, + grantType=self.AUTH_CODE_GRANT_TYPE, + codeVerifier=code_verifier, + redirectUri=self.redirect_uri, + ) + if not response["accessToken"]: + raise InterfaceError("IdC authentication failed : The credential token couldn't be created.") + _logger.debug("Length of IdC accessToken: {}".format(len(response["accessToken"]))) + return response["accessToken"] + except ClientError as e: + if e.response["Error"]["Code"] == "AuthorizationPendingException": + _logger.debug("Browser authorization pending from user") + time.sleep(polling_interval_in_sec) + else: + raise InterfaceError( + "IdC authentication failed : Unexpected error occured while fetching access token." + ) + + raise InterfaceError("IdC authentication failed : The request timed out. Authentication wasn't completed.") + + def generate_random_state(self: "BrowserIdcAuthPlugin") -> str: + random_bytes = os.urandom(self.STATE_LENGTH) + random_state = base64.urlsafe_b64encode(random_bytes).decode("utf-8").rstrip("=") + return random_state + + def run_server( + self: "BrowserIdcAuthPlugin", + listen_socket: socket.socket, + state: str, + ): + """ + Runs a server on localhost to listen for the IdC's response with authorization code. + :param listen_socket: socket.socket + The socket on which the method listens for a response + :param idp_response_timeout: int + The maximum time to listen on the socket, specified in seconds + :param state: str + The state generated by the client. This must match the state received from the IdC server + :return: None + """ + conn, addr = listen_socket.accept() + size: int = 102400 + with conn: + while True: + part: bytes = conn.recv(size) + decoded_part = part.decode() + state_idx: int = decoded_part.find( + "{}=".format(BrowserIdcAuthPlugin.OAuthParamNames.STATE_PARAMETER_NAME.value) + ) + + if state_idx > -1: + received_state: str = decoded_part[state_idx + 6 : decoded_part.find("&", state_idx)] + parsed_state: str = received_state[: received_state.find(" ")] + + if parsed_state != state: + exec_msg = "Incoming state {received} does not match the outgoing state {expected}".format( + received=parsed_state, expected=state + ) + _logger.debug(exec_msg) + raise InterfaceError(exec_msg) + + code_idx: int = decoded_part.find( + "{}=".format(BrowserIdcAuthPlugin.OAuthParamNames.AUTH_CODE_PARAMETER_NAME.value) + ) + + if code_idx < 0: + _logger.debug("No authorization code found") + raise InterfaceError("No authorization code found") + received_code: str = decoded_part[code_idx + 5 : state_idx - 1] + + if received_code == "": + _logger.debug("No valid authorization code found") + raise InterfaceError("No valid authorization code found") + conn.send(self.close_window_http_resp()) + self.auth_code = received_code + return + + def open_browser(self: "BrowserIdcAuthPlugin", state: str, client_id: str, code_challenge: str) -> None: + """ + Opens the default browser to allow user authentication with IdC + :param state: str + The state generated by the client + :return: None. + """ + url: str = self.get_authorization_token_url(state, client_id, code_challenge) + + if url is None: + BrowserIdcAuthPlugin.handle_missing_required_property("issuer_url") + self.validate_url(url) + + _logger.debug("Authorization code request URI: {}".format(url)) + + try: + webbrowser.open(url) + except: + _logger.debug("Unable to open the browser. Webbrowser environment is not supported") + + def get_authorization_token_url( + self: "BrowserIdcAuthPlugin", state: str, client_id: str, code_challenge: str + ) -> str: + """ + Returns a URL used for requesting authentication token from IdC + """ + _logger.debug("BrowserIdcAuthPlugin.get_authorization_token_url") + + params: typing.Dict[str, str] = { + BrowserIdcAuthPlugin.OAuthParamNames.RESPONSE_TYPE_PARAMETER_NAME.value: "code", + BrowserIdcAuthPlugin.OAuthParamNames.CLIENT_ID_PARAMETER_NAME.value: client_id, + BrowserIdcAuthPlugin.OAuthParamNames.REDIRECT_PARAMETER_NAME.value: str(self.redirect_uri), + BrowserIdcAuthPlugin.OAuthParamNames.STATE_PARAMETER_NAME.value: state, + BrowserIdcAuthPlugin.OAuthParamNames.SCOPE_PARAMETER_NAME.value: self.REDSHIFT_IDC_CONNECT_SCOPE, + BrowserIdcAuthPlugin.OAuthParamNames.CODE_CHALLENGE_PARAMETER_NAME.value: code_challenge, + BrowserIdcAuthPlugin.OAuthParamNames.CHALLENGE_METHOD_PARAMETER_NAME.value: self.CHALLENGE_METHOD, + } + + encoded_params: str = urlencode(params) + idc_host = self.OIDC_SCHEMA + "." + str(self.idc_region) + "." + self.AMAZON_COM_SCHEMA + + return urlunsplit( + ( + self.CURRENT_INTERACTION_SCHEMA, + idc_host, + self.AUTHORIZE_ENDPOINT, + encoded_params, + "", + ) + ) + + def get_listen_socket(self: "BrowserIdcAuthPlugin", listen_port: int) -> socket.socket: + """ + Returns a listen socket used for user authentication + """ + s: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + _logger.debug("Attempting socket bind on port {}".format(str(listen_port))) + s.bind(("127.0.0.1", listen_port)) + s.listen() + _logger.debug("Socket bound to port {}".format(s.getsockname()[1])) + return s diff --git a/redshift_connector/plugin/common_credentials_provider.py b/redshift_connector/plugin/common_credentials_provider.py index a3a6f3a..0a98a22 100644 --- a/redshift_connector/plugin/common_credentials_provider.py +++ b/redshift_connector/plugin/common_credentials_provider.py @@ -85,3 +85,6 @@ def set_group_federation(self: "CommonCredentialsProvider", group_federation: bo def get_sub_type(self: "CommonCredentialsProvider") -> int: return IamHelper.IDC_PLUGIN + + def get_cache_key(self: "CommonCredentialsProvider") -> str: + return "" diff --git a/redshift_connector/redshift_property.py b/redshift_connector/redshift_property.py index 0b35678..271e784 100644 --- a/redshift_connector/redshift_property.py +++ b/redshift_connector/redshift_property.py @@ -60,6 +60,8 @@ def __init__(self: "RedshiftProperty", **kwargs): self.host: str = "" self.iam: bool = False self.iam_disable_cache: bool = False + self.idc_client_display_name: typing.Optional[str] = None + self.idc_region: typing.Optional[str] = None self.identity_namespace: typing.Optional[str] = None # The IdP (identity provider) host you are using to authenticate into Redshift. self.idp_host: typing.Optional[str] = None @@ -69,6 +71,7 @@ def __init__(self: "RedshiftProperty", **kwargs): self.idp_tenant: typing.Optional[str] = None # The port used by an IdP (identity provider). self.idpPort: int = 443 + self.issuer_url: typing.Optional[str] = None self.listen_port: int = 7890 # property for specifying loginToRp used by AdfsCredentialsProvider self.login_to_rp: str = "urn:amazon:webservices" diff --git a/redshift_connector/utils/logging_utils.py b/redshift_connector/utils/logging_utils.py index e1f1e72..195f2a9 100644 --- a/redshift_connector/utils/logging_utils.py +++ b/redshift_connector/utils/logging_utils.py @@ -37,11 +37,14 @@ def mask_secure_info_in_props(info: "RedshiftProperty") -> "RedshiftProperty": "host", "iam", "iam_disable_cache", + "idc_client_display_name", + "idc_region", "identity_namespace", "idp_host", "idpPort", "idp_response_timeout", "idp_tenant", + "issuer_url", "is_serverless", "listen_port", "login_url", diff --git a/redshift_connector/version.py b/redshift_connector/version.py index a940811..c0c559a 100644 --- a/redshift_connector/version.py +++ b/redshift_connector/version.py @@ -2,4 +2,4 @@ # 1) we don't load dependencies by storing it in __init__.py # 2) we can import it in setup.py for the same reason # 3) we can import it into your module module -__version__ = "2.1.2" +__version__ = "2.1.3" diff --git a/setup.py b/setup.py index bdf4b8a..86b99fa 100644 --- a/setup.py +++ b/setup.py @@ -128,9 +128,7 @@ def get_tag(self): ], keywords="redshift dbapi", include_package_data=True, - package_data={ - "redshift-connector": ["*.py", "*.crt", "LICENSE", "NOTICE", "py.typed"] - }, + package_data={"redshift-connector": ["*.py", "*.crt", "LICENSE", "NOTICE", "py.typed"]}, packages=find_packages(exclude=["test*"]), cmdclass=custom_cmds, ) diff --git a/test/__init__.py b/test/__init__.py index cc8b14c..f565369 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -10,5 +10,6 @@ okta_browser_idp, okta_idp, ping_browser_idp, + redshift_browser_idc, redshift_idp_token_auth_plugin, ) diff --git a/test/conftest.py b/test/conftest.py index a323516..f58dc7a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -311,6 +311,7 @@ def redshift_native_browser_azure_oauth2_idp() -> typing.Dict[str, typing.Union[ } return db_connect + @pytest.fixture(scope="class") def redshift_idp_token_auth_plugin() -> typing.Dict[str, typing.Optional[str]]: db_connect = { @@ -327,6 +328,26 @@ def redshift_idp_token_auth_plugin() -> typing.Dict[str, typing.Optional[str]]: return db_connect +@pytest.fixture(scope="class") +def redshift_browser_idc() -> typing.Dict[str, typing.Union[str, typing.Optional[bool], int]]: + db_connect = { + "host": conf.get("redshift-browser-idc", "host", fallback=None), + "region": conf.get("redshift-browser-idc", "region", fallback=None), + "database": conf.get("redshift-browser-idc", "database", fallback="dev"), + "credentials_provider": conf.get( + "redshift-browser-idc", "credentials_provider", fallback="BrowserIdcAuthPlugin" + ), + "issuer_url": conf.get("redshift-browser-idc", "issuer_url", fallback=None), + "idc_region": conf.get("redshift-browser-idc", "idc_region", fallback=None), + "idp_response_timeout": conf.getint("redshift-browser-idc", "idp_response_timeout", fallback=120), + "listen_port": conf.get("redshift-browser-idc", "listen_port", fallback=7890), + "idc_client_display_name": conf.get( + "redshift-browser-idc", "idc_client_display_name", fallback="Amazon Redshift Python connector" + ), + } + return db_connect + + @pytest.fixture def con(request, db_kwargs) -> redshift_connector.Connection: conn: redshift_connector.Connection = redshift_connector.connect(**db_kwargs) diff --git a/test/integration/__init__.py b/test/integration/__init__.py index ed87c66..3d6dba5 100644 --- a/test/integration/__init__.py +++ b/test/integration/__init__.py @@ -10,6 +10,7 @@ okta_browser_idp, okta_idp, ping_browser_idp, + redshift_browser_idc, redshift_idp_token_auth_plugin, redshift_native_browser_azure_oauth2_idp, ) diff --git a/test/unit/plugin/test_browser_idc_auth_plugin.py b/test/unit/plugin/test_browser_idc_auth_plugin.py new file mode 100644 index 0000000..71e3f7c --- /dev/null +++ b/test/unit/plugin/test_browser_idc_auth_plugin.py @@ -0,0 +1,294 @@ +import socket +import time +import typing +from unittest.mock import MagicMock + +import pytest +from botocore.exceptions import ClientError +from pytest_mock import mocker + +from redshift_connector.error import InterfaceError +from redshift_connector.plugin.browser_idc_auth_plugin import BrowserIdcAuthPlugin +from redshift_connector.redshift_property import RedshiftProperty + + +def make_valid_browser_idc_provider() -> typing.Tuple[BrowserIdcAuthPlugin, RedshiftProperty]: + rp: RedshiftProperty = RedshiftProperty() + rp.idc_region = "some_region" + rp.issuer_url = "some_url" + rp.idp_response_timeout = 100 + rp.listen_port = 8000 + cp: BrowserIdcAuthPlugin = BrowserIdcAuthPlugin() + cp.add_parameter(rp) + return cp, rp + + +def valid_browser_without_optional_parameter() -> typing.Tuple[BrowserIdcAuthPlugin, RedshiftProperty]: + rp: RedshiftProperty = RedshiftProperty() + rp.idc_region = "some_region" + rp.issuer_url = "some_url" + cp: BrowserIdcAuthPlugin = BrowserIdcAuthPlugin() + cp.add_parameter(rp) + return cp, rp + + +def test_add_parameter_sets_browser_idc_specific(): + idc_credentials_provider, rp = make_valid_browser_idc_provider() + assert idc_credentials_provider.idc_region == rp.idc_region + assert idc_credentials_provider.issuer_url == rp.issuer_url + assert idc_credentials_provider.idp_response_timeout == rp.idp_response_timeout + assert idc_credentials_provider.listen_port == rp.listen_port + + +def test_add_parameter_sets_default(): + idc_credentials_provider, rp = valid_browser_without_optional_parameter() + assert idc_credentials_provider.idp_response_timeout == 120 + assert idc_credentials_provider.listen_port == 7890 + assert idc_credentials_provider.idc_client_display_name == "Amazon Redshift Python connector" + + +@pytest.mark.parametrize("value", [None, ""]) +def test_check_required_parameters_raises_if_issuer_url_missing(value): + idc_credentials_provider, _ = make_valid_browser_idc_provider() + idc_credentials_provider.issuer_url = value + + with pytest.raises( + InterfaceError, match="IdC authentication failed: The issuer_url must be included in the connection parameters." + ): + idc_credentials_provider.get_auth_token() + + +@pytest.mark.parametrize("value", [None, ""]) +def test_check_required_parameters_raises_if_idc_region_missing(value): + idc_credentials_provider, _ = make_valid_browser_idc_provider() + idc_credentials_provider.idc_region = value + + with pytest.raises( + InterfaceError, match="IdC authentication failed: The idc_region must be included in the connection parameters." + ): + idc_credentials_provider.get_auth_token() + + +def test_valid_register_client(): + idc_credentials_provider, rp = make_valid_browser_idc_provider() + mocked_register_client_result: typing.Dict[str, typing.Any] = { + "clientId": "mockedClientId", + "clientSecret": "mockedClientSecret", + "clientSecretExpiresAt": time.time() + 60, + } + + mocked_boto_client = MagicMock() + mocked_boto_client.register_client.return_value = mocked_register_client_result + + idc_credentials_provider.sso_oidc_client = mocked_boto_client + + register_client_result = idc_credentials_provider.register_client() + + assert register_client_result == mocked_register_client_result + + +def test_register_client_interface_exception(): + idc_credentials_provider, rp = make_valid_browser_idc_provider() + + mocked_boto_client = MagicMock() + error_response = { + "Error": {"Code": "400", "Message": "IdC authentication failed : Error registering client with IdC."} + } + + operation_name = "RegisterClient" + mocked_boto_client.register_client.side_effect = ClientError(error_response, operation_name) + idc_credentials_provider.sso_oidc_client = mocked_boto_client + + with pytest.raises(InterfaceError, match="IdC authentication failed : Error registering client with IdC."): + idc_credentials_provider.register_client() + + +def test_register_client_cache(): + idc_cred, rp = make_valid_browser_idc_provider() + mocked_register_client_result: typing.Dict[str, typing.Any] = { + "clientId": "mockedClientId", + "clientSecret": "mockedClientSecret", + "clientSecretExpiresAt": time.time() + 60, + } + mocked_cache_key: str = f"{idc_cred.idc_client_display_name}:{idc_cred.idc_region}:{idc_cred.listen_port}" + mocked_cache: typing.Dict[str, dict] = { + mocked_cache_key: mocked_register_client_result, + } + + idc_cred.register_client_cache = mocked_cache + + register_client_result = idc_cred.register_client() + + assert register_client_result == mocked_register_client_result + + +def test_register_client_cache_expired(): + idc_cred, rp = make_valid_browser_idc_provider() + mocked_client_expired_result: typing.Dict[str, typing.Any] = { + "clientId": "expiredClientId", + "clientSecret": "expiredClientSecret", + "clientSecretExpiresAt": time.time(), + } + mocked_cache_key: str = f"{idc_cred.idc_client_display_name}:{idc_cred.idc_region}:{idc_cred.listen_port}" + mocked_cache: typing.Dict[str, dict] = { + mocked_cache_key: mocked_client_expired_result, + } + idc_cred.register_client_cache = mocked_cache + + mocked_client_result: typing.Dict[str, typing.Any] = { + "clientId": "mockedClientId", + "clientSecret": "mockedClientSecret", + "clientSecretExpiresAt": time.time() + 60, + } + + mocked_boto_client = MagicMock() + mocked_boto_client.register_client.return_value = mocked_client_result + idc_cred.sso_oidc_client = mocked_boto_client + register_client_result = idc_cred.register_client() + + assert register_client_result["clientId"] == "mockedClientId" + + +def test_register_client_exception_handling(mocker): + idc_credentials_provider, rp = make_valid_browser_idc_provider() + + mocker.patch.object(idc_credentials_provider, "register_client", side_effect=Exception("Some error")) + + with pytest.raises(InterfaceError): + idc_credentials_provider.get_auth_token() + + +def test_fetch_authorization_code_exception_handling(mocker): + idc_credentials_provider, rp = make_valid_browser_idc_provider() + mocked_register_client_result: typing.Dict[str, typing.Any] = { + "clientId": "mockedClientId", + "clientSecret": "mockedClientSecret", + } + + mocker.patch.object(idc_credentials_provider, "register_client", return_value=mocked_register_client_result) + mocker.patch.object(idc_credentials_provider, "fetch_authorization_code", side_effect=Exception("Some error")) + + with pytest.raises(InterfaceError): + idc_credentials_provider.get_auth_token() + + +def test_fetch_authorization_code_exception(mocker): + idc_credentials_provider, rp = make_valid_browser_idc_provider() + mocked_register_client_result: typing.Dict[str, typing.Any] = { + "clientId": "mockedClientId", + "clientSecret": "mockedClientSecret", + } + + mocker.patch.object(idc_credentials_provider, "register_client", return_value=mocked_register_client_result) + mocker.patch.object(idc_credentials_provider, "fetch_authorization_code", side_effect=Exception("Some error")) + + with pytest.raises(Exception): + idc_credentials_provider.fetch_authorization_code() + + +def test_valid_fetch_access_token(): + idc_credentials_provider, rp = make_valid_browser_idc_provider() + idc_credentials_provider.redirect_uri = "http://127.0.0.1:8000" + mocked_register_client_result: typing.Dict[str, typing.Any] = { + "clientId": "mockedClientId", + "clientSecret": "mockedClientSecret", + } + mocked_access_token_result: typing.Dict[str, typing.Any] = { + "accessToken": "validAccessToken", + } + mocked_auth_code: str = "validAuthCode" + mocked_verifier: str = "validVerifier" + expected_access_token = "validAccessToken" + + mocked_boto_client = MagicMock() + mocked_boto_client.create_token.return_value = mocked_access_token_result + + idc_credentials_provider.sso_oidc_client = mocked_boto_client + + accessToken = idc_credentials_provider.fetch_access_token( + mocked_register_client_result, mocked_auth_code, mocked_verifier + ) + + assert accessToken == expected_access_token + + +def test_fetch_access_token_exception_handling(mocker): + idc_credentials_provider, rp = make_valid_browser_idc_provider() + mocked_register_client_result: typing.Dict[str, typing.Any] = { + "clientId": "mockedClientId", + "clientSecret": "mockedClientSecret", + } + mocked_fetch_authorization_code_result: str = {"mockedAuthCode"} + + mocker.patch.object(idc_credentials_provider, "register_client", return_value=mocked_register_client_result) + mocker.patch.object( + idc_credentials_provider, "fetch_authorization_code", return_value=mocked_fetch_authorization_code_result + ) + mocker.patch.object(idc_credentials_provider, "fetch_access_token", side_effect=Exception("Unexpected error")) + + with pytest.raises(InterfaceError): + idc_credentials_provider.get_auth_token() + + +def test_get_auth_token_fetches_idc_token(mocker): + # Mock the dependencies and their return values + idc_credentials_provider, rp = make_valid_browser_idc_provider() + + mocked_register_client_result: typing.Dict[str, typing.Any] = { + "clientId": "mockedClientId", + "clientSecret": "mockedClientSecret", + } + mocked_fetch_authorization_code_result: str = {"mockedAuthCode"} + expected_idc_token: str = "mockedAccessToken" + + mocker.patch("boto3.client") # Mocking boto3.client + + # Mocking the response of internal methods + mocker.patch.object(idc_credentials_provider, "register_client", return_value=mocked_register_client_result) + mocker.patch.object( + idc_credentials_provider, "fetch_authorization_code", return_value=mocked_fetch_authorization_code_result + ) + mocker.patch.object(idc_credentials_provider, "fetch_access_token", return_value=expected_idc_token) + + # Call the method under test + test_result_idc_token: str = idc_credentials_provider.get_auth_token() + + assert test_result_idc_token == expected_idc_token + + +def test_authorization_token_url(): + idc_credentials_provider, rp = make_valid_browser_idc_provider() + mocked_state: str = "mockedState" + mocked_client_id: str = "mockedClientId" + mocked_code_challenge: str = "mockedCodeChallenge" + expected_url = "https://oidc.some_region.amazonaws.com/authorize?response_type=code&client_id=mockedClientId&redirect_uri=None&state=mockedState&scopes=redshift%3Aconnect&code_challenge=mockedCodeChallenge&code_challenge_method=S256" + + url: str = idc_credentials_provider.get_authorization_token_url( + mocked_state, mocked_client_id, mocked_code_challenge + ) + assert url == expected_url + + +def test_generate_random_state(): + idc_credentials_provider, rp = make_valid_browser_idc_provider() + state: str = idc_credentials_provider.generate_random_state() + + assert len(state) == 14 + + +def test_get_listen_socket(): + idc_credentials_provider, rp = make_valid_browser_idc_provider() + mocked_port: str = 8000 + expected_socket = "('127.0.0.1', 8000)" + + listen_socket: socket.socket = idc_credentials_provider.get_listen_socket(mocked_port) + assert str(listen_socket.getsockname()) == expected_socket + + +def test_open_browser(): + idc_credentials_provider, rp = make_valid_browser_idc_provider() + mocked_port: str = 8000 + expected_socket = "('127.0.0.1', 8000)" + + listen_socket: socket.socket = idc_credentials_provider.get_listen_socket(mocked_port) + assert str(listen_socket.getsockname()) == expected_socket diff --git a/test/unit/plugin/test_credentials_providers.py b/test/unit/plugin/test_credentials_providers.py index b36ce81..743e0b4 100644 --- a/test/unit/plugin/test_credentials_providers.py +++ b/test/unit/plugin/test_credentials_providers.py @@ -12,6 +12,7 @@ okta_browser_idp, okta_idp, ping_browser_idp, + redshift_browser_idc, redshift_idp_token_auth_plugin, )