Skip to content

Commit 3e5387e

Browse files
authored
PYTHON-4539 Add SSLContext async wrap_socket support (#1740)
1 parent 1053931 commit 3e5387e

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

pymongo/asynchronous/pool.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
from __future__ import annotations
1616

17+
import asyncio
1718
import collections
1819
import contextlib
20+
import functools
1921
import logging
2022
import os
2123
import socket
@@ -876,12 +878,23 @@ async def _configured_socket(
876878
if _IS_SYNC:
877879
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host)
878880
else:
879-
ssl_sock = await ssl_context.a_wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc]
881+
if hasattr(ssl_context, "a_wrap_socket"):
882+
ssl_sock = await ssl_context.a_wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc]
883+
else:
884+
loop = asyncio.get_running_loop()
885+
ssl_sock = await loop.run_in_executor(
886+
None,
887+
functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc]
888+
)
880889
else:
881890
if _IS_SYNC:
882891
ssl_sock = ssl_context.wrap_socket(sock)
883892
else:
884-
ssl_sock = await ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc]
893+
if hasattr(ssl_context, "a_wrap_socket"):
894+
ssl_sock = await ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc]
895+
else:
896+
loop = asyncio.get_running_loop()
897+
ssl_sock = await loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc]
885898
except _CertificateError:
886899
sock.close()
887900
# Raise _CertificateError directly like we do after match_hostname

pymongo/synchronous/pool.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
from __future__ import annotations
1616

17+
import asyncio
1718
import collections
1819
import contextlib
20+
import functools
1921
import logging
2022
import os
2123
import socket
@@ -872,12 +874,23 @@ def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.
872874
if _IS_SYNC:
873875
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host)
874876
else:
875-
ssl_sock = ssl_context.a_wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc]
877+
if hasattr(ssl_context, "a_wrap_socket"):
878+
ssl_sock = ssl_context.a_wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc]
879+
else:
880+
loop = asyncio.get_running_loop()
881+
ssl_sock = loop.run_in_executor(
882+
None,
883+
functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc]
884+
)
876885
else:
877886
if _IS_SYNC:
878887
ssl_sock = ssl_context.wrap_socket(sock)
879888
else:
880-
ssl_sock = ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc]
889+
if hasattr(ssl_context, "a_wrap_socket"):
890+
ssl_sock = ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc]
891+
else:
892+
loop = asyncio.get_running_loop()
893+
ssl_sock = loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc]
881894
except _CertificateError:
882895
sock.close()
883896
# Raise _CertificateError directly like we do after match_hostname

0 commit comments

Comments
 (0)