Skip to content

Commit d69b5f6

Browse files
committed
Async pyopenssl support
1 parent 7f71430 commit d69b5f6

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

pymongo/network_layer.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ async def _async_receive_ssl(
127127
) -> memoryview:
128128
mv = memoryview(bytearray(length))
129129
fd = conn.fileno()
130-
read = 0
130+
total_read = 0
131131

132132
def _is_ready(fut: Future) -> None:
133133
loop.remove_writer(fd)
@@ -136,11 +136,12 @@ def _is_ready(fut: Future) -> None:
136136
return
137137
fut.set_result(None)
138138

139-
while read < length:
139+
while total_read < length:
140140
try:
141-
read += conn.recv_into(mv[read:])
141+
read = conn.recv_into(mv[total_read:])
142142
if read == 0:
143143
raise OSError("connection closed")
144+
total_read += read
144145
except BLOCKING_IO_ERRORS as exc:
145146
fd = conn.fileno()
146147
# Check for closed socket.
@@ -228,15 +229,19 @@ async def async_receive_data(
228229
else:
229230
read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
230231
tasks = [read_task, cancellation_task]
231-
result = await asyncio.wait(tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED)
232-
if len(result[1]) == 2:
232+
done, pending = await asyncio.wait(
233+
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
234+
)
235+
for task in pending:
236+
task.cancel()
237+
if len(done) == 0:
233238
raise socket.timeout("timed out")
234-
finished = next(iter(result[0]))
235-
next(iter(result[1])).cancel()
236-
if finished == read_task:
237-
return finished.result() # type: ignore[return-value]
238-
else:
239-
raise _OperationCancelled("operation cancelled")
239+
for task in done:
240+
if task == read_task:
241+
return read_task.result()
242+
else:
243+
raise _OperationCancelled("operation cancelled")
244+
return None # type: ignore[return-value]
240245
finally:
241246
sock.settimeout(sock_timeout)
242247

pymongo/pyopenssl_context.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,16 @@ def _ragged_eof(exc: BaseException) -> bool:
105105
# https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets
106106
class _sslConn(_SSL.Connection):
107107
def __init__(
108-
self, ctx: _SSL.Context, sock: Optional[_socket.socket], suppress_ragged_eofs: bool
108+
self,
109+
ctx: _SSL.Context,
110+
sock: Optional[_socket.socket],
111+
suppress_ragged_eofs: bool,
112+
is_async: bool = False,
109113
):
110114
self.socket_checker = _SocketChecker()
111115
self.suppress_ragged_eofs = suppress_ragged_eofs
112116
super().__init__(ctx, sock)
117+
self._is_async = is_async
113118

114119
def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T:
115120
timeout = self.gettimeout()
@@ -119,6 +124,8 @@ def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T:
119124
try:
120125
return call(*args, **kwargs)
121126
except BLOCKING_IO_ERRORS as exc:
127+
if self._is_async:
128+
raise exc
122129
# Check for closed socket.
123130
if self.fileno() == -1:
124131
if timeout and _time.monotonic() - start > timeout:
@@ -381,7 +388,7 @@ async def a_wrap_socket(
381388
"""Wrap an existing Python socket connection and return a TLS socket
382389
object.
383390
"""
384-
ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs)
391+
ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs, True)
385392
loop = asyncio.get_running_loop()
386393
if session:
387394
ssl_conn.set_session(session)

0 commit comments

Comments
 (0)