Skip to content

Commit 2841d0d

Browse files
committed
Use loop.start_tls() to upgrade connections to SSL
The old way of TLS upgrade (openining a connection, asking postgres to do TLS and then duping the underlying socket) seems not to work anymore on Windows with Python 3.8.
1 parent a9b3713 commit 2841d0d

File tree

2 files changed

+116
-80
lines changed

2 files changed

+116
-80
lines changed

asyncpg/connect_utils.py

Lines changed: 111 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,94 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
504504
return addrs, params, config
505505

506506

507+
class TLSUpgradeProto(asyncio.Protocol):
508+
def __init__(self, loop, host, port, ssl_context, ssl_is_advisory):
509+
self.on_data = _create_future(loop)
510+
self.host = host
511+
self.port = port
512+
self.ssl_context = ssl_context
513+
self.ssl_is_advisory = ssl_is_advisory
514+
515+
def data_received(self, data):
516+
if data == b'S':
517+
self.on_data.set_result(True)
518+
elif (self.ssl_is_advisory and
519+
self.ssl_context.verify_mode == ssl_module.CERT_NONE and
520+
data == b'N'):
521+
# ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
522+
# since the only way to get ssl_is_advisory is from
523+
# sslmode=prefer (or sslmode=allow). But be extra sure to
524+
# disallow insecure connections when the ssl context asks for
525+
# real security.
526+
self.on_data.set_result(False)
527+
else:
528+
self.on_data.set_exception(
529+
ConnectionError(
530+
f'PostgreSQL server at "{self.host}:{self.port}" '
531+
f'rejected SSL upgrade'))
532+
533+
def connection_lost(self, exc):
534+
if not self.on_data.done():
535+
if exc is None:
536+
exc = ConnectionError('unexpected connection_lost() call')
537+
self.on_data.set_exception(exc)
538+
539+
540+
async def _create_ssl_connection(protocol_factory, host, port, *,
541+
loop, ssl_context, ssl_is_advisory=False):
542+
543+
if ssl_context is True:
544+
ssl_context = ssl_module.create_default_context()
545+
546+
tr, pr = await loop.create_connection(
547+
lambda: TLSUpgradeProto(loop, host, port,
548+
ssl_context, ssl_is_advisory),
549+
host, port)
550+
551+
tr.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message.
552+
553+
try:
554+
do_ssl_upgrade = await pr.on_data
555+
except (Exception, asyncio.CancelledError):
556+
tr.close()
557+
raise
558+
559+
if hasattr(loop, 'start_tls'):
560+
if do_ssl_upgrade:
561+
try:
562+
new_tr = await loop.start_tls(
563+
tr, pr, ssl_context, server_hostname=host)
564+
except (Exception, asyncio.CancelledError):
565+
tr.close()
566+
raise
567+
else:
568+
new_tr = tr
569+
570+
pg_proto = protocol_factory()
571+
pg_proto.connection_made(new_tr)
572+
new_tr.set_protocol(pg_proto)
573+
574+
return new_tr, pg_proto
575+
else:
576+
conn_factory = functools.partial(
577+
loop.create_connection, protocol_factory)
578+
579+
if do_ssl_upgrade:
580+
conn_factory = functools.partial(
581+
conn_factory, ssl=ssl_context, server_hostname=host)
582+
583+
sock = _get_socket(tr)
584+
sock = sock.dup()
585+
_set_nodelay(sock)
586+
tr.close()
587+
588+
try:
589+
return await conn_factory(sock=sock)
590+
except (Exception, asyncio.CancelledError):
591+
sock.close()
592+
raise
593+
594+
507595
async def _connect_addr(*, addr, loop, timeout, params, config,
508596
connection_class):
509597
assert loop is not None
@@ -526,8 +614,6 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
526614
else:
527615
connector = loop.create_connection(proto_factory, *addr)
528616

529-
connector = asyncio.ensure_future(connector)
530-
531617
before = time.monotonic()
532618
try:
533619
tr, pr = await asyncio.wait_for(
@@ -575,79 +661,41 @@ async def _connect(*, loop, timeout, connection_class, **kwargs):
575661
raise last_error
576662

577663

578-
async def _negotiate_ssl_connection(host, port, conn_factory, *, loop, ssl,
579-
server_hostname, ssl_is_advisory=False):
580-
# Note: ssl_is_advisory only affects behavior when the server does not
581-
# accept SSLRequests. If the SSLRequest is accepted but either the SSL
582-
# negotiation fails or the PostgreSQL user isn't permitted to use SSL,
583-
# there's nothing that would attempt to reconnect with a non-SSL socket.
584-
reader, writer = await asyncio.open_connection(host, port)
585-
586-
tr = writer.transport
587-
try:
588-
sock = _get_socket(tr)
589-
_set_nodelay(sock)
590-
591-
writer.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message.
592-
await writer.drain()
593-
resp = await reader.readexactly(1)
594-
595-
if resp == b'S':
596-
conn_factory = functools.partial(
597-
conn_factory, ssl=ssl, server_hostname=server_hostname)
598-
elif (ssl_is_advisory and
599-
ssl.verify_mode == ssl_module.CERT_NONE and
600-
resp == b'N'):
601-
# ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
602-
# since the only way to get ssl_is_advisory is from sslmode=prefer
603-
# (or sslmode=allow). But be extra sure to disallow insecure
604-
# connections when the ssl context asks for real security.
605-
pass
606-
else:
607-
raise ConnectionError(
608-
'PostgreSQL server at "{}:{}" rejected SSL upgrade'.format(
609-
host, port))
610-
611-
sock = sock.dup() # Must come before tr.close()
612-
finally:
613-
writer.close()
614-
await compat.wait_closed(writer)
615-
616-
try:
617-
return await conn_factory(sock=sock) # Must come after tr.close()
618-
except (Exception, asyncio.CancelledError):
619-
sock.close()
620-
raise
664+
async def _cancel(*, loop, addr, params: _ConnectionParameters,
665+
backend_pid, backend_secret):
621666

667+
class CancelProto(asyncio.Protocol):
622668

623-
async def _create_ssl_connection(protocol_factory, host, port, *,
624-
loop, ssl_context, ssl_is_advisory=False):
625-
return await _negotiate_ssl_connection(
626-
host, port,
627-
functools.partial(loop.create_connection, protocol_factory),
628-
loop=loop,
629-
ssl=ssl_context,
630-
server_hostname=host,
631-
ssl_is_advisory=ssl_is_advisory)
669+
def __init__(self):
670+
self.on_disconnect = _create_future(loop)
632671

672+
def connection_lost(self, exc):
673+
if not self.on_disconnect.done():
674+
self.on_disconnect.set_result(True)
633675

634-
async def _open_connection(*, loop, addr, params: _ConnectionParameters):
635676
if isinstance(addr, str):
636-
r, w = await asyncio.open_unix_connection(addr)
677+
tr, pr = await loop.create_unix_connection(CancelProto, addr)
637678
else:
638679
if params.ssl:
639-
r, w = await _negotiate_ssl_connection(
680+
tr, pr = await _create_ssl_connection(
681+
CancelProto,
640682
*addr,
641-
asyncio.open_connection,
642683
loop=loop,
643-
ssl=params.ssl,
644-
server_hostname=addr[0],
684+
ssl_context=params.ssl,
645685
ssl_is_advisory=params.ssl_is_advisory)
646686
else:
647-
r, w = await asyncio.open_connection(*addr)
648-
_set_nodelay(_get_socket(w.transport))
687+
tr, pr = await loop.create_connection(
688+
CancelProto, *addr)
689+
_set_nodelay(_get_socket(tr))
690+
691+
# Pack a CancelRequest message
692+
msg = struct.pack('!llll', 16, 80877102, backend_pid, backend_secret)
649693

650-
return r, w
694+
try:
695+
tr.write(msg)
696+
await pr.on_disconnect
697+
finally:
698+
tr.close()
651699

652700

653701
def _get_socket(transport):

asyncpg/connection.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import collections
1111
import collections.abc
1212
import itertools
13-
import struct
1413
import sys
1514
import time
1615
import traceback
@@ -1186,24 +1185,16 @@ async def _cleanup_stmts(self):
11861185
await self._protocol.close_statement(stmt, protocol.NO_TIMEOUT)
11871186

11881187
async def _cancel(self, waiter):
1189-
r = w = None
1190-
11911188
try:
11921189
# Open new connection to the server
1193-
r, w = await connect_utils._open_connection(
1194-
loop=self._loop, addr=self._addr, params=self._params)
1195-
1196-
# Pack CancelRequest message
1197-
msg = struct.pack('!llll', 16, 80877102,
1198-
self._protocol.backend_pid,
1199-
self._protocol.backend_secret)
1200-
1201-
w.write(msg)
1202-
await r.read() # Wait until EOF
1190+
await connect_utils._cancel(
1191+
loop=self._loop, addr=self._addr, params=self._params,
1192+
backend_pid=self._protocol.backend_pid,
1193+
backend_secret=self._protocol.backend_secret)
12031194
except ConnectionResetError as ex:
12041195
# On some systems Postgres will reset the connection
12051196
# after processing the cancellation command.
1206-
if r is None and not waiter.done():
1197+
if not waiter.done():
12071198
waiter.set_exception(ex)
12081199
except asyncio.CancelledError:
12091200
# There are two scenarios in which the cancellation
@@ -1221,9 +1212,6 @@ async def _cancel(self, waiter):
12211212
compat.current_asyncio_task(self._loop))
12221213
if not waiter.done():
12231214
waiter.set_result(None)
1224-
if w is not None:
1225-
w.close()
1226-
await compat.wait_closed(w)
12271215

12281216
def _cancel_current_command(self, waiter):
12291217
self._cancellations.add(self._loop.create_task(self._cancel(waiter)))

0 commit comments

Comments
 (0)