diff --git a/.mypy.ini b/.mypy.ini index 7a891e5a..5ba44640 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -12,6 +12,9 @@ ignore_missing_imports = True [mypy-pg8000] ignore_missing_imports = True +[mypy-asyncpg] +ignore_missing_imports = True + [mypy-pytds] ignore_missing_imports = True diff --git a/README.md b/README.md index b7a6829f..114a09a1 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ The Cloud SQL Python Connector is a package to be used alongside a database driv Currently supported drivers are: - [`pymysql`](https://github.com/PyMySQL/PyMySQL) (MySQL) - [`pg8000`](https://github.com/tlocke/pg8000) (PostgreSQL) + - [`asyncpg`](https://github.com/MagicStack/asyncpg) (PostgreSQL) - [`pytds`](https://github.com/denisenkom/pytds) (SQL Server) @@ -37,9 +38,16 @@ based on your database dialect. pip install "cloud-sql-python-connector[pymysql]" ``` ### Postgres +There are two different database drivers that are supported for the Postgres dialect: + +#### pg8000 ``` pip install "cloud-sql-python-connector[pg8000]" ``` +#### asyncpg +``` +pip install "cloud-sql-python-connector[asyncpg]" +``` ### SQL Server ``` pip install "cloud-sql-python-connector[pytds]" @@ -111,9 +119,9 @@ def getconn() -> pymysql.connections.Connection: conn: pymysql.connections.Connection = connector.connect( "project:region:instance", "pymysql", - user="root", - password="shhh", - db="your-db-name" + user="my-user", + password="my-password", + db="my-db-name" ) return conn @@ -188,9 +196,9 @@ def getconn() -> pymysql.connections.Connection: conn = connector.connect( "project:region:instance", "pymysql", - user="root", - password="shhh", - db="your-db-name" + user="my-user", + password="my-password", + db="my-db-name" ) return conn @@ -245,7 +253,7 @@ connector.connect( "project:region:instance", "pg8000", user="postgres-iam-user@gmail.com", - db="my_database", + db="my-db-name", enable_iam_auth=True, ) ``` @@ -258,7 +266,7 @@ Once you have followed the steps linked above, you can run the following code to connector.connect( "project:region:instance", "pytds", - db="my_database", + db="my-db-name", active_directory_auth=True, server_name="public.[instance].[location].[project].cloudsql.[domain]", ) @@ -268,13 +276,111 @@ Or, if using Private IP: connector.connect( "project:region:instance", "pytds", - db="my_database", + db="my-db-name", active_directory_auth=True, server_name="private.[instance].[location].[project].cloudsql.[domain]", ip_type=IPTypes.PRIVATE ) ``` +### Async Driver Usage +The Cloud SQL Connector is compatible with +[asyncio](https://docs.python.org/3/library/asyncio.html) to improve the speed +and efficiency of database connections through concurrency. You can use all +non-asyncio drivers through the `Connector.connect_async` function, in addition +to the following asyncio database drivers: +- [asyncpg](https://magicstack.github.io/asyncpg) (Postgres) + +The Cloud SQL Connector has a helper `create_async_connector` function that is +recommended for asyncio database connections. It returns a `Connector` +object that uses the current thread's running event loop. This is different +than `Connector()` which by default initializes a new event loop in a +background thread. + +The `create_async_connector` allows all the same input arguments as the +[Connector](#configuring-the-connector) object. + +Once a `Connector` object is returned by `create_async_connector` you can call +its `connect_async` method, just as you would the `connect` method: + +```python +import asyncpg +from google.cloud.sql.connector import create_async_connector + +async def main(): + # intialize Connector object using 'create_async_connector' + connector = await create_async_connector() + + # create connection to Cloud SQL database + conn: asyncpg.Connection = await connector.connect_async( + "project:region:instance", # Cloud SQL instance connection name + "asyncpg", + user="my-user", + password="my-password", + db="my-db-name" + # ... additional database driver args + ) + + # insert into Cloud SQL database (example) + await conn.execute("INSERT INTO ratings (title, genre, rating) VALUES ('Batman', 'Action', 8.2)") + + # query Cloud SQL database (example) + results = await conn.fetch("SELECT * from ratings") + for row in results: + # ... do something with results + + # close asyncpg connection + await conn.close + + # close Cloud SQL Connector + await connector.close_async() +``` + +For more details on interacting with an `asyncpg.Connection`, please visit +the [official documentation](https://magicstack.github.io/asyncpg/current/api/index.html). + +### Async Context Manager + +An alternative to using the `create_async_connector` function is initializing +a `Connector` as an async context manager, removing the need for explicit +calls to `connector.close_async()` to cleanup resources. + +**Note:** This alternative requires that the running event loop be +passed in as the `loop` argument to `Connector()`. + +```python +import asyncio +import asyncpg +from google.cloud.sql.connector import Connector + +async def main(): + # get current running event loop to be used with Connector + loop = asyncio.get_running_loop() + # intialize Connector object as async context manager + async with Connector(loop=loop) as connector: + + # create connection to Cloud SQL database + conn: asyncpg.Connection = await connector.connect_async( + "project:region:instance", # Cloud SQL instance connection name + "asyncpg", + user="my-user", + password="my-password", + db="my-db-name" + # ... additional database driver args + ) + + # insert into Cloud SQL database (example) + await conn.execute("INSERT INTO ratings (title, genre, rating) VALUES ('Batman', 'Action', 8.2)") + + # query Cloud SQL database (example) + results = await conn.fetch("SELECT * from ratings") + for row in results: + # ... do something with results + + # close asyncpg connection + await conn.close +``` + ## Support policy ### Major version lifecycle diff --git a/google/cloud/sql/connector/__init__.py b/google/cloud/sql/connector/__init__.py index 0e932169..527e177c 100644 --- a/google/cloud/sql/connector/__init__.py +++ b/google/cloud/sql/connector/__init__.py @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. """ -from .connector import Connector +from .connector import Connector, create_async_connector from .instance import IPTypes -__ALL__ = [Connector, IPTypes] +__ALL__ = [create_async_connector, Connector, IPTypes] try: import pkg_resources diff --git a/google/cloud/sql/connector/asyncpg.py b/google/cloud/sql/connector/asyncpg.py new file mode 100644 index 00000000..0e03c0a6 --- /dev/null +++ b/google/cloud/sql/connector/asyncpg.py @@ -0,0 +1,64 @@ +""" +Copyright 2022 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import ssl +from typing import Any, TYPE_CHECKING + +SERVER_PROXY_PORT = 3307 + +if TYPE_CHECKING: + import asyncpg + + +async def connect( + ip_address: str, ctx: ssl.SSLContext, **kwargs: Any +) -> "asyncpg.Connection": + """Helper function to create an asyncpg DB-API connection object. + + :type ip_address: str + :param ip_address: A string containing an IP address for the Cloud SQL + instance. + + :type ctx: ssl.SSLContext + :param ctx: An SSLContext object created from the Cloud SQL server CA + cert and ephemeral cert. + + :type kwargs: Any + :param kwargs: Keyword arguments for establishing asyncpg connection + object to Cloud SQL instance. + + :rtype: asyncpg.Connection + :returns: An asyncpg.Connection object to a Cloud SQL instance. + """ + try: + import asyncpg + except ImportError: + raise ImportError( + 'Unable to import module "asyncpg." Please install and try again.' + ) + user = kwargs.pop("user") + db = kwargs.pop("db") + passwd = kwargs.pop("password", None) + + return await asyncpg.connect( + user=user, + database=db, + password=passwd, + host=ip_address, + port=SERVER_PROXY_PORT, + ssl=ctx, + direct_tls=True, + **kwargs, + ) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index b78971e3..e6561713 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -14,7 +14,6 @@ limitations under the License. """ import asyncio -import concurrent import logging from types import TracebackType from google.cloud.sql.connector.instance import ( @@ -24,6 +23,7 @@ import google.cloud.sql.connector.pymysql as pymysql import google.cloud.sql.connector.pg8000 as pg8000 import google.cloud.sql.connector.pytds as pytds +import google.cloud.sql.connector.asyncpg as asyncpg from google.cloud.sql.connector.utils import generate_keys from google.auth.credentials import Credentials from threading import Thread @@ -32,6 +32,18 @@ logger = logging.getLogger(name=__name__) +ASYNC_DRIVERS = ["asyncpg"] + + +class ConnectorLoopError(Exception): + """ + Raised when Connector.connect is called with Connector._loop + in an invalid state (event loop in current thread). + """ + + def __init__(self, *args: Any) -> None: + super(ConnectorLoopError, self).__init__(self, *args) + class Connector: """A class to configure and create connections to Cloud SQL instances. @@ -53,6 +65,11 @@ class Connector: :param credentials Credentials object used to authenticate connections to Cloud SQL server. If not specified, Application Default Credentials are used. + + :type loop: asyncio.AbstractEventLoop + :param loop + Event loop to run asyncio tasks, if not specified, defaults to + creating new event loop on background thread. """ def __init__( @@ -61,13 +78,22 @@ def __init__( enable_iam_auth: bool = False, timeout: int = 30, credentials: Optional[Credentials] = None, + loop: asyncio.AbstractEventLoop = None, ) -> None: - self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() - self._thread: Thread = Thread(target=self._loop.run_forever, daemon=True) - self._thread.start() - self._keys: concurrent.futures.Future = asyncio.run_coroutine_threadsafe( - generate_keys(), self._loop - ) + # if event loop is given, use for background tasks + if loop: + self._loop: asyncio.AbstractEventLoop = loop + self._thread: Optional[Thread] = None + self._keys: asyncio.Future = loop.create_task(generate_keys()) + # if no event loop is given, spin up new loop in background thread + else: + self._loop = asyncio.new_event_loop() + self._thread = Thread(target=self._loop.run_forever, daemon=True) + self._thread.start() + self._keys = asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(generate_keys(), self._loop), + loop=self._loop, + ) self._instances: Dict[str, Instance] = {} # set default params for connections @@ -102,6 +128,18 @@ def connect( :returns: A DB-API connection to the specified Cloud SQL instance. """ + try: + # check if event loop is running in current thread + if self._loop == asyncio.get_running_loop(): + raise ConnectorLoopError( + "Connector event loop is running in current thread!" + "Event loop must be attached to a different thread to prevent blocking code!" + ) + # asyncio.get_running_loop will throw RunTimeError if no running loop is present + except RuntimeError: + pass + + # if event loop is not in current thread, proceed with connection connect_task = asyncio.run_coroutine_threadsafe( self.connect_async(instance_connection_string, driver, **kwargs), self._loop ) @@ -123,7 +161,7 @@ async def connect_async( :type driver: str :param: driver: A string representing the driver to connect with. Supported drivers are - pymysql, pg8000, and pytds. + pymysql, pg8000, asyncpg, and pytds. :param kwargs: Pass in any driver-specific arguments needed to connect to the Cloud @@ -133,7 +171,6 @@ async def connect_async( :returns: A DB-API connection to the specified Cloud SQL instance. """ - # Create an Instance object from the connection string. # The Instance should verify arguments. # @@ -164,6 +201,7 @@ async def connect_async( connect_func = { "pymysql": pymysql.connect, "pg8000": pg8000.connect, + "asyncpg": asyncpg.connect, "pytds": pytds.connect, } @@ -194,6 +232,10 @@ async def connect_async( # helper function to wrap in timeout async def get_connection() -> Any: instance_data, ip_address = await instance.connect_info(ip_type) + # async drivers are unblocking and can be awaited directly + if driver in ASYNC_DRIVERS: + return await connector(ip_address, instance_data.context, **kwargs) + # synchronous drivers are blocking and run using executor connect_partial = partial( connector, ip_address, instance_data.context, **kwargs ) @@ -222,15 +264,70 @@ def __exit__( """Exit context manager by closing Connector""" self.close() + async def __aenter__(self) -> Any: + """Enter async context manager by returning Connector object""" + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + """Exit async context manager by closing Connector""" + await self.close_async() + def close(self) -> None: """Close Connector by stopping tasks and releasing resources.""" - close_future = asyncio.run_coroutine_threadsafe(self._close(), loop=self._loop) + close_future = asyncio.run_coroutine_threadsafe( + self.close_async(), loop=self._loop + ) # Will attempt to safely shut down tasks for 5s close_future.result(timeout=5) - async def _close(self) -> None: + async def close_async(self) -> None: """Helper function to cancel Instances' tasks and close aiohttp.ClientSession.""" await asyncio.gather( *[instance.close() for instance in self._instances.values()] ) + + +async def create_async_connector( + ip_type: IPTypes = IPTypes.PUBLIC, + enable_iam_auth: bool = False, + timeout: int = 30, + credentials: Optional[Credentials] = None, + loop: asyncio.AbstractEventLoop = None, +) -> Connector: + """ + Create Connector object for asyncio connections that can auto-detect + and use current thread's running event loop. + + :type ip_type: IPTypes + :param ip_type + The IP type (public or private) used to connect. IP types + can be either IPTypes.PUBLIC or IPTypes.PRIVATE. + + :type enable_iam_auth: bool + :param enable_iam_auth + Enables IAM based authentication (Postgres only). + + :type timeout: int + :param timeout + The time limit for a connection before raising a TimeoutError. + + :type credentials: google.auth.credentials.Credentials + :param credentials + Credentials object used to authenticate connections to Cloud SQL server. + If not specified, Application Default Credentials are used. + + :type loop: asyncio.AbstractEventLoop + :param loop + Event loop to run asyncio tasks, if not specified, defaults + to current thread's running event loop. + """ + # if no loop given, automatically detect running event loop + if loop is None: + loop = asyncio.get_running_loop() + return Connector(ip_type, enable_iam_auth, timeout, credentials, loop) diff --git a/google/cloud/sql/connector/instance.py b/google/cloud/sql/connector/instance.py index 718ba72b..a1282e14 100644 --- a/google/cloud/sql/connector/instance.py +++ b/google/cloud/sql/connector/instance.py @@ -28,7 +28,6 @@ # Importing libraries import asyncio import aiohttp -import concurrent import datetime from enum import Enum import google.auth @@ -39,7 +38,6 @@ from tempfile import TemporaryDirectory from typing import ( Any, - Awaitable, Dict, Optional, Tuple, @@ -210,7 +208,7 @@ def _client_session(self) -> aiohttp.ClientSession: return self.__client_session _credentials: Optional[Credentials] = None - _keys: Awaitable + _keys: asyncio.Future _instance_connection_string: str _user_agent_string: str @@ -227,7 +225,7 @@ def __init__( self, instance_connection_string: str, driver_name: str, - keys: concurrent.futures.Future, + keys: asyncio.Future, loop: asyncio.AbstractEventLoop, credentials: Optional[Credentials] = None, enable_iam_auth: bool = False, @@ -251,7 +249,7 @@ def __init__( self._user_agent_string = f"{APPLICATION_NAME}/{version}+{driver_name}" self._loop = loop - self._keys = asyncio.wrap_future(keys, loop=self._loop) + self._keys = keys # validate credentials type if not isinstance(credentials, Credentials) and credentials is not None: raise CredentialsTypeError( diff --git a/requirements-test.txt b/requirements-test.txt index 33e42470..5794edb8 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -15,5 +15,6 @@ types-mock==4.0.15 twine==4.0.1 PyMySQL==1.0.2 pg8000==1.29.1 +asyncpg==0.26.0 python-tds==1.11.0 aioresponses==0.7.3 diff --git a/setup.py b/setup.py index 90b631dc..b803e8ef 100644 --- a/setup.py +++ b/setup.py @@ -79,7 +79,8 @@ extras_require={ "pymysql": ["PyMySQL==1.0.2"], "pg8000": ["pg8000==1.29.1"], - "pytds": ["python-tds==1.11.0"] + "pytds": ["python-tds==1.11.0"], + "asyncpg": ["asyncpg==0.26.0"] }, python_requires=">=3.7", include_package_data=True, diff --git a/tests/conftest.py b/tests/conftest.py index 45b6c8d2..11813b88 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,7 @@ import pytest # noqa F401 Needed to run the tests from threading import Thread -from typing import Any, Generator, AsyncGenerator +from typing import Any, Generator, AsyncGenerator, Tuple from google.auth.credentials import Credentials, with_scopes_if_required from google.oauth2 import service_account from aioresponses import aioresponses @@ -154,9 +154,9 @@ async def instance( Instance with mocked API calls. """ # generate client key pair - keys = asyncio.run_coroutine_threadsafe(generate_keys(), event_loop) - key_task = asyncio.wrap_future(keys, loop=event_loop) - _, client_key = await key_task + keys = event_loop.create_task(generate_keys()) + _, client_key = await keys + with patch("google.auth.default") as mock_auth: mock_auth.return_value = fake_credentials, None # mock Cloud SQL Admin API calls @@ -192,7 +192,21 @@ async def connector(fake_credentials: Credentials) -> AsyncGenerator[Connector, mock_auth.return_value = fake_credentials, None # mock Cloud SQL Admin API calls mock_instance = FakeCSQLInstance(project, region, instance_name) - _, client_key = connector._keys.result() + + async def wait_for_keys(future: asyncio.Future) -> Tuple[bytes, str]: + """ + Helper method to await keys of Connector in tests prior to + initializing an Instance object. + """ + return await future + + # converting asyncio.Future into concurrent.Future + # await keys in background thread so that .result() is set + # required because keys are needed for mocks, but are not awaited + # in the code until Instance() is initialized + _, client_key = asyncio.run_coroutine_threadsafe( + wait_for_keys(connector._keys), connector._loop + ).result() with aioresponses() as mocked: mocked.get( f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{project}/instances/{instance_name}/connectSettings", diff --git a/tests/system/test_asyncpg_connection.py b/tests/system/test_asyncpg_connection.py new file mode 100644 index 00000000..35a1de13 --- /dev/null +++ b/tests/system/test_asyncpg_connection.py @@ -0,0 +1,64 @@ +""" +Copyright 2022 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import uuid +from typing import AsyncGenerator + +import asyncpg +import pytest +from google.cloud.sql.connector import create_async_connector + +table_name = f"books_{uuid.uuid4().hex}" + + +@pytest.fixture(name="conn") +async def setup() -> AsyncGenerator: + # initialize Cloud SQL Python Connector object + connector = await create_async_connector() + conn: asyncpg.Connection = await connector.connect_async( + os.environ["POSTGRES_CONNECTION_NAME"], + "asyncpg", + user=os.environ["POSTGRES_USER"], + password=os.environ["POSTGRES_PASS"], + db=os.environ["POSTGRES_DB"], + ) + await conn.execute( + f"CREATE TABLE IF NOT EXISTS {table_name}" + " ( id CHAR(20) NOT NULL, title TEXT NOT NULL );" + ) + + yield conn + + await conn.execute(f"DROP TABLE IF EXISTS {table_name}") + # close asyncpg connection + await conn.close() + # cleanup Connector object + await connector.close_async() + + +@pytest.mark.asyncio +async def test_connection_with_asyncpg(conn: asyncpg.Connection) -> None: + await conn.execute( + f"INSERT INTO {table_name} (id, title) VALUES ('book1', 'Book One')" + ) + await conn.execute( + f"INSERT INTO {table_name} (id, title) VALUES ('book2', 'Book Two')" + ) + + rows = await conn.fetch(f"SELECT title FROM {table_name} ORDER BY ID") + titles = [row[0] for row in rows] + + assert titles == ["Book One", "Book Two"] diff --git a/tests/system/test_asyncpg_iam_auth.py b/tests/system/test_asyncpg_iam_auth.py new file mode 100644 index 00000000..1229f9f1 --- /dev/null +++ b/tests/system/test_asyncpg_iam_auth.py @@ -0,0 +1,64 @@ +""" +Copyright 2022 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import uuid +from typing import AsyncGenerator + +import asyncpg +import pytest +from google.cloud.sql.connector import create_async_connector + +table_name = f"books_{uuid.uuid4().hex}" + + +@pytest.fixture(name="conn") +async def setup() -> AsyncGenerator: + # initialize Cloud SQL Python Connector object + connector = await create_async_connector() + conn: asyncpg.Connection = await connector.connect_async( + os.environ["POSTGRES_IAM_CONNECTION_NAME"], + "asyncpg", + user=os.environ["POSTGRES_IAM_USER"], + db=os.environ["POSTGRES_DB"], + enable_iam_auth=True, + ) + await conn.execute( + f"CREATE TABLE IF NOT EXISTS {table_name}" + " ( id CHAR(20) NOT NULL, title TEXT NOT NULL );" + ) + + yield conn + + await conn.execute(f"DROP TABLE IF EXISTS {table_name}") + # close asyncpg connection + await conn.close() + # cleanup Connector object + await connector.close_async() + + +@pytest.mark.asyncio +async def test_connection_with_asyncpg_iam_auth(conn: asyncpg.Connection) -> None: + await conn.execute( + f"INSERT INTO {table_name} (id, title) VALUES ('book1', 'Book One')" + ) + await conn.execute( + f"INSERT INTO {table_name} (id, title) VALUES ('book2', 'Book Two')" + ) + + rows = await conn.fetch(f"SELECT title FROM {table_name} ORDER BY ID") + titles = [row[0] for row in rows] + + assert titles == ["Book One", "Book Two"] diff --git a/tests/system/test_connector_object.py b/tests/system/test_connector_object.py index d044a78d..909aa9c4 100644 --- a/tests/system/test_connector_object.py +++ b/tests/system/test_connector_object.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +import asyncio import os import pymysql import sqlalchemy @@ -21,6 +22,7 @@ from google.cloud.sql.connector import Connector import datetime import concurrent.futures +from threading import Thread def init_connection_engine( @@ -122,3 +124,20 @@ def test_connector_as_context_manager() -> None: with pool.connect() as conn: conn.execute("SELECT 1") + + +def test_connector_with_custom_loop() -> None: + """Test that Connector can be used with custom loop in background thread.""" + # create new event loop and start it in thread + loop = asyncio.new_event_loop() + thread = Thread(target=loop.run_forever, daemon=True) + thread.start() + + with Connector(loop=loop) as connector: + pool = init_connection_engine(connector) + + with pool.connect() as conn: + result = conn.execute("SELECT 1").fetchone() + assert result[0] == 1 + # assert that Connector does not start its own thread + assert connector._thread is None diff --git a/tests/unit/test_asyncpg.py b/tests/unit/test_asyncpg.py new file mode 100644 index 00000000..e08dc8ed --- /dev/null +++ b/tests/unit/test_asyncpg.py @@ -0,0 +1,34 @@ +""" +Copyright 2022 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import ssl +import pytest +from typing import Any +from mock import patch, AsyncMock + +from google.cloud.sql.connector.asyncpg import connect + + +@pytest.mark.asyncio +@patch("asyncpg.connect", new_callable=AsyncMock) +async def test_asyncpg(mock_connect: AsyncMock, kwargs: Any) -> None: + """Test to verify that asyncpg gets to proper connection call.""" + ip_addr = "0.0.0.0" + context = ssl.create_default_context() + mock_connect.return_value = True + connection = await connect(ip_addr, context, **kwargs) + assert connection is True + # verify that driver connection call would be made + assert mock_connect.assert_called_once diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index ea91a465..c48787f5 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -17,7 +17,8 @@ import pytest # noqa F401 Needed to run the tests import asyncio -from google.cloud.sql.connector import Connector, IPTypes +from google.cloud.sql.connector import Connector, IPTypes, create_async_connector +from google.cloud.sql.connector.connector import ConnectorLoopError from mock import patch from typing import Any @@ -77,6 +78,33 @@ def test_connect_enable_iam_auth_error() -> None: connector._instances = {} +def test_connect_with_unsupported_driver(connector: Connector) -> None: + # try to connect using unsupported driver, should raise KeyError + with pytest.raises(KeyError) as exc_info: + connector.connect( + "my-project:my-region:my-instance", + "bad_driver", + ) + # assert custom error message for unsupported driver is present + assert exc_info.value.args[0] == "Driver 'bad_driver' is not supported." + connector.close() + + +@pytest.mark.asyncio +async def test_connect_ConnectorLoopError() -> None: + """Test that ConnectorLoopError is thrown when Connector.connect + is called with event loop running in current thread.""" + current_loop = asyncio.get_running_loop() + connector = Connector(loop=current_loop) + # try to connect using current thread's loop, should raise error + pytest.raises( + ConnectorLoopError, + connector.connect, + "my-project:my-region:my-instance", + "pg8000", + ) + + def test_Connector_Init() -> None: """Test that Connector __init__ sets default properties properly.""" connector = Connector() @@ -87,6 +115,28 @@ def test_Connector_Init() -> None: connector.close() +def test_Connector_Init_context_manager() -> None: + """Test that Connector as context manager sets default properties properly.""" + with Connector() as connector: + assert connector._ip_type == IPTypes.PUBLIC + assert connector._enable_iam_auth is False + assert connector._timeout == 30 + assert connector._credentials is None + + +@pytest.mark.asyncio +async def test_Connector_Init_async_context_manager() -> None: + """Test that Connector as async context manager sets default properties + properly.""" + loop = asyncio.get_running_loop() + async with Connector(loop=loop) as connector: + assert connector._ip_type == IPTypes.PUBLIC + assert connector._enable_iam_auth is False + assert connector._timeout == 30 + assert connector._credentials is None + assert connector._loop == loop + + def test_Connector_connect(connector: Connector) -> None: """Test that Connector.connect can properly return a DB API connection.""" connect_string = "my-project:my-region:my-instance" @@ -98,3 +148,12 @@ def test_Connector_connect(connector: Connector) -> None: ) # verify connector made connection call assert connection is True + + +@pytest.mark.asyncio +async def test_create_async_connector() -> None: + """Test that create_async_connector properly initializes connector + object using current thread's event loop""" + connector = await create_async_connector() + assert connector._loop == asyncio.get_running_loop() + await connector.close_async() diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index 30b1d6fa..63cb6b09 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -50,7 +50,9 @@ async def test_Instance_init( """ connect_string = "test-project:test-region:test-instance" - keys = asyncio.run_coroutine_threadsafe(generate_keys(), event_loop) + keys = asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop + ) with patch("google.auth.default") as mock_auth: mock_auth.return_value = fake_credentials, None instance = Instance(connect_string, "pymysql", keys, event_loop) @@ -75,7 +77,9 @@ async def test_Instance_init_bad_credentials( throws proper error for bad credentials arg type. """ connect_string = "test-project:test-region:test-instance" - keys = asyncio.run_coroutine_threadsafe(generate_keys(), event_loop) + keys = asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop + ) with pytest.raises(CredentialsTypeError): instance = Instance(connect_string, "pymysql", keys, event_loop, credentials=1) await instance.close() @@ -356,7 +360,9 @@ async def test_ClientResponseError( Test that detailed error message is applied to ClientResponseError. """ # mock Cloud SQL Admin API calls with exceptions - keys = asyncio.run_coroutine_threadsafe(generate_keys(), event_loop) + keys = asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop + ) get_url = "https://sqladmin.googleapis.com/sql/v1beta4/projects/my-project/instances/my-instance/connectSettings" post_url = "https://sqladmin.googleapis.com/sql/v1beta4/projects/my-project/instances/my-instance:generateEphemeralCert" with aioresponses() as mocked: