diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 9823d048..4d8452e9 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -13,6 +13,7 @@ import platform import re import socket +import ssl import stat import struct import time @@ -458,14 +459,60 @@ async def _get_ssl_ready_socket(host, port, *, loop): async def _create_ssl_connection(protocol_factory, host, port, *, loop, ssl_context): - sock = await _get_ssl_ready_socket(host, port, loop=loop) - try: - return await loop.create_connection( - protocol_factory, sock=sock, ssl=ssl_context, - server_hostname=host) - except Exception: - sock.close() - raise + + class TLSUpgradeProto(asyncio.Protocol): + def __init__(self): + self.on_data = loop.create_future() + + def data_received(self, data): + if data == b'S': + self.on_data.set_result(True) + else: + self.on_data.set_exception( + ConnectionError( + 'PostgreSQL server at "{}:{}" ' + 'rejected SSL upgrade'.format(host, port))) + + def connection_lost(self, exc): + if not self.on_data.done(): + if exc is None: + exc = ConnectionError('unexpected connection_lost() call') + self.on_data.set_exception(exc) + + if hasattr(loop, 'start_tls'): + tr, pr = await loop.create_connection(TLSUpgradeProto, host, port) + tr.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message. + + try: + await pr.on_data + except Exception: + tr.close() + raise + + if ssl_context is True: + ssl_context = ssl.create_default_context() + + try: + new_tr = await loop.start_tls( + tr, pr, ssl_context, server_hostname=host) + except Exception: + tr.close() + raise + + pg_proto = protocol_factory() + pg_proto.connection_made(new_tr) + new_tr.set_protocol(pg_proto) + + return new_tr, pg_proto + else: + sock = await _get_ssl_ready_socket(host, port, loop=loop) + try: + return await loop.create_connection( + protocol_factory, sock=sock, ssl=ssl_context, + server_hostname=host) + except Exception: + sock.close() + raise async def _open_connection(*, loop, addr, params: _ConnectionParameters):