Skip to content

Commit 772079f

Browse files
KMilhanchayim
andauthored
Enable AsyncIO cluster mode lock (#2446)
Co-authored-by: Chayim <chayim@users.noreply.github.com>
1 parent 1cdba63 commit 772079f

File tree

5 files changed

+90
-7
lines changed

5 files changed

+90
-7
lines changed

CHANGES

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
* ClusterPipeline Doesn't Handle ConnectionError for Dead Hosts (#2225)
2525
* Remove compatibility code for old versions of Hiredis, drop Packaging dependency
2626
* The `deprecated` library is no longer a dependency
27+
* Enable Lock for asyncio cluster mode
2728

2829
* 4.1.3 (Feb 8, 2022)
2930
* Fix flushdb and flushall (#1926)

redis/asyncio/cluster.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
SSLConnection,
2525
parse_url,
2626
)
27+
from redis.asyncio.lock import Lock
2728
from redis.asyncio.parser import CommandsParser
2829
from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
2930
from redis.cluster import (
@@ -764,6 +765,72 @@ def pipeline(
764765

765766
return ClusterPipeline(self)
766767

768+
def lock(
769+
self,
770+
name: KeyT,
771+
timeout: Optional[float] = None,
772+
sleep: float = 0.1,
773+
blocking_timeout: Optional[float] = None,
774+
lock_class: Optional[Type[Lock]] = None,
775+
thread_local: bool = True,
776+
) -> Lock:
777+
"""
778+
Return a new Lock object using key ``name`` that mimics
779+
the behavior of threading.Lock.
780+
781+
If specified, ``timeout`` indicates a maximum life for the lock.
782+
By default, it will remain locked until release() is called.
783+
784+
``sleep`` indicates the amount of time to sleep per loop iteration
785+
when the lock is in blocking mode and another client is currently
786+
holding the lock.
787+
788+
``blocking_timeout`` indicates the maximum amount of time in seconds to
789+
spend trying to acquire the lock. A value of ``None`` indicates
790+
continue trying forever. ``blocking_timeout`` can be specified as a
791+
float or integer, both representing the number of seconds to wait.
792+
793+
``lock_class`` forces the specified lock implementation. Note that as
794+
of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
795+
a Lua-based lock). So, it's unlikely you'll need this parameter, unless
796+
you have created your own custom lock class.
797+
798+
``thread_local`` indicates whether the lock token is placed in
799+
thread-local storage. By default, the token is placed in thread local
800+
storage so that a thread only sees its token, not a token set by
801+
another thread. Consider the following timeline:
802+
803+
time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
804+
thread-1 sets the token to "abc"
805+
time: 1, thread-2 blocks trying to acquire `my-lock` using the
806+
Lock instance.
807+
time: 5, thread-1 has not yet completed. redis expires the lock
808+
key.
809+
time: 5, thread-2 acquired `my-lock` now that it's available.
810+
thread-2 sets the token to "xyz"
811+
time: 6, thread-1 finishes its work and calls release(). if the
812+
token is *not* stored in thread local storage, then
813+
thread-1 would see the token value as "xyz" and would be
814+
able to successfully release the thread-2's lock.
815+
816+
In some use cases it's necessary to disable thread local storage. For
817+
example, if you have code where one thread acquires a lock and passes
818+
that lock instance to a worker thread to release later. If thread
819+
local storage isn't disabled in this case, the worker thread won't see
820+
the token set by the thread that acquired the lock. Our assumption
821+
is that these cases aren't common and as such default to using
822+
thread local storage."""
823+
if lock_class is None:
824+
lock_class = Lock
825+
return lock_class(
826+
self,
827+
name,
828+
timeout=timeout,
829+
sleep=sleep,
830+
blocking_timeout=blocking_timeout,
831+
thread_local=thread_local,
832+
)
833+
767834

768835
class ClusterNode:
769836
"""

redis/asyncio/lock.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from redis.exceptions import LockError, LockNotOwnedError
88

99
if TYPE_CHECKING:
10-
from redis.asyncio import Redis
10+
from redis.asyncio import Redis, RedisCluster
1111

1212

1313
class Lock:
@@ -77,7 +77,7 @@ class Lock:
7777

7878
def __init__(
7979
self,
80-
redis: "Redis",
80+
redis: Union["Redis", "RedisCluster"],
8181
name: Union[str, bytes, memoryview],
8282
timeout: Optional[float] = None,
8383
sleep: float = 0.1,
@@ -189,7 +189,11 @@ async def acquire(
189189
if token is None:
190190
token = uuid.uuid1().hex.encode()
191191
else:
192-
encoder = self.redis.connection_pool.get_encoder()
192+
try:
193+
encoder = self.redis.connection_pool.get_encoder()
194+
except AttributeError:
195+
# Cluster
196+
encoder = self.redis.get_encoder()
193197
token = encoder.encode(token)
194198
if blocking is None:
195199
blocking = self.blocking
@@ -233,7 +237,11 @@ async def owned(self) -> bool:
233237
# need to always compare bytes to bytes
234238
# TODO: this can be simplified when the context manager is finished
235239
if stored_token and not isinstance(stored_token, bytes):
236-
encoder = self.redis.connection_pool.get_encoder()
240+
try:
241+
encoder = self.redis.connection_pool.get_encoder()
242+
except AttributeError:
243+
# Cluster
244+
encoder = self.redis.get_encoder()
237245
stored_token = encoder.encode(stored_token)
238246
return self.local.token is not None and stored_token == self.local.token
239247

redis/commands/core.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4930,7 +4930,11 @@ def __init__(self, registered_client: "Redis", script: ScriptTextT):
49304930
if isinstance(script, str):
49314931
# We need the encoding from the client in order to generate an
49324932
# accurate byte representation of the script
4933-
encoder = registered_client.connection_pool.get_encoder()
4933+
try:
4934+
encoder = registered_client.connection_pool.get_encoder()
4935+
except AttributeError:
4936+
# Cluster
4937+
encoder = registered_client.get_encoder()
49344938
script = encoder.encode(script)
49354939
self.sha = hashlib.sha1(script).hexdigest()
49364940

@@ -4975,7 +4979,11 @@ def __init__(self, registered_client: "AsyncRedis", script: ScriptTextT):
49754979
if isinstance(script, str):
49764980
# We need the encoding from the client in order to generate an
49774981
# accurate byte representation of the script
4978-
encoder = registered_client.connection_pool.get_encoder()
4982+
try:
4983+
encoder = registered_client.connection_pool.get_encoder()
4984+
except AttributeError:
4985+
# Cluster
4986+
encoder = registered_client.get_encoder()
49794987
script = encoder.encode(script)
49804988
self.sha = hashlib.sha1(script).hexdigest()
49814989

tests/test_asyncio/test_lock.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from redis.exceptions import LockError, LockNotOwnedError
88

99

10-
@pytest.mark.onlynoncluster
1110
class TestLock:
1211
@pytest_asyncio.fixture()
1312
async def r_decoded(self, create_redis):

0 commit comments

Comments
 (0)