@@ -752,7 +752,6 @@ class Connection(object):
752
752
_socket = None
753
753
754
754
_socket_impl = socket
755
- _ssl_impl = ssl
756
755
757
756
_check_hostname = False
758
757
_product_type = None
@@ -780,7 +779,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
780
779
self .endpoint = host if isinstance (host , EndPoint ) else DefaultEndPoint (host , port )
781
780
782
781
self .authenticator = authenticator
783
- self .ssl_options = ssl_options .copy () if ssl_options else None
782
+ self .ssl_options = ssl_options .copy () if ssl_options else {}
784
783
self .ssl_context = ssl_context
785
784
self .sockopts = sockopts
786
785
self .compression = compression
@@ -800,15 +799,20 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
800
799
self ._on_orphaned_stream_released = on_orphaned_stream_released
801
800
802
801
if ssl_options :
803
- self ._check_hostname = bool (self .ssl_options .pop ('check_hostname' , False ))
804
- if self ._check_hostname :
805
- if not getattr (ssl , 'match_hostname' , None ):
806
- raise RuntimeError ("ssl_options specify 'check_hostname', but ssl.match_hostname is not provided. "
807
- "Patch or upgrade Python to use this option." )
808
802
self .ssl_options .update (self .endpoint .ssl_options or {})
809
803
elif self .endpoint .ssl_options :
810
804
self .ssl_options = self .endpoint .ssl_options
811
805
806
+ # PYTHON-1331
807
+ #
808
+ # We always use SSLContext.wrap_socket() now but legacy configs may have other params that were passed to ssl.wrap_socket()...
809
+ # and either could have 'check_hostname'. Remove these params into a separate map and use them to build an SSLContext if
810
+ # we need to do so.
811
+ #
812
+ # Note the use of pop() here; we are very deliberately removing these params from ssl_options if they're present. After this
813
+ # operation ssl_options should contain only args needed for the ssl_context.wrap_socket() call.
814
+ if not self .ssl_context and self .ssl_options :
815
+ self .ssl_context = self ._build_ssl_context_from_options ()
812
816
813
817
if protocol_version >= 3 :
814
818
self .max_request_id = min (self .max_in_flight - 1 , (2 ** 15 ) - 1 )
@@ -882,15 +886,48 @@ def factory(cls, endpoint, timeout, host_conn = None, *args, **kwargs):
882
886
else :
883
887
return conn
884
888
889
+ def _build_ssl_context_from_options (self ):
890
+
891
+ # Extract a subset of names from self.ssl_options which apply to SSLContext creation
892
+ ssl_context_opt_names = ['ssl_version' , 'cert_reqs' , 'check_hostname' , 'keyfile' , 'certfile' , 'ca_certs' , 'ciphers' ]
893
+ opts = {k :self .ssl_options .get (k , None ) for k in ssl_context_opt_names if k in self .ssl_options }
894
+
895
+ # Python >= 3.10 requires either PROTOCOL_TLS_CLIENT or PROTOCOL_TLS_SERVER so we'll get ahead of things by always
896
+ # being explicit
897
+ ssl_version = opts .get ('ssl_version' , None ) or ssl .PROTOCOL_TLS_CLIENT
898
+ cert_reqs = opts .get ('cert_reqs' , None ) or ssl .CERT_REQUIRED
899
+ rv = ssl .SSLContext (protocol = int (ssl_version ))
900
+ rv .check_hostname = bool (opts .get ('check_hostname' , False ))
901
+ rv .options = int (cert_reqs )
902
+
903
+ certfile = opts .get ('certfile' , None )
904
+ keyfile = opts .get ('keyfile' , None )
905
+ if certfile :
906
+ rv .load_cert_chain (certfile , keyfile )
907
+ ca_certs = opts .get ('ca_certs' , None )
908
+ if ca_certs :
909
+ rv .load_verify_locations (ca_certs )
910
+ ciphers = opts .get ('ciphers' , None )
911
+ if ciphers :
912
+ rv .set_ciphers (ciphers )
913
+
914
+ return rv
915
+
885
916
def _wrap_socket_from_context (self ):
886
- ssl_options = self .ssl_options or {}
917
+
918
+ # Extract a subset of names from self.ssl_options which apply to SSLContext.wrap_socket (or at least the parts
919
+ # of it that don't involve building an SSLContext under the covers)
920
+ wrap_socket_opt_names = ['server_side' , 'do_handshake_on_connect' , 'suppress_ragged_eofs' , 'server_hostname' ]
921
+ opts = {k :self .ssl_options .get (k , None ) for k in wrap_socket_opt_names if k in self .ssl_options }
922
+
887
923
# PYTHON-1186: set the server_hostname only if the SSLContext has
888
924
# check_hostname enabled and it is not already provided by the EndPoint ssl options
889
- if (self .ssl_context .check_hostname and
890
- 'server_hostname' not in ssl_options ):
891
- ssl_options = ssl_options .copy ()
892
- ssl_options ['server_hostname' ] = self .endpoint .address
893
- self ._socket = self .ssl_context .wrap_socket (self ._socket , ** ssl_options )
925
+ #opts['server_hostname'] = self.endpoint.address
926
+ if (self .ssl_context .check_hostname and 'server_hostname' not in opts ):
927
+ server_hostname = self .endpoint .address
928
+ opts ['server_hostname' ] = server_hostname
929
+
930
+ return self .ssl_context .wrap_socket (self ._socket , ** opts )
894
931
895
932
def _initiate_connection (self , sockaddr ):
896
933
if self .features .shard_id is not None :
@@ -904,8 +941,11 @@ def _initiate_connection(self, sockaddr):
904
941
905
942
self ._socket .connect (sockaddr )
906
943
907
- def _match_hostname (self ):
908
- ssl .match_hostname (self ._socket .getpeercert (), self .endpoint .address )
944
+ # PYTHON-1331
945
+ #
946
+ # Allow implementations specific to an event loop to add additional behaviours
947
+ def _validate_hostname (self ):
948
+ pass
909
949
910
950
def _get_socket_addresses (self ):
911
951
address , port = self .endpoint .resolve ()
@@ -927,18 +967,21 @@ def _connect_socket(self):
927
967
try :
928
968
self ._socket = self ._socket_impl .socket (af , socktype , proto )
929
969
if self .ssl_context :
930
- self ._wrap_socket_from_context ()
931
- elif self .ssl_options :
932
- if not self ._ssl_impl :
933
- raise RuntimeError ("This version of Python was not compiled with SSL support" )
934
- self ._socket = self ._ssl_impl .wrap_socket (self ._socket , ** self .ssl_options )
970
+ self ._socket = self ._wrap_socket_from_context ()
935
971
self ._socket .settimeout (self .connect_timeout )
936
972
self ._initiate_connection (sockaddr )
937
973
self ._socket .settimeout (None )
974
+
938
975
local_addr = self ._socket .getsockname ()
939
976
log .debug ("Connection %s: '%s' -> '%s'" , id (self ), local_addr , sockaddr )
977
+
978
+ # PYTHON-1331
979
+ #
980
+ # Most checking is done via the check_hostname param on the SSLContext.
981
+ # Subclasses can add additional behaviours via _validate_hostname() so
982
+ # run that here.
940
983
if self ._check_hostname :
941
- self ._match_hostname ()
984
+ self ._validate_hostname ()
942
985
sockerr = None
943
986
break
944
987
except socket .error as err :
0 commit comments