Skip to content

Commit 326f351

Browse files
authored
RESP3 tests (#2780)
* fix command response in resp3 * linters * acl_log & acl_getuser * client_info * test_commands and test_asyncio/test_commands * fix test_command_parser * fix asyncio/test_connection/test_invalid_response * linters * all the tests * push handler sharded pubsub * Use assert_resp_response wherever possible * fix test_xreadgroup * fix cluster_zdiffstore and cluster_zinter * fix review comments * fix review comments * linters
1 parent e8fc092 commit 326f351

19 files changed

+812
-705
lines changed

redis/asyncio/client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -671,13 +671,13 @@ def __init__(
671671
if self.encoder is None:
672672
self.encoder = self.connection_pool.get_encoder()
673673
if self.encoder.decode_responses:
674-
self.health_check_response: Iterable[Union[str, bytes]] = [
675-
"pong",
674+
self.health_check_response = [
675+
["pong", self.HEALTH_CHECK_MESSAGE],
676676
self.HEALTH_CHECK_MESSAGE,
677677
]
678678
else:
679679
self.health_check_response = [
680-
b"pong",
680+
[b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE)],
681681
self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
682682
]
683683
if self.push_handler_func is None:
@@ -807,7 +807,7 @@ async def parse_response(self, block: bool = True, timeout: float = 0):
807807
conn, conn.read_response, timeout=read_timeout, push_request=True
808808
)
809809

810-
if conn.health_check_interval and response == self.health_check_response:
810+
if conn.health_check_interval and response in self.health_check_response:
811811
# ignore the health check message as user might not expect it
812812
return None
813813
return response

redis/asyncio/cluster.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,8 @@ def __init__(
319319
kwargs.update({"retry": self.retry})
320320

321321
kwargs["response_callbacks"] = self.__class__.RESPONSE_CALLBACKS.copy()
322+
if kwargs.get("protocol") in ["3", 3]:
323+
kwargs["response_callbacks"].update(self.__class__.RESP3_RESPONSE_CALLBACKS)
322324
self.connection_kwargs = kwargs
323325

324326
if startup_nodes:

redis/asyncio/connection.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,16 +333,36 @@ def _error_message(self, exception):
333333
async def on_connect(self) -> None:
334334
"""Initialize the connection, authenticate and select a database"""
335335
self._parser.on_connect(self)
336+
parser = self._parser
336337

338+
auth_args = None
337339
# if credential provider or username and/or password are set, authenticate
338340
if self.credential_provider or (self.username or self.password):
339341
cred_provider = (
340342
self.credential_provider
341343
or UsernamePasswordCredentialProvider(self.username, self.password)
342344
)
343345
auth_args = cred_provider.get_credentials()
344-
# avoid checking health here -- PING will fail if we try
345-
# to check the health prior to the AUTH
346+
# if resp version is specified and we have auth args,
347+
# we need to send them via HELLO
348+
if auth_args and self.protocol not in [2, "2"]:
349+
if isinstance(self._parser, _AsyncRESP2Parser):
350+
self.set_parser(_AsyncRESP3Parser)
351+
# update cluster exception classes
352+
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
353+
self._parser.on_connect(self)
354+
if len(auth_args) == 1:
355+
auth_args = ["default", auth_args[0]]
356+
await self.send_command("HELLO", self.protocol, "AUTH", *auth_args)
357+
response = await self.read_response()
358+
if response.get(b"proto") not in [2, "2"] and response.get("proto") not in [
359+
2,
360+
"2",
361+
]:
362+
raise ConnectionError("Invalid RESP version")
363+
# avoid checking health here -- PING will fail if we try
364+
# to check the health prior to the AUTH
365+
elif auth_args:
346366
await self.send_command("AUTH", *auth_args, check_health=False)
347367

348368
try:
@@ -359,9 +379,11 @@ async def on_connect(self) -> None:
359379
raise AuthenticationError("Invalid Username or Password")
360380

361381
# if resp version is specified, switch to it
362-
if self.protocol != 2:
382+
elif self.protocol != 2:
363383
if isinstance(self._parser, _AsyncRESP2Parser):
364384
self.set_parser(_AsyncRESP3Parser)
385+
# update cluster exception classes
386+
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
365387
self._parser.on_connect(self)
366388
await self.send_command("HELLO", self.protocol)
367389
response = await self.read_response()

redis/client.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -331,9 +331,15 @@ def parse_xinfo_stream(response, **options):
331331
data["last-entry"] = (last[0], pairs_to_dict(last[1]))
332332
else:
333333
data["entries"] = {_id: pairs_to_dict(entry) for _id, entry in data["entries"]}
334-
data["groups"] = [
335-
pairs_to_dict(group, decode_keys=True) for group in data["groups"]
336-
]
334+
if isinstance(data["groups"][0], list):
335+
data["groups"] = [
336+
pairs_to_dict(group, decode_keys=True) for group in data["groups"]
337+
]
338+
else:
339+
data["groups"] = [
340+
{str_if_bytes(k): v for k, v in group.items()}
341+
for group in data["groups"]
342+
]
337343
return data
338344

339345

@@ -581,14 +587,15 @@ def parse_command_resp3(response, **options):
581587
cmd_name = str_if_bytes(command[0])
582588
cmd_dict["name"] = cmd_name
583589
cmd_dict["arity"] = command[1]
584-
cmd_dict["flags"] = command[2]
590+
cmd_dict["flags"] = {str_if_bytes(flag) for flag in command[2]}
585591
cmd_dict["first_key_pos"] = command[3]
586592
cmd_dict["last_key_pos"] = command[4]
587593
cmd_dict["step_count"] = command[5]
588594
cmd_dict["acl_categories"] = command[6]
589-
cmd_dict["tips"] = command[7]
590-
cmd_dict["key_specifications"] = command[8]
591-
cmd_dict["subcommands"] = command[9]
595+
if len(command) > 7:
596+
cmd_dict["tips"] = command[7]
597+
cmd_dict["key_specifications"] = command[8]
598+
cmd_dict["subcommands"] = command[9]
592599

593600
commands[cmd_name] = cmd_dict
594601
return commands
@@ -626,17 +633,20 @@ def parse_acl_getuser(response, **options):
626633
if data["channels"] == [""]:
627634
data["channels"] = []
628635
if "selectors" in data:
629-
data["selectors"] = [
630-
list(map(str_if_bytes, selector)) for selector in data["selectors"]
631-
]
636+
if data["selectors"] != [] and isinstance(data["selectors"][0], list):
637+
data["selectors"] = [
638+
list(map(str_if_bytes, selector)) for selector in data["selectors"]
639+
]
640+
elif data["selectors"] != []:
641+
data["selectors"] = [
642+
{str_if_bytes(k): str_if_bytes(v) for k, v in selector.items()}
643+
for selector in data["selectors"]
644+
]
632645

633646
# split 'commands' into separate 'categories' and 'commands' lists
634647
commands, categories = [], []
635648
for command in data["commands"].split(" "):
636-
if "@" in command:
637-
categories.append(command)
638-
else:
639-
commands.append(command)
649+
categories.append(command) if "@" in command else commands.append(command)
640650

641651
data["commands"] = commands
642652
data["categories"] = categories

redis/cluster.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from redis.parsers import CommandsParser, Encoder
3434
from redis.retry import Retry
3535
from redis.utils import (
36+
HIREDIS_AVAILABLE,
3637
dict_merge,
3738
list_keys_to_dict,
3839
merge_result,
@@ -1608,7 +1609,15 @@ class ClusterPubSub(PubSub):
16081609
https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html
16091610
"""
16101611

1611-
def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs):
1612+
def __init__(
1613+
self,
1614+
redis_cluster,
1615+
node=None,
1616+
host=None,
1617+
port=None,
1618+
push_handler_func=None,
1619+
**kwargs,
1620+
):
16121621
"""
16131622
When a pubsub instance is created without specifying a node, a single
16141623
node will be transparently chosen for the pubsub connection on the
@@ -1633,7 +1642,10 @@ def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs):
16331642
self.node_pubsub_mapping = {}
16341643
self._pubsubs_generator = self._pubsubs_generator()
16351644
super().__init__(
1636-
**kwargs, connection_pool=connection_pool, encoder=redis_cluster.encoder
1645+
connection_pool=connection_pool,
1646+
encoder=redis_cluster.encoder,
1647+
push_handler_func=push_handler_func,
1648+
**kwargs,
16371649
)
16381650

16391651
def set_pubsub_node(self, cluster, node=None, host=None, port=None):
@@ -1717,14 +1729,18 @@ def execute_command(self, *args):
17171729
# register a callback that re-subscribes to any channels we
17181730
# were listening to when we were disconnected
17191731
self.connection.register_connect_callback(self.on_connect)
1732+
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
1733+
self.connection._parser.set_push_handler(self.push_handler_func)
17201734
connection = self.connection
17211735
self._execute(connection, connection.send_command, *args)
17221736

17231737
def _get_node_pubsub(self, node):
17241738
try:
17251739
return self.node_pubsub_mapping[node.name]
17261740
except KeyError:
1727-
pubsub = node.redis_connection.pubsub()
1741+
pubsub = node.redis_connection.pubsub(
1742+
push_handler_func=self.push_handler_func
1743+
)
17281744
self.node_pubsub_mapping[node.name] = pubsub
17291745
return pubsub
17301746

redis/connection.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,14 +276,33 @@ def _error_message(self, exception):
276276
def on_connect(self):
277277
"Initialize the connection, authenticate and select a database"
278278
self._parser.on_connect(self)
279+
parser = self._parser
279280

281+
auth_args = None
280282
# if credential provider or username and/or password are set, authenticate
281283
if self.credential_provider or (self.username or self.password):
282284
cred_provider = (
283285
self.credential_provider
284286
or UsernamePasswordCredentialProvider(self.username, self.password)
285287
)
286288
auth_args = cred_provider.get_credentials()
289+
# if resp version is specified and we have auth args,
290+
# we need to send them via HELLO
291+
if auth_args and self.protocol != 2:
292+
if isinstance(self._parser, _RESP2Parser):
293+
self.set_parser(_RESP3Parser)
294+
# update cluster exception classes
295+
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
296+
self._parser.on_connect(self)
297+
if len(auth_args) == 1:
298+
auth_args = ["default", auth_args[0]]
299+
self.send_command("HELLO", self.protocol, "AUTH", *auth_args)
300+
response = self.read_response()
301+
if response.get(b"proto") != int(self.protocol) and response.get(
302+
"proto"
303+
) != int(self.protocol):
304+
raise ConnectionError("Invalid RESP version")
305+
elif auth_args:
287306
# avoid checking health here -- PING will fail if we try
288307
# to check the health prior to the AUTH
289308
self.send_command("AUTH", *auth_args, check_health=False)
@@ -302,9 +321,11 @@ def on_connect(self):
302321
raise AuthenticationError("Invalid Username or Password")
303322

304323
# if resp version is specified, switch to it
305-
if self.protocol != 2:
324+
elif self.protocol != 2:
306325
if isinstance(self._parser, _RESP2Parser):
307326
self.set_parser(_RESP3Parser)
327+
# update cluster exception classes
328+
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
308329
self._parser.on_connect(self)
309330
self.send_command("HELLO", self.protocol)
310331
response = self.read_response()

redis/parsers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .base import BaseParser
1+
from .base import BaseParser, _AsyncRESPBase
22
from .commands import AsyncCommandsParser, CommandsParser
33
from .encoders import Encoder
44
from .hiredis import _AsyncHiredisParser, _HiredisParser
@@ -8,6 +8,7 @@
88
__all__ = [
99
"AsyncCommandsParser",
1010
"_AsyncHiredisParser",
11+
"_AsyncRESPBase",
1112
"_AsyncRESP2Parser",
1213
"_AsyncRESP3Parser",
1314
"CommandsParser",

redis/parsers/resp3.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,12 @@ def _read_response(self, disable_decoding=False, push_request=False):
6969
# bool value
7070
elif byte == b"#":
7171
return response == b"t"
72-
# bulk response and verbatim strings
73-
elif byte in (b"$", b"="):
72+
# bulk response
73+
elif byte == b"$":
7474
response = self._buffer.read(int(response))
75+
# verbatim string response
76+
elif byte == b"=":
77+
response = self._buffer.read(int(response))[4:]
7578
# array response
7679
elif byte == b"*":
7780
response = [
@@ -195,9 +198,12 @@ async def _read_response(
195198
# bool value
196199
elif byte == b"#":
197200
return response == b"t"
198-
# bulk response and verbatim strings
199-
elif byte in (b"$", b"="):
201+
# bulk response
202+
elif byte == b"$":
200203
response = await self._read(int(response))
204+
# verbatim string response
205+
elif byte == b"=":
206+
response = (await self._read(int(response)))[4:]
201207
# array response
202208
elif byte == b"*":
203209
response = [

tests/conftest.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -475,8 +475,31 @@ def wait_for_command(client, monitor, command, key=None):
475475

476476

477477
def is_resp2_connection(r):
478-
if isinstance(r, redis.Redis):
478+
if isinstance(r, redis.Redis) or isinstance(r, redis.asyncio.Redis):
479479
protocol = r.connection_pool.connection_kwargs.get("protocol")
480-
elif isinstance(r, redis.RedisCluster):
480+
elif isinstance(r, redis.cluster.AbstractRedisCluster):
481481
protocol = r.nodes_manager.connection_kwargs.get("protocol")
482482
return protocol in ["2", 2, None]
483+
484+
485+
def get_protocol_version(r):
486+
if isinstance(r, redis.Redis) or isinstance(r, redis.asyncio.Redis):
487+
return r.connection_pool.connection_kwargs.get("protocol")
488+
elif isinstance(r, redis.cluster.AbstractRedisCluster):
489+
return r.nodes_manager.connection_kwargs.get("protocol")
490+
491+
492+
def assert_resp_response(r, response, resp2_expected, resp3_expected):
493+
protocol = get_protocol_version(r)
494+
if protocol in [2, "2", None]:
495+
assert response == resp2_expected
496+
else:
497+
assert response == resp3_expected
498+
499+
500+
def assert_resp_response_in(r, response, resp2_expected, resp3_expected):
501+
protocol = get_protocol_version(r)
502+
if protocol in [2, "2", None]:
503+
assert response in resp2_expected
504+
else:
505+
assert response in resp3_expected

tests/test_asyncio/conftest.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -236,29 +236,6 @@ async def wait_for_command(
236236
return None
237237

238238

239-
def get_protocol_version(r):
240-
if isinstance(r, redis.Redis):
241-
return r.connection_pool.connection_kwargs.get("protocol")
242-
elif isinstance(r, redis.RedisCluster):
243-
return r.nodes_manager.connection_kwargs.get("protocol")
244-
245-
246-
def assert_resp_response(r, response, resp2_expected, resp3_expected):
247-
protocol = get_protocol_version(r)
248-
if protocol in [2, "2", None]:
249-
assert response == resp2_expected
250-
else:
251-
assert response == resp3_expected
252-
253-
254-
def assert_resp_response_in(r, response, resp2_expected, resp3_expected):
255-
protocol = get_protocol_version(r)
256-
if protocol in [2, "2", None]:
257-
assert response in resp2_expected
258-
else:
259-
assert response in resp3_expected
260-
261-
262239
# python 3.6 doesn't have the asynccontextmanager decorator. Provide it here.
263240
class AsyncContextManager:
264241
def __init__(self, async_generator):

0 commit comments

Comments
 (0)