@@ -733,7 +733,6 @@ class Connection(object):
733
733
_socket = None
734
734
735
735
_socket_impl = socket
736
- _ssl_impl = ssl
737
736
738
737
_check_hostname = False
739
738
_product_type = None
@@ -757,7 +756,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
757
756
self .endpoint = host if isinstance (host , EndPoint ) else DefaultEndPoint (host , port )
758
757
759
758
self .authenticator = authenticator
760
- self .ssl_options = ssl_options .copy () if ssl_options else None
759
+ self .ssl_options = ssl_options .copy () if ssl_options else {}
761
760
self .ssl_context = ssl_context
762
761
self .sockopts = sockopts
763
762
self .compression = compression
@@ -777,15 +776,20 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
777
776
self ._on_orphaned_stream_released = on_orphaned_stream_released
778
777
779
778
if ssl_options :
780
- self ._check_hostname = bool (self .ssl_options .pop ('check_hostname' , False ))
781
- if self ._check_hostname :
782
- if not getattr (ssl , 'match_hostname' , None ):
783
- raise RuntimeError ("ssl_options specify 'check_hostname', but ssl.match_hostname is not provided. "
784
- "Patch or upgrade Python to use this option." )
785
779
self .ssl_options .update (self .endpoint .ssl_options or {})
786
780
elif self .endpoint .ssl_options :
787
781
self .ssl_options = self .endpoint .ssl_options
788
782
783
+ # PYTHON-1331
784
+ #
785
+ # We always use SSLContext.wrap_socket() now but legacy configs may have other params that were passed to ssl.wrap_socket()...
786
+ # and either could have 'check_hostname'. Remove these params into a separate map and use them to build an SSLContext if
787
+ # we need to do so.
788
+ #
789
+ # Note the use of pop() here; we are very deliberately removing these params from ssl_options if they're present. After this
790
+ # operation ssl_options should contain only args needed for the ssl_context.wrap_socket() call.
791
+ if not self .ssl_context and self .ssl_options :
792
+ self .ssl_context = self ._build_ssl_context_from_options ()
789
793
790
794
if protocol_version >= 3 :
791
795
self .max_request_id = min (self .max_in_flight - 1 , (2 ** 15 ) - 1 )
@@ -852,21 +856,57 @@ def factory(cls, endpoint, timeout, *args, **kwargs):
852
856
else :
853
857
return conn
854
858
859
+ def _build_ssl_context_from_options (self ):
860
+
861
+ # Extract a subset of names from self.ssl_options which apply to SSLContext creation
862
+ ssl_context_opt_names = ['ssl_version' , 'cert_reqs' , 'check_hostname' , 'keyfile' , 'certfile' , 'ca_certs' , 'ciphers' ]
863
+ opts = {k :self .ssl_options .get (k , None ) for k in ssl_context_opt_names if k in self .ssl_options }
864
+
865
+ # Python >= 3.10 requires either PROTOCOL_TLS_CLIENT or PROTOCOL_TLS_SERVER so we'll get ahead of things by always
866
+ # being explicit
867
+ ssl_version = opts .get ('ssl_version' , None ) or ssl .PROTOCOL_TLS_CLIENT
868
+ cert_reqs = opts .get ('cert_reqs' , None ) or ssl .CERT_REQUIRED
869
+ rv = ssl .SSLContext (protocol = int (ssl_version ))
870
+ rv .check_hostname = bool (opts .get ('check_hostname' , False ))
871
+ rv .options = int (cert_reqs )
872
+
873
+ certfile = opts .get ('certfile' , None )
874
+ keyfile = opts .get ('keyfile' , None )
875
+ if certfile :
876
+ rv .load_cert_chain (certfile , keyfile )
877
+ ca_certs = opts .get ('ca_certs' , None )
878
+ if ca_certs :
879
+ rv .load_verify_locations (ca_certs )
880
+ ciphers = opts .get ('ciphers' , None )
881
+ if ciphers :
882
+ rv .set_ciphers (ciphers )
883
+
884
+ return rv
885
+
855
886
def _wrap_socket_from_context (self ):
856
- ssl_options = self .ssl_options or {}
887
+
888
+ # Extract a subset of names from self.ssl_options which apply to SSLContext.wrap_socket (or at least the parts
889
+ # of it that don't involve building an SSLContext under the covers)
890
+ wrap_socket_opt_names = ['server_side' , 'do_handshake_on_connect' , 'suppress_ragged_eofs' , 'server_hostname' ]
891
+ opts = {k :self .ssl_options .get (k , None ) for k in wrap_socket_opt_names if k in self .ssl_options }
892
+
857
893
# PYTHON-1186: set the server_hostname only if the SSLContext has
858
894
# check_hostname enabled and it is not already provided by the EndPoint ssl options
859
- if (self .ssl_context .check_hostname and
860
- 'server_hostname' not in ssl_options ):
861
- ssl_options = ssl_options .copy ()
862
- ssl_options ['server_hostname' ] = self .endpoint .address
863
- self ._socket = self .ssl_context .wrap_socket (self ._socket , ** ssl_options )
895
+ #opts['server_hostname'] = self.endpoint.address
896
+ if (self .ssl_context .check_hostname and 'server_hostname' not in opts ):
897
+ server_hostname = self .endpoint .address
898
+ opts ['server_hostname' ] = server_hostname
899
+
900
+ return self .ssl_context .wrap_socket (self ._socket , ** opts )
864
901
865
902
def _initiate_connection (self , sockaddr ):
866
903
self ._socket .connect (sockaddr )
867
904
868
- def _match_hostname (self ):
869
- ssl .match_hostname (self ._socket .getpeercert (), self .endpoint .address )
905
+ # PYTHON-1331
906
+ #
907
+ # Allow implementations specific to an event loop to add additional behaviours
908
+ def _validate_hostname (self ):
909
+ pass
870
910
871
911
def _get_socket_addresses (self ):
872
912
address , port = self .endpoint .resolve ()
@@ -887,16 +927,18 @@ def _connect_socket(self):
887
927
try :
888
928
self ._socket = self ._socket_impl .socket (af , socktype , proto )
889
929
if self .ssl_context :
890
- self ._wrap_socket_from_context ()
891
- elif self .ssl_options :
892
- if not self ._ssl_impl :
893
- raise RuntimeError ("This version of Python was not compiled with SSL support" )
894
- self ._socket = self ._ssl_impl .wrap_socket (self ._socket , ** self .ssl_options )
930
+ self ._socket = self ._wrap_socket_from_context ()
895
931
self ._socket .settimeout (self .connect_timeout )
896
932
self ._initiate_connection (sockaddr )
897
933
self ._socket .settimeout (None )
934
+
935
+ # PYTHON-1331
936
+ #
937
+ # Most checking is done via the check_hostname param on the SSLContext.
938
+ # Subclasses can add additional behaviours via _validate_hostname() so
939
+ # run that here.
898
940
if self ._check_hostname :
899
- self ._match_hostname ()
941
+ self ._validate_hostname ()
900
942
sockerr = None
901
943
break
902
944
except socket .error as err :
0 commit comments