diff --git a/dev_requirements.txt b/dev_requirements.txt index 8285b0456f..8ffb1e944f 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -15,4 +15,5 @@ pytest-cov>=4.0.0 vulture>=2.3.0 ujson>=4.2.0 wheel>=0.30.0 +urllib3<2 uvloop diff --git a/redis/client.py b/redis/client.py index c303dbde38..ef327b5922 100755 --- a/redis/client.py +++ b/redis/client.py @@ -833,6 +833,7 @@ class AbstractRedis: "QUIT": bool_ok, "STRALGO": parse_stralgo, "PUBSUB NUMSUB": parse_pubsub_numsub, + "PUBSUB SHARDNUMSUB": parse_pubsub_numsub, "RANDOMKEY": lambda r: r and r or None, "RESET": str_if_bytes, "SCAN": parse_scan, @@ -1440,8 +1441,8 @@ class PubSub: will be returned and it's safe to start listening again. """ - PUBLISH_MESSAGE_TYPES = ("message", "pmessage") - UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe") + PUBLISH_MESSAGE_TYPES = ("message", "pmessage", "smessage") + UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe", "sunsubscribe") HEALTH_CHECK_MESSAGE = "redis-py-health-check" def __init__( @@ -1493,9 +1494,11 @@ def reset(self): self.connection.clear_connect_callbacks() self.connection_pool.release(self.connection) self.connection = None - self.channels = {} self.health_check_response_counter = 0 + self.channels = {} self.pending_unsubscribe_channels = set() + self.shard_channels = {} + self.pending_unsubscribe_shard_channels = set() self.patterns = {} self.pending_unsubscribe_patterns = set() self.subscribed_event.clear() @@ -1510,16 +1513,23 @@ def on_connect(self, connection): # before passing them to [p]subscribe. self.pending_unsubscribe_channels.clear() self.pending_unsubscribe_patterns.clear() + self.pending_unsubscribe_shard_channels.clear() if self.channels: - channels = {} - for k, v in self.channels.items(): - channels[self.encoder.decode(k, force=True)] = v + channels = { + self.encoder.decode(k, force=True): v for k, v in self.channels.items() + } self.subscribe(**channels) if self.patterns: - patterns = {} - for k, v in self.patterns.items(): - patterns[self.encoder.decode(k, force=True)] = v + patterns = { + self.encoder.decode(k, force=True): v for k, v in self.patterns.items() + } self.psubscribe(**patterns) + if self.shard_channels: + shard_channels = { + self.encoder.decode(k, force=True): v + for k, v in self.shard_channels.items() + } + self.ssubscribe(**shard_channels) @property def subscribed(self): @@ -1728,6 +1738,45 @@ def unsubscribe(self, *args): self.pending_unsubscribe_channels.update(channels) return self.execute_command("UNSUBSCRIBE", *args) + def ssubscribe(self, *args, target_node=None, **kwargs): + """ + Subscribes the client to the specified shard channels. + Channels supplied as keyword arguments expect a channel name as the key + and a callable as the value. A channel's callable will be invoked automatically + when a message is received on that channel rather than producing a message via + ``listen()`` or ``get_sharded_message()``. + """ + if args: + args = list_or_args(args[0], args[1:]) + new_s_channels = dict.fromkeys(args) + new_s_channels.update(kwargs) + ret_val = self.execute_command("SSUBSCRIBE", *new_s_channels.keys()) + # update the s_channels dict AFTER we send the command. we don't want to + # subscribe twice to these channels, once for the command and again + # for the reconnection. + new_s_channels = self._normalize_keys(new_s_channels) + self.shard_channels.update(new_s_channels) + if not self.subscribed: + # Set the subscribed_event flag to True + self.subscribed_event.set() + # Clear the health check counter + self.health_check_response_counter = 0 + self.pending_unsubscribe_shard_channels.difference_update(new_s_channels) + return ret_val + + def sunsubscribe(self, *args, target_node=None): + """ + Unsubscribe from the supplied shard_channels. If empty, unsubscribe from + all shard_channels + """ + if args: + args = list_or_args(args[0], args[1:]) + s_channels = self._normalize_keys(dict.fromkeys(args)) + else: + s_channels = self.shard_channels + self.pending_unsubscribe_shard_channels.update(s_channels) + return self.execute_command("SUNSUBSCRIBE", *args) + def listen(self): "Listen for messages on channels this client has been subscribed to" while self.subscribed: @@ -1762,6 +1811,8 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0): return self.handle_message(response, ignore_subscribe_messages) return None + get_sharded_message = get_message + def ping(self, message=None): """ Ping the Redis server @@ -1809,12 +1860,17 @@ def handle_message(self, response, ignore_subscribe_messages=False): if pattern in self.pending_unsubscribe_patterns: self.pending_unsubscribe_patterns.remove(pattern) self.patterns.pop(pattern, None) + elif message_type == "sunsubscribe": + s_channel = response[1] + if s_channel in self.pending_unsubscribe_shard_channels: + self.pending_unsubscribe_shard_channels.remove(s_channel) + self.shard_channels.pop(s_channel, None) else: channel = response[1] if channel in self.pending_unsubscribe_channels: self.pending_unsubscribe_channels.remove(channel) self.channels.pop(channel, None) - if not self.channels and not self.patterns: + if not self.channels and not self.patterns and not self.shard_channels: # There are no subscriptions anymore, set subscribed_event flag # to false self.subscribed_event.clear() @@ -1823,6 +1879,8 @@ def handle_message(self, response, ignore_subscribe_messages=False): # if there's a message handler, invoke it if message_type == "pmessage": handler = self.patterns.get(message["pattern"], None) + elif message_type == "smessage": + handler = self.shard_channels.get(message["channel"], None) else: handler = self.channels.get(message["channel"], None) if handler: @@ -1843,6 +1901,11 @@ def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None): for pattern, handler in self.patterns.items(): if handler is None: raise PubSubError(f"Pattern: '{pattern}' has no handler registered") + for s_channel, handler in self.shard_channels.items(): + if handler is None: + raise PubSubError( + f"Shard Channel: '{s_channel}' has no handler registered" + ) thread = PubSubWorkerThread( self, sleep_time, daemon=daemon, exception_handler=exception_handler diff --git a/redis/cluster.py b/redis/cluster.py index 182ec6d733..d3956e45f5 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -9,6 +9,7 @@ from redis.backoff import default_backoff from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan from redis.commands import READ_COMMANDS, RedisClusterCommands +from redis.commands.helpers import list_or_args from redis.connection import ConnectionPool, DefaultParser, parse_url from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.exceptions import ( @@ -222,6 +223,8 @@ class AbstractRedisCluster: "PUBSUB CHANNELS", "PUBSUB NUMPAT", "PUBSUB NUMSUB", + "PUBSUB SHARDCHANNELS", + "PUBSUB SHARDNUMSUB", "PING", "INFO", "SHUTDOWN", @@ -346,11 +349,13 @@ class AbstractRedisCluster: } RESULT_CALLBACKS = dict_merge( - list_keys_to_dict(["PUBSUB NUMSUB"], parse_pubsub_numsub), + list_keys_to_dict(["PUBSUB NUMSUB", "PUBSUB SHARDNUMSUB"], parse_pubsub_numsub), list_keys_to_dict( ["PUBSUB NUMPAT"], lambda command, res: sum(list(res.values())) ), - list_keys_to_dict(["KEYS", "PUBSUB CHANNELS"], merge_result), + list_keys_to_dict( + ["KEYS", "PUBSUB CHANNELS", "PUBSUB SHARDCHANNELS"], merge_result + ), list_keys_to_dict( [ "PING", @@ -1625,6 +1630,8 @@ def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs): else redis_cluster.get_redis_connection(self.node).connection_pool ) self.cluster = redis_cluster + self.node_pubsub_mapping = {} + self._pubsubs_generator = self._pubsubs_generator() super().__init__( **kwargs, connection_pool=connection_pool, encoder=redis_cluster.encoder ) @@ -1678,9 +1685,9 @@ def _raise_on_invalid_node(self, redis_cluster, node, host, port): f"Node {host}:{port} doesn't exist in the cluster" ) - def execute_command(self, *args, **kwargs): + def execute_command(self, *args): """ - Execute a publish/subscribe command. + Execute a subscribe/unsubscribe command. Taken code from redis-py and tweak to make it work within a cluster. """ @@ -1713,6 +1720,87 @@ def execute_command(self, *args, **kwargs): connection = self.connection self._execute(connection, connection.send_command, *args) + def _get_node_pubsub(self, node): + try: + return self.node_pubsub_mapping[node.name] + except KeyError: + pubsub = node.redis_connection.pubsub() + self.node_pubsub_mapping[node.name] = pubsub + return pubsub + + def _sharded_message_generator(self): + for _ in range(len(self.node_pubsub_mapping)): + pubsub = next(self._pubsubs_generator) + message = pubsub.get_message() + if message is not None: + return message + return None + + def _pubsubs_generator(self): + while True: + for pubsub in self.node_pubsub_mapping.values(): + yield pubsub + + def get_sharded_message( + self, ignore_subscribe_messages=False, timeout=0.0, target_node=None + ): + if target_node: + message = self.node_pubsub_mapping[target_node.name].get_message( + ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout + ) + else: + message = self._sharded_message_generator() + if message is None: + return None + elif str_if_bytes(message["type"]) == "sunsubscribe": + if message["channel"] in self.pending_unsubscribe_shard_channels: + self.pending_unsubscribe_shard_channels.remove(message["channel"]) + self.shard_channels.pop(message["channel"], None) + node = self.cluster.get_node_from_key(message["channel"]) + if self.node_pubsub_mapping[node.name].subscribed is False: + self.node_pubsub_mapping.pop(node.name) + if not self.channels and not self.patterns and not self.shard_channels: + # There are no subscriptions anymore, set subscribed_event flag + # to false + self.subscribed_event.clear() + if self.ignore_subscribe_messages or ignore_subscribe_messages: + return None + return message + + def ssubscribe(self, *args, **kwargs): + if args: + args = list_or_args(args[0], args[1:]) + s_channels = dict.fromkeys(args) + s_channels.update(kwargs) + for s_channel, handler in s_channels.items(): + node = self.cluster.get_node_from_key(s_channel) + pubsub = self._get_node_pubsub(node) + if handler: + pubsub.ssubscribe(**{s_channel: handler}) + else: + pubsub.ssubscribe(s_channel) + self.shard_channels.update(pubsub.shard_channels) + self.pending_unsubscribe_shard_channels.difference_update( + self._normalize_keys({s_channel: None}) + ) + if pubsub.subscribed and not self.subscribed: + self.subscribed_event.set() + self.health_check_response_counter = 0 + + def sunsubscribe(self, *args): + if args: + args = list_or_args(args[0], args[1:]) + else: + args = self.shard_channels + + for s_channel in args: + node = self.cluster.get_node_from_key(s_channel) + p = self._get_node_pubsub(node) + p.sunsubscribe(s_channel) + self.pending_unsubscribe_shard_channels.update( + p.pending_unsubscribe_shard_channels + ) + def get_redis_connection(self): """ Get the Redis connection of the pubsub connected node. @@ -1720,6 +1808,15 @@ def get_redis_connection(self): if self.node is not None: return self.node.redis_connection + def disconnect(self): + """ + Disconnect the pubsub connection. + """ + if self.connection: + self.connection.disconnect() + for pubsub in self.node_pubsub_mapping.values(): + pubsub.connection.disconnect() + class ClusterPipeline(RedisCluster): """ diff --git a/redis/commands/core.py b/redis/commands/core.py index e2cabb85fa..6676ea8d71 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -5103,6 +5103,15 @@ def publish(self, channel: ChannelT, message: EncodableT, **kwargs) -> ResponseT """ return self.execute_command("PUBLISH", channel, message, **kwargs) + def spublish(self, shard_channel: ChannelT, message: EncodableT) -> ResponseT: + """ + Posts a message to the given shard channel. + Returns the number of clients that received the message + + For more information see https://redis.io/commands/spublish + """ + return self.execute_command("SPUBLISH", shard_channel, message) + def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: """ Return a list of channels that have at least one subscriber @@ -5111,6 +5120,14 @@ def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: """ return self.execute_command("PUBSUB CHANNELS", pattern, **kwargs) + def pubsub_shardchannels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: + """ + Return a list of shard_channels that have at least one subscriber + + For more information see https://redis.io/commands/pubsub-shardchannels + """ + return self.execute_command("PUBSUB SHARDCHANNELS", pattern, **kwargs) + def pubsub_numpat(self, **kwargs) -> ResponseT: """ Returns the number of subscriptions to patterns @@ -5128,6 +5145,15 @@ def pubsub_numsub(self, *args: ChannelT, **kwargs) -> ResponseT: """ return self.execute_command("PUBSUB NUMSUB", *args, **kwargs) + def pubsub_shardnumsub(self, *args: ChannelT, **kwargs) -> ResponseT: + """ + Return a list of (shard_channel, number of subscribers) tuples + for each channel given in ``*args`` + + For more information see https://redis.io/commands/pubsub-shardnumsub + """ + return self.execute_command("PUBSUB SHARDNUMSUB", *args, **kwargs) + AsyncPubSubCommands = PubSubCommands diff --git a/redis/parsers/commands.py b/redis/parsers/commands.py index 2ea29a75ae..d3b4a99ed3 100644 --- a/redis/parsers/commands.py +++ b/redis/parsers/commands.py @@ -155,13 +155,13 @@ def _get_pubsub_keys(self, *args): # the second argument is a part of the command name, e.g. # ['PUBSUB', 'NUMSUB', 'foo']. pubsub_type = args[1].upper() - if pubsub_type in ["CHANNELS", "NUMSUB"]: + if pubsub_type in ["CHANNELS", "NUMSUB", "SHARDCHANNELS", "SHARDNUMSUB"]: keys = args[2:] elif command in ["SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE"]: # format example: # SUBSCRIBE channel [channel ...] keys = list(args[1:]) - elif command == "PUBLISH": + elif command in ["PUBLISH", "SPUBLISH"]: # format example: # PUBLISH channel message keys = [args[1]] diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 8cd5cf6fba..412398f37b 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -675,18 +675,15 @@ async def loop(): nonlocal interrupt await pubsub.subscribe("foo") while True: - # print("loop") try: try: await pubsub.connect() await loop_step() - # print("succ") except redis.ConnectionError: await asyncio.sleep(0.1) except asyncio.CancelledError: # we use a cancel to interrupt the "listen" # when we perform a disconnect - # print("cancel", interrupt) if interrupt: interrupt = False else: @@ -919,7 +916,6 @@ async def loop(self): try: if self.state == 4: break - # print("state a ", self.state) got_msg = await self.get_message() assert got_msg if self.state in (1, 2): @@ -937,7 +933,6 @@ async def loop(self): async def loop_step_get_message(self): # get a single message via get_message message = await self.pubsub.get_message(timeout=0.1) - # print(message) if message is not None: await self.messages.put(message) return True diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index e1e4311511..2f6b4bad80 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -3,6 +3,7 @@ import socket import threading import time +from collections import defaultdict from unittest import mock from unittest.mock import patch @@ -20,13 +21,22 @@ ) -def wait_for_message(pubsub, timeout=0.5, ignore_subscribe_messages=False): +def wait_for_message( + pubsub, timeout=0.5, ignore_subscribe_messages=False, node=None, func=None +): now = time.time() timeout = now + timeout while now < timeout: - message = pubsub.get_message( - ignore_subscribe_messages=ignore_subscribe_messages - ) + if node: + message = pubsub.get_sharded_message( + ignore_subscribe_messages=ignore_subscribe_messages, target_node=node + ) + elif func: + message = func(ignore_subscribe_messages=ignore_subscribe_messages) + else: + message = pubsub.get_message( + ignore_subscribe_messages=ignore_subscribe_messages + ) if message is not None: return message time.sleep(0.01) @@ -53,6 +63,15 @@ def make_subscribe_test_data(pubsub, type): "unsub_func": pubsub.unsubscribe, "keys": ["foo", "bar", "uni" + chr(4456) + "code"], } + elif type == "shard_channel": + return { + "p": pubsub, + "sub_type": "ssubscribe", + "unsub_type": "sunsubscribe", + "sub_func": pubsub.ssubscribe, + "unsub_func": pubsub.sunsubscribe, + "keys": ["foo", "bar", "uni" + chr(4456) + "code"], + } elif type == "pattern": return { "p": pubsub, @@ -93,6 +112,44 @@ def test_pattern_subscribe_unsubscribe(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_subscribe_unsubscribe(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_subscribe_unsubscribe(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "shard_channel") + self._test_subscribe_unsubscribe(**kwargs) + + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_subscribe_unsubscribe_cluster(self, r): + node_channels = defaultdict(int) + p = r.pubsub() + keys = { + "foo": r.get_node_from_key("foo"), + "bar": r.get_node_from_key("bar"), + "uni" + chr(4456) + "code": r.get_node_from_key("uni" + chr(4456) + "code"), + } + + for key, node in keys.items(): + assert p.ssubscribe(key) is None + + # should be a message for each shard_channel we just subscribed to + for key, node in keys.items(): + node_channels[node.name] += 1 + assert wait_for_message(p, node=node) == make_message( + "ssubscribe", key, node_channels[node.name] + ) + + for key in keys.keys(): + assert p.sunsubscribe(key) is None + + # should be a message for each shard_channel we just unsubscribed + # from + for key, node in keys.items(): + node_channels[node.name] -= 1 + assert wait_for_message(p, node=node) == make_message( + "sunsubscribe", key, node_channels[node.name] + ) + def _test_resubscribe_on_reconnection( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): @@ -136,6 +193,12 @@ def test_resubscribe_to_patterns_on_reconnection(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_resubscribe_on_reconnection(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_resubscribe_to_shard_channels_on_reconnection(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "shard_channel") + self._test_resubscribe_on_reconnection(**kwargs) + def _test_subscribed_property( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): @@ -192,38 +255,111 @@ def test_subscribe_property_with_patterns(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_subscribed_property(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_subscribe_property_with_shard_channels(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "shard_channel") + self._test_subscribed_property(**kwargs) + + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_subscribe_property_with_shard_channels_cluster(self, r): + p = r.pubsub() + keys = ["foo", "bar", "uni" + chr(4456) + "code"] + nodes = [r.get_node_from_key(key) for key in keys] + assert p.subscribed is False + p.ssubscribe(keys[0]) + # we're now subscribed even though we haven't processed the + # reply from the server just yet + assert p.subscribed is True + assert wait_for_message(p, node=nodes[0]) == make_message( + "ssubscribe", keys[0], 1 + ) + # we're still subscribed + assert p.subscribed is True + + # unsubscribe from all shard_channels + p.sunsubscribe() + # we're still technically subscribed until we process the + # response messages from the server + assert p.subscribed is True + assert wait_for_message(p, node=nodes[0]) == make_message( + "sunsubscribe", keys[0], 0 + ) + # now we're no longer subscribed as no more messages can be delivered + # to any channels we were listening to + assert p.subscribed is False + + # subscribing again flips the flag back + p.ssubscribe(keys[0]) + assert p.subscribed is True + assert wait_for_message(p, node=nodes[0]) == make_message( + "ssubscribe", keys[0], 1 + ) + + # unsubscribe again + p.sunsubscribe() + assert p.subscribed is True + # subscribe to another shard_channel before reading the unsubscribe response + p.ssubscribe(keys[1]) + assert p.subscribed is True + # read the unsubscribe for key1 + assert wait_for_message(p, node=nodes[0]) == make_message( + "sunsubscribe", keys[0], 0 + ) + # we're still subscribed to key2, so subscribed should still be True + assert p.subscribed is True + # read the key2 subscribe message + assert wait_for_message(p, node=nodes[1]) == make_message( + "ssubscribe", keys[1], 1 + ) + p.sunsubscribe() + # haven't read the message yet, so we're still subscribed + assert p.subscribed is True + assert wait_for_message(p, node=nodes[1]) == make_message( + "sunsubscribe", keys[1], 0 + ) + # now we're finally unsubscribed + assert p.subscribed is False + + @skip_if_server_version_lt("7.0.0") def test_ignore_all_subscribe_messages(self, r): p = r.pubsub(ignore_subscribe_messages=True) checks = ( - (p.subscribe, "foo"), - (p.unsubscribe, "foo"), - (p.psubscribe, "f*"), - (p.punsubscribe, "f*"), + (p.subscribe, "foo", p.get_message), + (p.unsubscribe, "foo", p.get_message), + (p.psubscribe, "f*", p.get_message), + (p.punsubscribe, "f*", p.get_message), + (p.ssubscribe, "foo", p.get_sharded_message), + (p.sunsubscribe, "foo", p.get_sharded_message), ) assert p.subscribed is False - for func, channel in checks: + for func, channel, get_func in checks: assert func(channel) is None assert p.subscribed is True - assert wait_for_message(p) is None + assert wait_for_message(p, func=get_func) is None assert p.subscribed is False + @skip_if_server_version_lt("7.0.0") def test_ignore_individual_subscribe_messages(self, r): p = r.pubsub() checks = ( - (p.subscribe, "foo"), - (p.unsubscribe, "foo"), - (p.psubscribe, "f*"), - (p.punsubscribe, "f*"), + (p.subscribe, "foo", p.get_message), + (p.unsubscribe, "foo", p.get_message), + (p.psubscribe, "f*", p.get_message), + (p.punsubscribe, "f*", p.get_message), + (p.ssubscribe, "foo", p.get_sharded_message), + (p.sunsubscribe, "foo", p.get_sharded_message), ) assert p.subscribed is False - for func, channel in checks: + for func, channel, get_func in checks: assert func(channel) is None assert p.subscribed is True - message = wait_for_message(p, ignore_subscribe_messages=True) + message = wait_for_message(p, ignore_subscribe_messages=True, func=get_func) assert message is None assert p.subscribed is False @@ -236,6 +372,12 @@ def test_sub_unsub_resub_patterns(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_sub_unsub_resub(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_sub_unsub_resub_shard_channels(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "shard_channel") + self._test_sub_unsub_resub(**kwargs) + def _test_sub_unsub_resub( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): @@ -250,6 +392,26 @@ def _test_sub_unsub_resub( assert wait_for_message(p) == make_message(sub_type, key, 1) assert p.subscribed is True + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_sub_unsub_resub_shard_channels_cluster(self, r): + p = r.pubsub() + key = "foo" + p.ssubscribe(key) + p.sunsubscribe(key) + p.ssubscribe(key) + assert p.subscribed is True + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "ssubscribe", key, 1 + ) + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "sunsubscribe", key, 0 + ) + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "ssubscribe", key, 1 + ) + assert p.subscribed is True + def test_sub_unsub_all_resub_channels(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "channel") self._test_sub_unsub_all_resub(**kwargs) @@ -258,6 +420,12 @@ def test_sub_unsub_all_resub_patterns(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_sub_unsub_all_resub(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_sub_unsub_all_resub_shard_channels(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "shard_channel") + self._test_sub_unsub_all_resub(**kwargs) + def _test_sub_unsub_all_resub( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): @@ -272,6 +440,26 @@ def _test_sub_unsub_all_resub( assert wait_for_message(p) == make_message(sub_type, key, 1) assert p.subscribed is True + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_sub_unsub_all_resub_shard_channels_cluster(self, r): + p = r.pubsub() + key = "foo" + p.ssubscribe(key) + p.sunsubscribe() + p.ssubscribe(key) + assert p.subscribed is True + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "ssubscribe", key, 1 + ) + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "sunsubscribe", key, 0 + ) + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "ssubscribe", key, 1 + ) + assert p.subscribed is True + class TestPubSubMessages: def setup_method(self, method): @@ -290,6 +478,32 @@ def test_published_message_to_channel(self, r): assert isinstance(message, dict) assert message == make_message("message", "foo", "test message") + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_published_message_to_shard_channel(self, r): + p = r.pubsub() + p.ssubscribe("foo") + assert wait_for_message(p) == make_message("ssubscribe", "foo", 1) + assert r.spublish("foo", "test message") == 1 + + message = wait_for_message(p) + assert isinstance(message, dict) + assert message == make_message("smessage", "foo", "test message") + + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_published_message_to_shard_channel_cluster(self, r): + p = r.pubsub() + p.ssubscribe("foo") + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "ssubscribe", "foo", 1 + ) + assert r.spublish("foo", "test message") == 1 + + message = wait_for_message(p, func=p.get_sharded_message) + assert isinstance(message, dict) + assert message == make_message("smessage", "foo", "test message") + def test_published_message_to_pattern(self, r): p = r.pubsub() p.subscribe("foo") @@ -321,6 +535,15 @@ def test_channel_message_handler(self, r): assert wait_for_message(p) is None assert self.message == make_message("message", "foo", "test message") + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.ssubscribe(foo=self.message_handler) + assert wait_for_message(p, func=p.get_sharded_message) is None + assert r.spublish("foo", "test message") == 1 + assert wait_for_message(p, func=p.get_sharded_message) is None + assert self.message == make_message("smessage", "foo", "test message") + @pytest.mark.onlynoncluster def test_pattern_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) @@ -342,6 +565,17 @@ def test_unicode_channel_message_handler(self, r): assert wait_for_message(p) is None assert self.message == make_message("message", channel, "test message") + @skip_if_server_version_lt("7.0.0") + def test_unicode_shard_channel_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + channel = "uni" + chr(4456) + "code" + channels = {channel: self.message_handler} + p.ssubscribe(**channels) + assert wait_for_message(p, func=p.get_sharded_message) is None + assert r.spublish(channel, "test message") == 1 + assert wait_for_message(p, func=p.get_sharded_message) is None + assert self.message == make_message("smessage", channel, "test message") + @pytest.mark.onlynoncluster # see: https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html # #known-limitations-with-pubsub @@ -411,6 +645,19 @@ def test_pattern_subscribe_unsubscribe(self, r): p.punsubscribe(self.pattern) assert wait_for_message(p) == self.make_message("punsubscribe", self.pattern, 0) + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_subscribe_unsubscribe(self, r): + p = r.pubsub() + p.ssubscribe(self.channel) + assert wait_for_message(p, func=p.get_sharded_message) == self.make_message( + "ssubscribe", self.channel, 1 + ) + + p.sunsubscribe(self.channel) + assert wait_for_message(p, func=p.get_sharded_message) == self.make_message( + "sunsubscribe", self.channel, 0 + ) + def test_channel_publish(self, r): p = r.pubsub() p.subscribe(self.channel) @@ -430,6 +677,18 @@ def test_pattern_publish(self, r): "pmessage", self.channel, self.data, pattern=self.pattern ) + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_publish(self, r): + p = r.pubsub() + p.ssubscribe(self.channel) + assert wait_for_message(p, func=p.get_sharded_message) == self.make_message( + "ssubscribe", self.channel, 1 + ) + r.spublish(self.channel, self.data) + assert wait_for_message(p, func=p.get_sharded_message) == self.make_message( + "smessage", self.channel, self.data + ) + def test_channel_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) p.subscribe(**{self.channel: self.message_handler}) @@ -468,6 +727,30 @@ def test_pattern_message_handler(self, r): "pmessage", self.channel, new_data, pattern=self.pattern ) + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.ssubscribe(**{self.channel: self.message_handler}) + assert wait_for_message(p, func=p.get_sharded_message) is None + r.spublish(self.channel, self.data) + assert wait_for_message(p, func=p.get_sharded_message) is None + assert self.message == self.make_message("smessage", self.channel, self.data) + + # test that we reconnected to the correct channel + self.message = None + try: + # cluster mode + p.disconnect() + except AttributeError: + # standalone mode + p.connection.disconnect() + # should reconnect + assert wait_for_message(p, func=p.get_sharded_message) is None + new_data = self.data + "new data" + r.spublish(self.channel, new_data) + assert wait_for_message(p, func=p.get_sharded_message) is None + assert self.message == self.make_message("smessage", self.channel, new_data) + def test_context_manager(self, r): with r.pubsub() as pubsub: pubsub.subscribe("foo") @@ -497,6 +780,38 @@ def test_pubsub_channels(self, r): expected = [b"bar", b"baz", b"foo", b"quux"] assert all([channel in r.pubsub_channels() for channel in expected]) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_pubsub_shardchannels(self, r): + p = r.pubsub() + p.ssubscribe("foo", "bar", "baz", "quux") + for i in range(4): + assert wait_for_message(p)["type"] == "ssubscribe" + expected = [b"bar", b"baz", b"foo", b"quux"] + assert all([channel in r.pubsub_shardchannels() for channel in expected]) + + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_pubsub_shardchannels_cluster(self, r): + channels = { + b"foo": r.get_node_from_key("foo"), + b"bar": r.get_node_from_key("bar"), + b"baz": r.get_node_from_key("baz"), + b"quux": r.get_node_from_key("quux"), + } + p = r.pubsub() + p.ssubscribe("foo", "bar", "baz", "quux") + for node in channels.values(): + assert wait_for_message(p, node=node)["type"] == "ssubscribe" + for channel, node in channels.items(): + assert channel in r.pubsub_shardchannels(target_nodes=node) + assert all( + [ + channel in r.pubsub_shardchannels(target_nodes="all") + for channel in channels.keys() + ] + ) + @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.8.0") def test_pubsub_numsub(self, r): @@ -523,6 +838,32 @@ def test_pubsub_numpat(self, r): assert wait_for_message(p)["type"] == "psubscribe" assert r.pubsub_numpat() == 3 + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_pubsub_shardnumsub(self, r): + channels = { + b"foo": r.get_node_from_key("foo"), + b"bar": r.get_node_from_key("bar"), + b"baz": r.get_node_from_key("baz"), + } + p1 = r.pubsub() + p1.ssubscribe(*channels.keys()) + for node in channels.values(): + assert wait_for_message(p1, node=node)["type"] == "ssubscribe" + p2 = r.pubsub() + p2.ssubscribe("bar", "baz") + for i in range(2): + assert ( + wait_for_message(p2, func=p2.get_sharded_message)["type"] + == "ssubscribe" + ) + p3 = r.pubsub() + p3.ssubscribe("baz") + assert wait_for_message(p3, node=channels[b"baz"])["type"] == "ssubscribe" + + channels = [(b"foo", 1), (b"bar", 2), (b"baz", 3)] + assert r.pubsub_shardnumsub("foo", "bar", "baz", target_nodes="all") == channels + class TestPubSubPings: @skip_if_server_version_lt("3.0.0")