Skip to content

Commit e4d8449

Browse files
authored
PYTHON-5021 - Fix usages of getaddrinfo to be non-blocking (#2059)
1 parent 8fa6750 commit e4d8449

File tree

7 files changed

+72
-16
lines changed

7 files changed

+72
-16
lines changed

pymongo/asynchronous/auth.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
_authenticate_oidc,
3939
_get_authenticator,
4040
)
41+
from pymongo.asynchronous.helpers import _getaddrinfo
4142
from pymongo.auth_shared import (
4243
MongoCredential,
4344
_authenticate_scram_start,
@@ -177,15 +178,22 @@ def _auth_key(nonce: str, username: str, password: str) -> str:
177178
return md5hash.hexdigest()
178179

179180

180-
def _canonicalize_hostname(hostname: str, option: str | bool) -> str:
181+
async def _canonicalize_hostname(hostname: str, option: str | bool) -> str:
181182
"""Canonicalize hostname following MIT-krb5 behavior."""
182183
# https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520
183184
if option in [False, "none"]:
184185
return hostname
185186

186-
af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
187-
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
188-
)[0]
187+
af, socktype, proto, canonname, sockaddr = (
188+
await _getaddrinfo(
189+
hostname,
190+
None,
191+
family=0,
192+
type=0,
193+
proto=socket.IPPROTO_TCP,
194+
flags=socket.AI_CANONNAME,
195+
)
196+
)[0] # type: ignore[index]
189197

190198
# For forward just to resolve the cname as dns.lookup() will not return it.
191199
if option == "forward":
@@ -213,7 +221,7 @@ async def _authenticate_gssapi(credentials: MongoCredential, conn: AsyncConnecti
213221
# Starting here and continuing through the while loop below - establish
214222
# the security context. See RFC 4752, Section 3.1, first paragraph.
215223
host = props.service_host or conn.address[0]
216-
host = _canonicalize_hostname(host, props.canonicalize_host_name)
224+
host = await _canonicalize_hostname(host, props.canonicalize_host_name)
217225
service = props.service_name + "@" + host
218226
if props.service_realm is not None:
219227
service = service + "@" + props.service_realm

pymongo/asynchronous/helpers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
"""Miscellaneous pieces that need to be synchronized."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import builtins
20+
import socket
1921
import sys
2022
from typing import (
2123
Any,
@@ -68,6 +70,24 @@ async def inner(*args: Any, **kwargs: Any) -> Any:
6870
return cast(F, inner)
6971

7072

73+
async def _getaddrinfo(
74+
host: Any, port: Any, **kwargs: Any
75+
) -> list[
76+
tuple[
77+
socket.AddressFamily,
78+
socket.SocketKind,
79+
int,
80+
str,
81+
tuple[str, int] | tuple[str, int, int, int],
82+
]
83+
]:
84+
if not _IS_SYNC:
85+
loop = asyncio.get_running_loop()
86+
return await loop.getaddrinfo(host, port, **kwargs) # type: ignore[return-value]
87+
else:
88+
return socket.getaddrinfo(host, port, **kwargs)
89+
90+
7191
if sys.version_info >= (3, 10):
7292
anext = builtins.anext
7393
aiter = builtins.aiter

pymongo/asynchronous/pool.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from bson import DEFAULT_CODEC_OPTIONS
4141
from pymongo import _csot, helpers_shared
4242
from pymongo.asynchronous.client_session import _validate_session_write_concern
43-
from pymongo.asynchronous.helpers import _handle_reauth
43+
from pymongo.asynchronous.helpers import _getaddrinfo, _handle_reauth
4444
from pymongo.asynchronous.network import command, receive_message
4545
from pymongo.common import (
4646
MAX_BSON_SIZE,
@@ -783,7 +783,7 @@ def __repr__(self) -> str:
783783
)
784784

785785

786-
def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
786+
async def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
787787
"""Given (host, port) and PoolOptions, connect and return a socket object.
788788
789789
Can raise socket.error.
@@ -814,7 +814,7 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket
814814
family = socket.AF_UNSPEC
815815

816816
err = None
817-
for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
817+
for res in await _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined]
818818
af, socktype, proto, dummy, sa = res
819819
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
820820
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
@@ -863,7 +863,7 @@ async def _configured_socket(
863863
864864
Sets socket's SSL and timeout options.
865865
"""
866-
sock = _create_connection(address, options)
866+
sock = await _create_connection(address, options)
867867
ssl_context = options._ssl_context
868868

869869
if ssl_context is None:

pymongo/synchronous/auth.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
_authenticate_oidc,
4646
_get_authenticator,
4747
)
48+
from pymongo.synchronous.helpers import _getaddrinfo
4849

4950
if TYPE_CHECKING:
5051
from pymongo.hello import Hello
@@ -180,9 +181,16 @@ def _canonicalize_hostname(hostname: str, option: str | bool) -> str:
180181
if option in [False, "none"]:
181182
return hostname
182183

183-
af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
184-
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
185-
)[0]
184+
af, socktype, proto, canonname, sockaddr = (
185+
_getaddrinfo(
186+
hostname,
187+
None,
188+
family=0,
189+
type=0,
190+
proto=socket.IPPROTO_TCP,
191+
flags=socket.AI_CANONNAME,
192+
)
193+
)[0] # type: ignore[index]
186194

187195
# For forward just to resolve the cname as dns.lookup() will not return it.
188196
if option == "forward":

pymongo/synchronous/helpers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
"""Miscellaneous pieces that need to be synchronized."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import builtins
20+
import socket
1921
import sys
2022
from typing import (
2123
Any,
@@ -68,6 +70,24 @@ def inner(*args: Any, **kwargs: Any) -> Any:
6870
return cast(F, inner)
6971

7072

73+
def _getaddrinfo(
74+
host: Any, port: Any, **kwargs: Any
75+
) -> list[
76+
tuple[
77+
socket.AddressFamily,
78+
socket.SocketKind,
79+
int,
80+
str,
81+
tuple[str, int] | tuple[str, int, int, int],
82+
]
83+
]:
84+
if not _IS_SYNC:
85+
loop = asyncio.get_running_loop()
86+
return loop.getaddrinfo(host, port, **kwargs) # type: ignore[return-value]
87+
else:
88+
return socket.getaddrinfo(host, port, **kwargs)
89+
90+
7191
if sys.version_info >= (3, 10):
7292
next = builtins.next
7393
iter = builtins.iter

pymongo/synchronous/pool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
from pymongo.socket_checker import SocketChecker
8585
from pymongo.ssl_support import HAS_SNI, SSLError
8686
from pymongo.synchronous.client_session import _validate_session_write_concern
87-
from pymongo.synchronous.helpers import _handle_reauth
87+
from pymongo.synchronous.helpers import _getaddrinfo, _handle_reauth
8888
from pymongo.synchronous.network import command, receive_message
8989

9090
if TYPE_CHECKING:
@@ -812,7 +812,7 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket
812812
family = socket.AF_UNSPEC
813813

814814
err = None
815-
for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
815+
for res in _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined]
816816
af, socktype, proto, dummy, sa = res
817817
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
818818
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4

test/asynchronous/test_auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,10 +275,10 @@ async def test_gssapi_threaded(self):
275275
async def test_gssapi_canonicalize_host_name(self):
276276
# Test the low level method.
277277
assert GSSAPI_HOST is not None
278-
result = _canonicalize_hostname(GSSAPI_HOST, "forward")
278+
result = await _canonicalize_hostname(GSSAPI_HOST, "forward")
279279
if "compute-1.amazonaws.com" not in result:
280280
self.assertEqual(result, GSSAPI_HOST)
281-
result = _canonicalize_hostname(GSSAPI_HOST, "forwardAndReverse")
281+
result = await _canonicalize_hostname(GSSAPI_HOST, "forwardAndReverse")
282282
self.assertEqual(result, GSSAPI_HOST)
283283

284284
# Use the equivalent named CANONICALIZE_HOST_NAME.

0 commit comments

Comments
 (0)