@@ -752,7 +752,6 @@ class Connection(object):
752
752
_socket = None
753
753
754
754
_socket_impl = socket
755
- _ssl_impl = ssl
756
755
757
756
_check_hostname = False
758
757
_product_type = None
@@ -780,7 +779,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
780
779
self.endpoint = host if isinstance(host, EndPoint) else DefaultEndPoint(host, port)
781
780
782
781
self.authenticator = authenticator
783
- self .ssl_options = ssl_options .copy () if ssl_options else None
782
+ self.ssl_options = ssl_options.copy() if ssl_options else {}
784
783
self.ssl_context = ssl_context
785
784
self.sockopts = sockopts
786
785
self.compression = compression
@@ -800,15 +799,20 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
800
799
self._on_orphaned_stream_released = on_orphaned_stream_released
801
800
802
801
if ssl_options:
803
- self ._check_hostname = bool (self .ssl_options .pop ('check_hostname' , False ))
804
- if self ._check_hostname :
805
- if not getattr (ssl , 'match_hostname' , None ):
806
- raise RuntimeError ("ssl_options specify 'check_hostname', but ssl.match_hostname is not provided. "
807
- "Patch or upgrade Python to use this option." )
808
802
self.ssl_options.update(self.endpoint.ssl_options or {})
809
803
elif self.endpoint.ssl_options:
810
804
self.ssl_options = self.endpoint.ssl_options
811
805
806
+ # PYTHON-1331
807
+ #
808
+ # We always use SSLContext.wrap_socket() now but legacy configs may have other params that were passed to ssl.wrap_socket()...
809
+ # and either could have 'check_hostname'. Remove these params into a separate map and use them to build an SSLContext if
810
+ # we need to do so.
811
+ #
812
+ # Note the use of pop() here; we are very deliberately removing these params from ssl_options if they're present. After this
813
+ # operation ssl_options should contain only args needed for the ssl_context.wrap_socket() call.
814
+ if not self.ssl_context and self.ssl_options:
815
+ self.ssl_context = self._build_ssl_context_from_options()
812
816
813
817
if protocol_version >= 3:
814
818
self.max_request_id = min(self.max_in_flight - 1, (2 ** 15) - 1)
@@ -882,15 +886,48 @@ def factory(cls, endpoint, timeout, host_conn = None, *args, **kwargs):
882
886
else:
883
887
return conn
884
888
889
+ def _build_ssl_context_from_options(self):
890
+
891
+ # Extract a subset of names from self.ssl_options which apply to SSLContext creation
892
+ ssl_context_opt_names = ['ssl_version', 'cert_reqs', 'check_hostname', 'keyfile', 'certfile', 'ca_certs', 'ciphers']
893
+ opts = {k:self.ssl_options.get(k, None) for k in ssl_context_opt_names if k in self.ssl_options}
894
+
895
+ # Python >= 3.10 requires either PROTOCOL_TLS_CLIENT or PROTOCOL_TLS_SERVER so we'll get ahead of things by always
896
+ # being explicit
897
+ ssl_version = opts.get('ssl_version', None) or ssl.PROTOCOL_TLS_CLIENT
898
+ cert_reqs = opts.get('cert_reqs', None) or ssl.CERT_REQUIRED
899
+ rv = ssl.SSLContext(protocol=int(ssl_version))
900
+ rv.check_hostname = bool(opts.get('check_hostname', False))
901
+ rv.options = int(cert_reqs)
902
+
903
+ certfile = opts.get('certfile', None)
904
+ keyfile = opts.get('keyfile', None)
905
+ if certfile:
906
+ rv.load_cert_chain(certfile, keyfile)
907
+ ca_certs = opts.get('ca_certs', None)
908
+ if ca_certs:
909
+ rv.load_verify_locations(ca_certs)
910
+ ciphers = opts.get('ciphers', None)
911
+ if ciphers:
912
+ rv.set_ciphers(ciphers)
913
+
914
+ return rv
915
+
885
916
def _wrap_socket_from_context(self):
886
- ssl_options = self .ssl_options or {}
917
+
918
+ # Extract a subset of names from self.ssl_options which apply to SSLContext.wrap_socket (or at least the parts
919
+ # of it that don't involve building an SSLContext under the covers)
920
+ wrap_socket_opt_names = ['server_side', 'do_handshake_on_connect', 'suppress_ragged_eofs', 'server_hostname']
921
+ opts = {k:self.ssl_options.get(k, None) for k in wrap_socket_opt_names if k in self.ssl_options}
922
+
887
923
# PYTHON-1186: set the server_hostname only if the SSLContext has
888
924
# check_hostname enabled and it is not already provided by the EndPoint ssl options
889
- if (self .ssl_context .check_hostname and
890
- 'server_hostname' not in ssl_options ):
891
- ssl_options = ssl_options .copy ()
892
- ssl_options ['server_hostname' ] = self .endpoint .address
893
- self ._socket = self .ssl_context .wrap_socket (self ._socket , ** ssl_options )
925
+ #opts['server_hostname'] = self.endpoint.address
926
+ if (self.ssl_context.check_hostname and 'server_hostname' not in opts):
927
+ server_hostname = self.endpoint.address
928
+ opts['server_hostname'] = server_hostname
929
+
930
+ return self.ssl_context.wrap_socket(self._socket, **opts)
894
931
895
932
def _initiate_connection(self, sockaddr):
896
933
if self.features.shard_id is not None:
@@ -904,8 +941,11 @@ def _initiate_connection(self, sockaddr):
904
941
905
942
self._socket.connect(sockaddr)
906
943
907
- def _match_hostname (self ):
908
- ssl .match_hostname (self ._socket .getpeercert (), self .endpoint .address )
944
+ # PYTHON-1331
945
+ #
946
+ # Allow implementations specific to an event loop to add additional behaviours
947
+ def _validate_hostname(self):
948
+ pass
909
949
910
950
def _get_socket_addresses(self):
911
951
address, port = self.endpoint.resolve()
@@ -927,18 +967,21 @@ def _connect_socket(self):
927
967
try:
928
968
self._socket = self._socket_impl.socket(af, socktype, proto)
929
969
if self.ssl_context:
930
- self ._wrap_socket_from_context ()
931
- elif self .ssl_options :
932
- if not self ._ssl_impl :
933
- raise RuntimeError ("This version of Python was not compiled with SSL support" )
934
- self ._socket = self ._ssl_impl .wrap_socket (self ._socket , ** self .ssl_options )
970
+ self._socket = self._wrap_socket_from_context()
935
971
self._socket.settimeout(self.connect_timeout)
936
972
self._initiate_connection(sockaddr)
937
973
self._socket.settimeout(None)
974
+
938
975
local_addr = self._socket.getsockname()
939
976
log.debug("Connection %s: '%s' -> '%s'", id(self), local_addr, sockaddr)
977
+
978
+ # PYTHON-1331
979
+ #
980
+ # Most checking is done via the check_hostname param on the SSLContext.
981
+ # Subclasses can add additional behaviours via _validate_hostname() so
982
+ # run that here.
940
983
if self._check_hostname:
941
- self ._match_hostname ()
984
+ self._validate_hostname ()
942
985
sockerr = None
943
986
break
944
987
except socket.error as err:
0 commit comments