@@ -504,6 +504,94 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
504
504
return addrs , params , config
505
505
506
506
507
+ class TLSUpgradeProto (asyncio .Protocol ):
508
+ def __init__ (self , loop , host , port , ssl_context , ssl_is_advisory ):
509
+ self .on_data = _create_future (loop )
510
+ self .host = host
511
+ self .port = port
512
+ self .ssl_context = ssl_context
513
+ self .ssl_is_advisory = ssl_is_advisory
514
+
515
+ def data_received (self , data ):
516
+ if data == b'S' :
517
+ self .on_data .set_result (True )
518
+ elif (self .ssl_is_advisory and
519
+ self .ssl_context .verify_mode == ssl_module .CERT_NONE and
520
+ data == b'N' ):
521
+ # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
522
+ # since the only way to get ssl_is_advisory is from
523
+ # sslmode=prefer (or sslmode=allow). But be extra sure to
524
+ # disallow insecure connections when the ssl context asks for
525
+ # real security.
526
+ self .on_data .set_result (False )
527
+ else :
528
+ self .on_data .set_exception (
529
+ ConnectionError (
530
+ f'PostgreSQL server at "{ self .host } :{ self .port } " '
531
+ f'rejected SSL upgrade' ))
532
+
533
+ def connection_lost (self , exc ):
534
+ if not self .on_data .done ():
535
+ if exc is None :
536
+ exc = ConnectionError ('unexpected connection_lost() call' )
537
+ self .on_data .set_exception (exc )
538
+
539
+
540
+ async def _create_ssl_connection (protocol_factory , host , port , * ,
541
+ loop , ssl_context , ssl_is_advisory = False ):
542
+
543
+ if ssl_context is True :
544
+ ssl_context = ssl_module .create_default_context ()
545
+
546
+ tr , pr = await loop .create_connection (
547
+ lambda : TLSUpgradeProto (loop , host , port ,
548
+ ssl_context , ssl_is_advisory ),
549
+ host , port )
550
+
551
+ tr .write (struct .pack ('!ll' , 8 , 80877103 )) # SSLRequest message.
552
+
553
+ try :
554
+ do_ssl_upgrade = await pr .on_data
555
+ except (Exception , asyncio .CancelledError ):
556
+ tr .close ()
557
+ raise
558
+
559
+ if hasattr (loop , 'start_tls' ):
560
+ if do_ssl_upgrade :
561
+ try :
562
+ new_tr = await loop .start_tls (
563
+ tr , pr , ssl_context , server_hostname = host )
564
+ except (Exception , asyncio .CancelledError ):
565
+ tr .close ()
566
+ raise
567
+ else :
568
+ new_tr = tr
569
+
570
+ pg_proto = protocol_factory ()
571
+ pg_proto .connection_made (new_tr )
572
+ new_tr .set_protocol (pg_proto )
573
+
574
+ return new_tr , pg_proto
575
+ else :
576
+ conn_factory = functools .partial (
577
+ loop .create_connection , protocol_factory )
578
+
579
+ if do_ssl_upgrade :
580
+ conn_factory = functools .partial (
581
+ conn_factory , ssl = ssl_context , server_hostname = host )
582
+
583
+ sock = _get_socket (tr )
584
+ sock = sock .dup ()
585
+ _set_nodelay (sock )
586
+ tr .close ()
587
+
588
+ try :
589
+ return await conn_factory (sock = sock )
590
+ except (Exception , asyncio .CancelledError ):
591
+ sock .close ()
592
+ raise
593
+
594
+
507
595
async def _connect_addr (* , addr , loop , timeout , params , config ,
508
596
connection_class ):
509
597
assert loop is not None
@@ -526,8 +614,6 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
526
614
else :
527
615
connector = loop .create_connection (proto_factory , * addr )
528
616
529
- connector = asyncio .ensure_future (connector )
530
-
531
617
before = time .monotonic ()
532
618
try :
533
619
tr , pr = await asyncio .wait_for (
@@ -575,79 +661,41 @@ async def _connect(*, loop, timeout, connection_class, **kwargs):
575
661
raise last_error
576
662
577
663
578
- async def _negotiate_ssl_connection (host , port , conn_factory , * , loop , ssl ,
579
- server_hostname , ssl_is_advisory = False ):
580
- # Note: ssl_is_advisory only affects behavior when the server does not
581
- # accept SSLRequests. If the SSLRequest is accepted but either the SSL
582
- # negotiation fails or the PostgreSQL user isn't permitted to use SSL,
583
- # there's nothing that would attempt to reconnect with a non-SSL socket.
584
- reader , writer = await asyncio .open_connection (host , port )
585
-
586
- tr = writer .transport
587
- try :
588
- sock = _get_socket (tr )
589
- _set_nodelay (sock )
590
-
591
- writer .write (struct .pack ('!ll' , 8 , 80877103 )) # SSLRequest message.
592
- await writer .drain ()
593
- resp = await reader .readexactly (1 )
594
-
595
- if resp == b'S' :
596
- conn_factory = functools .partial (
597
- conn_factory , ssl = ssl , server_hostname = server_hostname )
598
- elif (ssl_is_advisory and
599
- ssl .verify_mode == ssl_module .CERT_NONE and
600
- resp == b'N' ):
601
- # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
602
- # since the only way to get ssl_is_advisory is from sslmode=prefer
603
- # (or sslmode=allow). But be extra sure to disallow insecure
604
- # connections when the ssl context asks for real security.
605
- pass
606
- else :
607
- raise ConnectionError (
608
- 'PostgreSQL server at "{}:{}" rejected SSL upgrade' .format (
609
- host , port ))
610
-
611
- sock = sock .dup () # Must come before tr.close()
612
- finally :
613
- writer .close ()
614
- await compat .wait_closed (writer )
615
-
616
- try :
617
- return await conn_factory (sock = sock ) # Must come after tr.close()
618
- except (Exception , asyncio .CancelledError ):
619
- sock .close ()
620
- raise
664
+ async def _cancel (* , loop , addr , params : _ConnectionParameters ,
665
+ backend_pid , backend_secret ):
621
666
667
+ class CancelProto (asyncio .Protocol ):
622
668
623
- async def _create_ssl_connection (protocol_factory , host , port , * ,
624
- loop , ssl_context , ssl_is_advisory = False ):
625
- return await _negotiate_ssl_connection (
626
- host , port ,
627
- functools .partial (loop .create_connection , protocol_factory ),
628
- loop = loop ,
629
- ssl = ssl_context ,
630
- server_hostname = host ,
631
- ssl_is_advisory = ssl_is_advisory )
669
+ def __init__ (self ):
670
+ self .on_disconnect = _create_future (loop )
632
671
672
+ def connection_lost (self , exc ):
673
+ if not self .on_disconnect .done ():
674
+ self .on_disconnect .set_result (True )
633
675
634
- async def _open_connection (* , loop , addr , params : _ConnectionParameters ):
635
676
if isinstance (addr , str ):
636
- r , w = await asyncio . open_unix_connection ( addr )
677
+ tr , pr = await loop . create_unix_connection ( CancelProto , addr )
637
678
else :
638
679
if params .ssl :
639
- r , w = await _negotiate_ssl_connection (
680
+ tr , pr = await _create_ssl_connection (
681
+ CancelProto ,
640
682
* addr ,
641
- asyncio .open_connection ,
642
683
loop = loop ,
643
- ssl = params .ssl ,
644
- server_hostname = addr [0 ],
684
+ ssl_context = params .ssl ,
645
685
ssl_is_advisory = params .ssl_is_advisory )
646
686
else :
647
- r , w = await asyncio .open_connection (* addr )
648
- _set_nodelay (_get_socket (w .transport ))
687
+ tr , pr = await loop .create_connection (
688
+ CancelProto , * addr )
689
+ _set_nodelay (_get_socket (tr ))
690
+
691
+ # Pack a CancelRequest message
692
+ msg = struct .pack ('!llll' , 16 , 80877102 , backend_pid , backend_secret )
649
693
650
- return r , w
694
+ try :
695
+ tr .write (msg )
696
+ await pr .on_disconnect
697
+ finally :
698
+ tr .close ()
651
699
652
700
653
701
def _get_socket (transport ):
0 commit comments