diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 584999d8e7..ae2cfd6cfc 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -35,6 +35,7 @@ jobs: inputs: requirements.txt dev_requirements.txt ignore-vulns: | GHSA-w596-4wvx-j9j6 # subversion related git pull, dependency for pytest. There is no impact here. + PYSEC-2024-48 # black vulnerability in 22.3.0, can't upgrade due to python 3.7 support, no impact lint: name: Code linters diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 8260515867..480a877d20 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -221,6 +221,7 @@ def __init__( ssl_ca_data: Optional[str] = None, ssl_check_hostname: bool = False, ssl_min_version: Optional[ssl.TLSVersion] = None, + ssl_ciphers: Optional[str] = None, max_connections: Optional[int] = None, single_connection_client: bool = False, health_check_interval: int = 0, @@ -314,6 +315,7 @@ def __init__( "ssl_ca_data": ssl_ca_data, "ssl_check_hostname": ssl_check_hostname, "ssl_min_version": ssl_min_version, + "ssl_ciphers": ssl_ciphers, } ) # This arg only used if no pool is passed in diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index f06a277779..3bf147d7a6 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -267,6 +267,7 @@ def __init__( ssl_check_hostname: bool = False, ssl_keyfile: Optional[str] = None, ssl_min_version: Optional[ssl.TLSVersion] = None, + ssl_ciphers: Optional[str] = None, protocol: Optional[int] = 2, address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, ) -> None: @@ -326,6 +327,7 @@ def __init__( "ssl_check_hostname": ssl_check_hostname, "ssl_keyfile": ssl_keyfile, "ssl_min_version": ssl_min_version, + "ssl_ciphers": ssl_ciphers, } ) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 0e100e9d8d..0074bce5db 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -739,6 +739,7 @@ def __init__( ssl_ca_data: Optional[str] = None, ssl_check_hostname: bool = False, ssl_min_version: Optional[ssl.TLSVersion] = None, + ssl_ciphers: Optional[str] = None, **kwargs, ): self.ssl_context: RedisSSLContext = RedisSSLContext( @@ -749,6 +750,7 @@ def __init__( ca_data=ssl_ca_data, check_hostname=ssl_check_hostname, min_version=ssl_min_version, + ciphers=ssl_ciphers, ) super().__init__(**kwargs) @@ -796,6 +798,7 @@ class RedisSSLContext: "context", "check_hostname", "min_version", + "ciphers", ) def __init__( @@ -807,6 +810,7 @@ def __init__( ca_data: Optional[str] = None, check_hostname: bool = False, min_version: Optional[ssl.TLSVersion] = None, + ciphers: Optional[str] = None, ): self.keyfile = keyfile self.certfile = certfile @@ -827,6 +831,7 @@ def __init__( self.ca_data = ca_data self.check_hostname = check_hostname self.min_version = min_version + self.ciphers = ciphers self.context: Optional[ssl.SSLContext] = None def get(self) -> ssl.SSLContext: @@ -840,6 +845,8 @@ def get(self) -> ssl.SSLContext: context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data) if self.min_version is not None: context.minimum_version = self.min_version + if self.ciphers is not None: + context.set_ciphers(self.ciphers) self.context = context return self.context diff --git a/redis/client.py b/redis/client.py index f95f4883bc..5a12f40e7b 100755 --- a/redis/client.py +++ b/redis/client.py @@ -198,6 +198,7 @@ def __init__( ssl_ocsp_context=None, ssl_ocsp_expected_cert=None, ssl_min_version=None, + ssl_ciphers=None, max_connections=None, single_connection_client=False, health_check_interval=0, @@ -298,6 +299,7 @@ def __init__( "ssl_ocsp_context": ssl_ocsp_context, "ssl_ocsp_expected_cert": ssl_ocsp_expected_cert, "ssl_min_version": ssl_min_version, + "ssl_ciphers": ssl_ciphers, } ) connection_pool = ConnectionPool(**kwargs) diff --git a/redis/connection.py b/redis/connection.py index 55dc03f3cf..346ff3aa6b 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -685,6 +685,7 @@ def __init__( ssl_ocsp_context=None, ssl_ocsp_expected_cert=None, ssl_min_version=None, + ssl_ciphers=None, **kwargs, ): """Constructor @@ -704,6 +705,7 @@ def __init__( ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service. ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module. + ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information. Raises: RedisError @@ -737,6 +739,7 @@ def __init__( self.ssl_ocsp_context = ssl_ocsp_context self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert self.ssl_min_version = ssl_min_version + self.ssl_ciphers = ssl_ciphers super().__init__(**kwargs) def _connect(self): @@ -761,6 +764,8 @@ def _connect(self): ) if self.ssl_min_version is not None: context.minimum_version = self.ssl_min_version + if self.ssl_ciphers: + context.set_ciphers(self.ssl_ciphers) sslsock = context.wrap_socket(sock, server_hostname=self.host) if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False: raise RedisError("cryptography is not installed.") diff --git a/setup.py b/setup.py index c5076e46bb..20546fcb45 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ long_description_content_type="text/markdown", keywords=["Redis", "key-value store", "database"], license="MIT", - version="5.0.3", + version="5.0.4", packages=find_packages( include=[ "redis", diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index e6cf2e4ce7..d2f165b5f5 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -1,6 +1,7 @@ import asyncio import binascii import datetime +import ssl import warnings from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, Union from urllib.parse import urlparse @@ -2951,6 +2952,59 @@ async def test_ssl_connection( async with await create_client(ssl=True, ssl_cert_reqs="none") as rc: assert await rc.ping() + @pytest.mark.parametrize( + "ssl_ciphers", + [ + "AES256-SHA:DHE-RSA-AES256-SHA:AES128-SHA:DHE-RSA-AES128-SHA", + "ECDHE-ECDSA-AES256-GCM-SHA384", + "ECDHE-RSA-AES128-GCM-SHA256", + ], + ) + async def test_ssl_connection_tls12_custom_ciphers( + self, ssl_ciphers, create_client: Callable[..., Awaitable[RedisCluster]] + ) -> None: + async with await create_client( + ssl=True, + ssl_cert_reqs="none", + ssl_min_version=ssl.TLSVersion.TLSv1_2, + ssl_ciphers=ssl_ciphers, + ) as rc: + assert await rc.ping() + + async def test_ssl_connection_tls12_custom_ciphers_invalid( + self, create_client: Callable[..., Awaitable[RedisCluster]] + ) -> None: + async with await create_client( + ssl=True, + ssl_cert_reqs="none", + ssl_min_version=ssl.TLSVersion.TLSv1_2, + ssl_ciphers="foo:bar", + ) as rc: + with pytest.raises(RedisClusterException) as e: + assert await rc.ping() + assert "Redis Cluster cannot be connected" in str(e.value) + + @pytest.mark.parametrize( + "ssl_ciphers", + [ + "TLS_CHACHA20_POLY1305_SHA256", + "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256", + ], + ) + async def test_ssl_connection_tls13_custom_ciphers( + self, ssl_ciphers, create_client: Callable[..., Awaitable[RedisCluster]] + ) -> None: + # TLSv1.3 does not support changing the ciphers + async with await create_client( + ssl=True, + ssl_cert_reqs="none", + ssl_min_version=ssl.TLSVersion.TLSv1_2, + ssl_ciphers=ssl_ciphers, + ) as rc: + with pytest.raises(RedisClusterException) as e: + assert await rc.ping() + assert "Redis Cluster cannot be connected" in str(e.value) + async def test_validating_self_signed_certificate( self, create_client: Callable[..., Awaitable[RedisCluster]] ) -> None: diff --git a/tests/test_asyncio/test_connect.py b/tests/test_asyncio/test_connect.py index 5497501258..6c902c2d05 100644 --- a/tests/test_asyncio/test_connect.py +++ b/tests/test_asyncio/test_connect.py @@ -50,6 +50,32 @@ async def test_uds_connect(uds_address): await _assert_connect(conn, path) +@pytest.mark.ssl +@pytest.mark.parametrize( + "ssl_ciphers", + [ + "AES256-SHA:DHE-RSA-AES256-SHA:AES128-SHA:DHE-RSA-AES128-SHA", + "ECDHE-ECDSA-AES256-GCM-SHA384", + "ECDHE-RSA-AES128-GCM-SHA256", + ], +) +async def test_tcp_ssl_tls12_custom_ciphers(tcp_address, ssl_ciphers): + host, port = tcp_address + certfile = get_ssl_filename("server-cert.pem") + keyfile = get_ssl_filename("server-key.pem") + conn = SSLConnection( + host=host, + port=port, + client_name=_CLIENT_NAME, + ssl_ca_certs=certfile, + socket_timeout=10, + ssl_min_version=ssl.TLSVersion.TLSv1_2, + ssl_ciphers=ssl_ciphers, + ) + await _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) + await conn.disconnect() + + @pytest.mark.ssl @pytest.mark.parametrize( "ssl_min_version", diff --git a/tests/test_connect.py b/tests/test_connect.py index 0fdbb7005f..fcc1a05268 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -71,6 +71,31 @@ def test_tcp_ssl_connect(tcp_address, ssl_min_version): _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) +@pytest.mark.ssl +@pytest.mark.parametrize( + "ssl_ciphers", + [ + "AES256-SHA:DHE-RSA-AES256-SHA:AES128-SHA:DHE-RSA-AES128-SHA", + "ECDHE-ECDSA-AES256-GCM-SHA384", + "ECDHE-RSA-AES128-GCM-SHA256", + ], +) +def test_tcp_ssl_tls12_custom_ciphers(tcp_address, ssl_ciphers): + host, port = tcp_address + certfile = get_ssl_filename("server-cert.pem") + keyfile = get_ssl_filename("server-key.pem") + conn = SSLConnection( + host=host, + port=port, + client_name=_CLIENT_NAME, + ssl_ca_certs=certfile, + socket_timeout=10, + ssl_min_version=ssl.TLSVersion.TLSv1_2, + ssl_ciphers=ssl_ciphers, + ) + _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) + + @pytest.mark.ssl @pytest.mark.skipif(not ssl.HAS_TLSv1_3, reason="requires TLSv1.3") def test_tcp_ssl_version_mismatch(tcp_address): diff --git a/tests/test_ssl.py b/tests/test_ssl.py index 465fdabb89..0e91750aa5 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -73,6 +73,68 @@ def test_validating_self_signed_string_certificate(self, request): ) assert r.ping() + @pytest.mark.parametrize( + "ssl_ciphers", + [ + "AES256-SHA:DHE-RSA-AES256-SHA:AES128-SHA:DHE-RSA-AES128-SHA", + "DHE-RSA-AES256-GCM-SHA384", + "ECDHE-RSA-AES256-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305", + ], + ) + def test_ssl_connection_tls12_custom_ciphers(self, request, ssl_ciphers): + ssl_url = request.config.option.redis_ssl_url + p = urlparse(ssl_url)[1].split(":") + r = redis.Redis( + host=p[0], + port=p[1], + ssl=True, + ssl_cert_reqs="none", + ssl_min_version=ssl.TLSVersion.TLSv1_3, + ssl_ciphers=ssl_ciphers, + ) + assert r.ping() + r.close() + + def test_ssl_connection_tls12_custom_ciphers_invalid(self, request): + ssl_url = request.config.option.redis_ssl_url + p = urlparse(ssl_url)[1].split(":") + r = redis.Redis( + host=p[0], + port=p[1], + ssl=True, + ssl_cert_reqs="none", + ssl_min_version=ssl.TLSVersion.TLSv1_2, + ssl_ciphers="foo:bar", + ) + with pytest.raises(RedisError) as e: + r.ping() + assert "No cipher can be selected" in str(e) + r.close() + + @pytest.mark.parametrize( + "ssl_ciphers", + [ + "TLS_CHACHA20_POLY1305_SHA256", + "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256", + ], + ) + def test_ssl_connection_tls13_custom_ciphers(self, request, ssl_ciphers): + # TLSv1.3 does not support changing the ciphers + ssl_url = request.config.option.redis_ssl_url + p = urlparse(ssl_url)[1].split(":") + r = redis.Redis( + host=p[0], + port=p[1], + ssl=True, + ssl_cert_reqs="none", + ssl_min_version=ssl.TLSVersion.TLSv1_2, + ssl_ciphers=ssl_ciphers, + ) + with pytest.raises(RedisError) as e: + r.ping() + assert "No cipher can be selected" in str(e) + r.close() + def _create_oscp_conn(self, request): ssl_url = request.config.option.redis_ssl_url p = urlparse(ssl_url)[1].split(":")