Skip to content

Commit 948a142

Browse files
authored
Merge branch 'master' into usr/aksinha334/redis-py#issue2598
2 parents 39bfa18 + a372ba4 commit 948a142

File tree

8 files changed

+104
-56
lines changed

8 files changed

+104
-56
lines changed

docs/examples/connection_examples.ipynb

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -222,18 +222,23 @@
222222
"import json\n",
223223
"import cachetools.func\n",
224224
"\n",
225-
"sm_client = boto3.client('secretsmanager')\n",
226-
" \n",
227-
"def sm_auth_provider(self, secret_id, version_id=None, version_stage='AWSCURRENT'):\n",
228-
" @cachetools.func.ttl_cache(maxsize=128, ttl=24 * 60 * 60) #24h\n",
229-
" def get_sm_user_credentials(secret_id, version_id, version_stage):\n",
230-
" secret = sm_client.get_secret_value(secret_id, version_id)\n",
231-
" return json.loads(secret['SecretString'])\n",
232-
" creds = get_sm_user_credentials(secret_id, version_id, version_stage)\n",
233-
" return creds['username'], creds['password']\n",
225+
"class SecretsManagerProvider(redis.CredentialProvider):\n",
226+
" def __init__(self, secret_id, version_id=None, version_stage='AWSCURRENT'):\n",
227+
" self.sm_client = boto3.client('secretsmanager')\n",
228+
" self.secret_id = secret_id\n",
229+
" self.version_id = version_id\n",
230+
" self.version_stage = version_stage\n",
234231
"\n",
235-
"secret_id = \"EXAMPLE1-90ab-cdef-fedc-ba987SECRET1\"\n",
236-
"creds_provider = redis.CredentialProvider(supplier=sm_auth_provider, secret_id=secret_id)\n",
232+
" def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:\n",
233+
" @cachetools.func.ttl_cache(maxsize=128, ttl=24 * 60 * 60) #24h\n",
234+
" def get_sm_user_credentials(secret_id, version_id, version_stage):\n",
235+
" secret = self.sm_client.get_secret_value(secret_id, version_id)\n",
236+
" return json.loads(secret['SecretString'])\n",
237+
" creds = get_sm_user_credentials(self.secret_id, self.version_id, self.version_stage)\n",
238+
" return creds['username'], creds['password']\n",
239+
"\n",
240+
"my_secret_id = \"EXAMPLE1-90ab-cdef-fedc-ba987SECRET1\"\n",
241+
"creds_provider = SecretsManagerProvider(secret_id=my_secret_id)\n",
237242
"user_connection = redis.Redis(host=\"localhost\", port=6379, credential_provider=creds_provider)\n",
238243
"user_connection.ping()"
239244
]
@@ -266,19 +271,24 @@
266271
"import boto3\n",
267272
"import cachetools.func\n",
268273
"\n",
269-
"ec_client = boto3.client('elasticache')\n",
274+
"class ElastiCacheIAMProvider(redis.CredentialProvider):\n",
275+
" def __init__(self, user, endpoint, port=6379, region=\"us-east-1\"):\n",
276+
" self.ec_client = boto3.client('elasticache')\n",
277+
" self.user = user\n",
278+
" self.endpoint = endpoint\n",
279+
" self.port = port\n",
280+
" self.region = region\n",
270281
"\n",
271-
"def iam_auth_provider(self, user, endpoint, port=6379, region=\"us-east-1\"):\n",
272-
" @cachetools.func.ttl_cache(maxsize=128, ttl=15 * 60) # 15m\n",
273-
" def get_iam_auth_token(user, endpoint, port, region):\n",
274-
" return ec_client.generate_iam_auth_token(user, endpoint, port, region)\n",
275-
" iam_auth_token = get_iam_auth_token(endpoint, port, user, region)\n",
276-
" return iam_auth_token\n",
282+
" def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:\n",
283+
" @cachetools.func.ttl_cache(maxsize=128, ttl=15 * 60) # 15m\n",
284+
" def get_iam_auth_token(user, endpoint, port, region):\n",
285+
" return self.ec_client.generate_iam_auth_token(user, endpoint, port, region)\n",
286+
" iam_auth_token = get_iam_auth_token(self.endpoint, self.port, self.user, self.region)\n",
287+
" return iam_auth_token\n",
277288
"\n",
278289
"username = \"barshaul\"\n",
279290
"endpoint = \"test-001.use1.cache.amazonaws.com\"\n",
280-
"creds_provider = redis.CredentialProvider(supplier=iam_auth_provider, user=username,\n",
281-
" endpoint=endpoint)\n",
291+
"creds_provider = ElastiCacheIAMProvider(user=username, endpoint=endpoint)\n",
282292
"user_connection = redis.Redis(host=endpoint, port=6379, credential_provider=creds_provider)\n",
283293
"user_connection.ping()"
284294
]

docs/examples/ssl_connection_examples.ipynb

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,27 @@
5555
"url_connection.ping()"
5656
]
5757
},
58+
{
59+
"cell_type": "markdown",
60+
"id": "04e70233",
61+
"metadata": {},
62+
"source": [
63+
"## Connecting to a Redis instance using ConnectionPool"
64+
]
65+
},
66+
{
67+
"cell_type": "code",
68+
"execution_count": null,
69+
"id": "2903de26",
70+
"metadata": {},
71+
"outputs": [],
72+
"source": [
73+
"import redis\n",
74+
"redis_pool = redis.ConnectionPool(host=\"localhost\", port=6666, connection_class=redis.SSLConnection)\n",
75+
"ssl_connection = redis.StrictRedis(connection_pool=redis_pool) \n",
76+
"ssl_connection.ping()"
77+
]
78+
},
5879
{
5980
"cell_type": "markdown",
6081
"metadata": {},

redis/asyncio/connection.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,6 @@ async def _read_response(
267267
response: Any
268268
byte, response = raw[:1], raw[1:]
269269

270-
if byte not in (b"-", b"+", b":", b"$", b"*"):
271-
raise InvalidResponse(f"Protocol Error: {raw!r}")
272-
273270
# server returned an error
274271
if byte == b"-":
275272
response = response.decode("utf-8", errors="replace")
@@ -289,22 +286,24 @@ async def _read_response(
289286
pass
290287
# int value
291288
elif byte == b":":
292-
response = int(response)
289+
return int(response)
293290
# bulk response
291+
elif byte == b"$" and response == b"-1":
292+
return None
294293
elif byte == b"$":
295-
length = int(response)
296-
if length == -1:
297-
return None
298-
response = await self._read(length)
294+
response = await self._read(int(response))
299295
# multi-bulk response
296+
elif byte == b"*" and response == b"-1":
297+
return None
300298
elif byte == b"*":
301-
length = int(response)
302-
if length == -1:
303-
return None
304299
response = [
305-
(await self._read_response(disable_decoding)) for _ in range(length)
300+
(await self._read_response(disable_decoding))
301+
for _ in range(int(response)) # noqa
306302
]
307-
if isinstance(response, bytes) and disable_decoding is False:
303+
else:
304+
raise InvalidResponse(f"Protocol Error: {raw!r}")
305+
306+
if disable_decoding is False:
308307
response = self.encoder.decode(response)
309308
return response
310309

redis/commands/core.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3357,10 +3357,15 @@ def smembers(self, name: str) -> Union[Awaitable[Set], Set]:
33573357

33583358
def smismember(
33593359
self, name: str, values: List, *args: List
3360-
) -> Union[Awaitable[List[bool]], List[bool]]:
3360+
) -> Union[
3361+
Awaitable[List[Union[Literal[0], Literal[1]]]],
3362+
List[Union[Literal[0], Literal[1]]],
3363+
]:
33613364
"""
33623365
Return whether each value in ``values`` is a member of the set ``name``
3363-
as a list of ``bool`` in the order of ``values``
3366+
as a list of ``int`` in the order of ``values``:
3367+
- 1 if the value is a member of the set.
3368+
- 0 if the value is not a member of the set or if key does not exist.
33643369
33653370
For more information see https://redis.io/commands/smismember
33663371
"""

redis/commands/json/commands.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def arrindex(
3131
name: str,
3232
path: str,
3333
scalar: int,
34-
start: Optional[int] = 0,
35-
stop: Optional[int] = -1,
34+
start: Optional[int] = None,
35+
stop: Optional[int] = None,
3636
) -> List[Union[int, None]]:
3737
"""
3838
Return the index of ``scalar`` in the JSON array under ``path`` at key
@@ -43,9 +43,13 @@ def arrindex(
4343
4444
For more information see `JSON.ARRINDEX <https://redis.io/commands/json.arrindex>`_.
4545
""" # noqa
46-
return self.execute_command(
47-
"JSON.ARRINDEX", name, str(path), self._encode(scalar), start, stop
48-
)
46+
pieces = [name, str(path), self._encode(scalar)]
47+
if start is not None:
48+
pieces.append(start)
49+
if stop is not None:
50+
pieces.append(stop)
51+
52+
return self.execute_command("JSON.ARRINDEX", *pieces)
4953

5054
def arrinsert(
5155
self, name: str, path: str, index: int, *args: List[JsonType]

redis/connection.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -358,9 +358,6 @@ def _read_response(self, disable_decoding=False):
358358

359359
byte, response = raw[:1], raw[1:]
360360

361-
if byte not in (b"-", b"+", b":", b"$", b"*"):
362-
raise InvalidResponse(f"Protocol Error: {raw!r}")
363-
364361
# server returned an error
365362
if byte == b"-":
366363
response = response.decode("utf-8", errors="replace")
@@ -379,23 +376,24 @@ def _read_response(self, disable_decoding=False):
379376
pass
380377
# int value
381378
elif byte == b":":
382-
response = int(response)
379+
return int(response)
383380
# bulk response
381+
elif byte == b"$" and response == b"-1":
382+
return None
384383
elif byte == b"$":
385-
length = int(response)
386-
if length == -1:
387-
return None
388-
response = self._buffer.read(length)
384+
response = self._buffer.read(int(response))
389385
# multi-bulk response
386+
elif byte == b"*" and response == b"-1":
387+
return None
390388
elif byte == b"*":
391-
length = int(response)
392-
if length == -1:
393-
return None
394389
response = [
395390
self._read_response(disable_decoding=disable_decoding)
396-
for i in range(length)
391+
for i in range(int(response))
397392
]
398-
if isinstance(response, bytes) and disable_decoding is False:
393+
else:
394+
raise InvalidResponse(f"Protocol Error: {raw!r}")
395+
396+
if disable_decoding is False:
399397
response = self.encoder.decode(response)
400398
return response
401399

tests/test_asyncio/test_json.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,15 @@ async def test_arrappend(modclient: redis.Redis):
145145

146146
@pytest.mark.redismod
147147
async def test_arrindex(modclient: redis.Redis):
148-
await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4])
149-
assert 1 == await modclient.json().arrindex("arr", Path.root_path(), 1)
150-
assert -1 == await modclient.json().arrindex("arr", Path.root_path(), 1, 2)
148+
r_path = Path.root_path()
149+
await modclient.json().set("arr", r_path, [0, 1, 2, 3, 4])
150+
assert 1 == await modclient.json().arrindex("arr", r_path, 1)
151+
assert -1 == await modclient.json().arrindex("arr", r_path, 1, 2)
152+
assert 4 == await modclient.json().arrindex("arr", r_path, 4)
153+
assert 4 == await modclient.json().arrindex("arr", r_path, 4, start=0)
154+
assert 4 == await modclient.json().arrindex("arr", r_path, 4, start=0, stop=5000)
155+
assert -1 == await modclient.json().arrindex("arr", r_path, 4, start=0, stop=-1)
156+
assert -1 == await modclient.json().arrindex("arr", r_path, 4, start=1, stop=3)
151157

152158

153159
@pytest.mark.redismod

tests/test_json.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,11 @@ def test_arrindex(client):
166166
client.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4])
167167
assert 1 == client.json().arrindex("arr", Path.root_path(), 1)
168168
assert -1 == client.json().arrindex("arr", Path.root_path(), 1, 2)
169+
assert 4 == client.json().arrindex("arr", Path.root_path(), 4)
170+
assert 4 == client.json().arrindex("arr", Path.root_path(), 4, start=0)
171+
assert 4 == client.json().arrindex("arr", Path.root_path(), 4, start=0, stop=5000)
172+
assert -1 == client.json().arrindex("arr", Path.root_path(), 4, start=0, stop=-1)
173+
assert -1 == client.json().arrindex("arr", Path.root_path(), 4, start=1, stop=3)
169174

170175

171176
@pytest.mark.redismod

0 commit comments

Comments
 (0)