From fc9fbc5d3a72f61acf197f6d9e40243155002ebe Mon Sep 17 00:00:00 2001 From: dvora-h Date: Sun, 30 Apr 2023 16:45:02 +0300 Subject: [PATCH 01/10] start cleaning --- redis/client.py | 306 ++++++++++++++++++++++++------------------------ 1 file changed, 155 insertions(+), 151 deletions(-) diff --git a/redis/client.py b/redis/client.py index 71048f548f..8b5be625d7 100755 --- a/redis/client.py +++ b/redis/client.py @@ -696,161 +696,165 @@ def parse_set_result(response, **options): class AbstractRedis: RESPONSE_CALLBACKS = { - **string_keys_to_dict( - "AUTH COPY EXPIRE EXPIREAT PEXPIRE PEXPIREAT " - "HEXISTS HMSET MOVE MSETNX PERSIST " - "PSETEX RENAMENX SISMEMBER SMOVE SETEX SETNX", - bool, - ), - **string_keys_to_dict( - "BITCOUNT BITPOS DECRBY DEL EXISTS GEOADD GETBIT HDEL HLEN " - "HSTRLEN INCRBY LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD " - "SCARD SDIFFSTORE SETBIT SETRANGE SINTERSTORE SREM STRLEN " - "SUNIONSTORE UNLINK XACK XDEL XLEN XTRIM ZCARD ZLEXCOUNT ZREM " - "ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE", - int, - ), - **string_keys_to_dict("INCRBYFLOAT HINCRBYFLOAT", float), - **string_keys_to_dict( - # these return OK, or int if redis-server is >=1.3.4 - "LPUSH RPUSH", - lambda r: isinstance(r, int) and r or str_if_bytes(r) == "OK", - ), - **string_keys_to_dict("SORT", sort_return_tuples), - **string_keys_to_dict("ZSCORE ZINCRBY GEODIST", float_or_none), - **string_keys_to_dict( - "FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE ASKING READONLY READWRITE " - "RENAME SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH ", - bool_ok, - ), - **string_keys_to_dict("BLPOP BRPOP", lambda r: r and tuple(r) or None), - **string_keys_to_dict( - "SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set() - ), - **string_keys_to_dict( - "ZPOPMAX ZPOPMIN ZINTER ZDIFF ZUNION ZRANGE ZRANGEBYSCORE " - "ZREVRANGE ZREVRANGEBYSCORE", - zset_score_pairs, - ), - **string_keys_to_dict( - "BZPOPMIN BZPOPMAX", lambda r: r and (r[0], r[1], float(r[2])) or None - ), - **string_keys_to_dict("ZRANK ZREVRANK", int_or_none), - **string_keys_to_dict("XREVRANGE XRANGE", parse_stream_list), - **string_keys_to_dict("XREAD XREADGROUP", parse_xread), - **string_keys_to_dict("BGREWRITEAOF BGSAVE", lambda r: True), - "ACL CAT": lambda r: list(map(str_if_bytes, r)), - "ACL DELUSER": int, - "ACL GENPASS": str_if_bytes, - "ACL GETUSER": parse_acl_getuser, - "ACL HELP": lambda r: list(map(str_if_bytes, r)), - "ACL LIST": lambda r: list(map(str_if_bytes, r)), - "ACL LOAD": bool_ok, - "ACL LOG": parse_acl_log, - "ACL SAVE": bool_ok, - "ACL SETUSER": bool_ok, - "ACL USERS": lambda r: list(map(str_if_bytes, r)), - "ACL WHOAMI": str_if_bytes, - "CLIENT GETNAME": str_if_bytes, - "CLIENT ID": int, - "CLIENT KILL": parse_client_kill, - "CLIENT LIST": parse_client_list, - "CLIENT INFO": parse_client_info, - "CLIENT SETNAME": bool_ok, - "CLIENT UNBLOCK": lambda r: r and int(r) == 1 or False, - "CLIENT PAUSE": bool_ok, - "CLIENT GETREDIR": int, - "CLIENT TRACKINGINFO": lambda r: list(map(str_if_bytes, r)), - "CLUSTER ADDSLOTS": bool_ok, - "CLUSTER ADDSLOTSRANGE": bool_ok, - "CLUSTER COUNT-FAILURE-REPORTS": lambda x: int(x), - "CLUSTER COUNTKEYSINSLOT": lambda x: int(x), + **string_keys_to_dict("EXISTS", int), "CLUSTER DELSLOTS": bool_ok, - "CLUSTER DELSLOTSRANGE": bool_ok, - "CLUSTER FAILOVER": bool_ok, - "CLUSTER FORGET": bool_ok, - "CLUSTER GETKEYSINSLOT": lambda r: list(map(str_if_bytes, r)), - "CLUSTER INFO": parse_cluster_info, - "CLUSTER KEYSLOT": lambda x: int(x), - "CLUSTER MEET": bool_ok, - "CLUSTER NODES": parse_cluster_nodes, - "CLUSTER REPLICAS": parse_cluster_nodes, - "CLUSTER REPLICATE": bool_ok, - "CLUSTER RESET": bool_ok, - "CLUSTER SAVECONFIG": bool_ok, - "CLUSTER SET-CONFIG-EPOCH": bool_ok, - "CLUSTER SETSLOT": bool_ok, - "CLUSTER SLAVES": parse_cluster_nodes, + "CLUSTER ADDSLOTS": bool_ok, "COMMAND": parse_command, - "COMMAND COUNT": int, - "COMMAND GETKEYS": lambda r: list(map(str_if_bytes, r)), - "CONFIG GET": parse_config_get, - "CONFIG RESETSTAT": bool_ok, - "CONFIG SET": bool_ok, - "DEBUG OBJECT": parse_debug_object, - "FUNCTION DELETE": bool_ok, - "FUNCTION FLUSH": bool_ok, - "FUNCTION RESTORE": bool_ok, - "GEOHASH": lambda r: list(map(str_if_bytes, r)), - "GEOPOS": lambda r: list( - map(lambda ll: (float(ll[0]), float(ll[1])) if ll is not None else None, r) - ), - "GEOSEARCH": parse_geosearch_generic, - "GEORADIUS": parse_geosearch_generic, - "GEORADIUSBYMEMBER": parse_geosearch_generic, - "HGETALL": lambda r: r and pairs_to_dict(r) or {}, - "HSCAN": parse_hscan, "INFO": parse_info, - "LASTSAVE": timestamp_to_datetime, - "MEMORY PURGE": bool_ok, - "MEMORY STATS": parse_memory_stats, - "MEMORY USAGE": int_or_none, - "MODULE LOAD": parse_module_result, - "MODULE UNLOAD": parse_module_result, - "MODULE LIST": lambda r: [pairs_to_dict(m) for m in r], - "OBJECT": parse_object, - "PING": lambda r: str_if_bytes(r) == "PONG", - "QUIT": bool_ok, - "STRALGO": parse_stralgo, - "PUBSUB NUMSUB": parse_pubsub_numsub, - "RANDOMKEY": lambda r: r and r or None, - "RESET": str_if_bytes, - "SCAN": parse_scan, - "SCRIPT EXISTS": lambda r: list(map(bool, r)), - "SCRIPT FLUSH": bool_ok, - "SCRIPT KILL": bool_ok, - "SCRIPT LOAD": str_if_bytes, - "SENTINEL CKQUORUM": bool_ok, - "SENTINEL FAILOVER": bool_ok, - "SENTINEL FLUSHCONFIG": bool_ok, - "SENTINEL GET-MASTER-ADDR-BY-NAME": parse_sentinel_get_master, - "SENTINEL MASTER": parse_sentinel_master, - "SENTINEL MASTERS": parse_sentinel_masters, - "SENTINEL MONITOR": bool_ok, - "SENTINEL RESET": bool_ok, - "SENTINEL REMOVE": bool_ok, - "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels, - "SENTINEL SET": bool_ok, - "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels, "SET": parse_set_result, - "SLOWLOG GET": parse_slowlog_get, - "SLOWLOG LEN": int, - "SLOWLOG RESET": bool_ok, - "SSCAN": parse_scan, - "TIME": lambda x: (int(x[0]), int(x[1])), - "XCLAIM": parse_xclaim, - "XAUTOCLAIM": parse_xautoclaim, - "XGROUP CREATE": bool_ok, - "XGROUP DELCONSUMER": int, - "XGROUP DESTROY": bool, - "XGROUP SETID": bool_ok, - "XINFO CONSUMERS": parse_list_of_dicts, - "XINFO GROUPS": parse_list_of_dicts, - "XINFO STREAM": parse_xinfo_stream, - "XPENDING": parse_xpending, - "ZADD": parse_zadd, - "ZSCAN": parse_zscan, - "ZMSCORE": parse_zmscore, + } + + RESP2_RESPONSE_CALLBACKS = { + # **string_keys_to_dict( + # "AUTH COPY EXPIRE EXPIREAT PEXPIRE PEXPIREAT " + # "HEXISTS HMSET MOVE MSETNX PERSIST " + # "PSETEX RENAMENX SISMEMBER SMOVE SETEX SETNX", + # bool, + # ), + # **string_keys_to_dict( + # "BITCOUNT BITPOS DECRBY DEL EXISTS GEOADD GETBIT HDEL HLEN " + # "HSTRLEN INCRBY LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD " + # "SCARD SDIFFSTORE SETBIT SETRANGE SINTERSTORE SREM STRLEN " + # "SUNIONSTORE UNLINK XACK XDEL XLEN XTRIM ZCARD ZLEXCOUNT ZREM " + # "ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE", + # int, + # ), + # **string_keys_to_dict("INCRBYFLOAT HINCRBYFLOAT", float), + # **string_keys_to_dict( + # # these return OK, or int if redis-server is >=1.3.4 + # "LPUSH RPUSH", + # lambda r: isinstance(r, int) and r or str_if_bytes(r) == "OK", + # ), + # **string_keys_to_dict("SORT", sort_return_tuples), + # **string_keys_to_dict("ZSCORE ZINCRBY GEODIST", float_or_none), + # **string_keys_to_dict( + # "FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE ASKING READONLY READWRITE " + # "RENAME SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH ", + # bool_ok, + # ), + # **string_keys_to_dict("BLPOP BRPOP", lambda r: r and tuple(r) or None), + # **string_keys_to_dict( + # "SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set() + # ), + # **string_keys_to_dict( + # "ZPOPMAX ZPOPMIN ZINTER ZDIFF ZUNION ZRANGE ZRANGEBYSCORE " + # "ZREVRANGE ZREVRANGEBYSCORE", + # zset_score_pairs, + # ), + # **string_keys_to_dict( + # "BZPOPMIN BZPOPMAX", lambda r: r and (r[0], r[1], float(r[2])) or None + # ), + # **string_keys_to_dict("ZRANK ZREVRANK", int_or_none), + # **string_keys_to_dict("XREVRANGE XRANGE", parse_stream_list), + # **string_keys_to_dict("XREAD XREADGROUP", parse_xread), + # **string_keys_to_dict("BGREWRITEAOF BGSAVE", lambda r: True), + # "ACL CAT": lambda r: list(map(str_if_bytes, r)), + # "ACL DELUSER": int, + # "ACL GENPASS": str_if_bytes, + # "ACL GETUSER": parse_acl_getuser, + # "ACL HELP": lambda r: list(map(str_if_bytes, r)), + # "ACL LIST": lambda r: list(map(str_if_bytes, r)), + # "ACL LOAD": bool_ok, + # "ACL LOG": parse_acl_log, + # "ACL SAVE": bool_ok, + # "ACL SETUSER": bool_ok, + # "ACL USERS": lambda r: list(map(str_if_bytes, r)), + # "ACL WHOAMI": str_if_bytes, + # "CLIENT GETNAME": str_if_bytes, + # "CLIENT ID": int, + # "CLIENT KILL": parse_client_kill, + # "CLIENT LIST": parse_client_list, + # "CLIENT INFO": parse_client_info, + # "CLIENT SETNAME": bool_ok, + # "CLIENT UNBLOCK": lambda r: r and int(r) == 1 or False, + # "CLIENT PAUSE": bool_ok, + # "CLIENT GETREDIR": int, + # "CLIENT TRACKINGINFO": lambda r: list(map(str_if_bytes, r)), + # "CLUSTER ADDSLOTSRANGE": bool_ok, + # "CLUSTER COUNT-FAILURE-REPORTS": lambda x: int(x), + # "CLUSTER COUNTKEYSINSLOT": lambda x: int(x), + # "CLUSTER DELSLOTSRANGE": bool_ok, + # "CLUSTER FAILOVER": bool_ok, + # "CLUSTER FORGET": bool_ok, + # "CLUSTER GETKEYSINSLOT": lambda r: list(map(str_if_bytes, r)), + # "CLUSTER INFO": parse_cluster_info, + # "CLUSTER KEYSLOT": lambda x: int(x), + # "CLUSTER MEET": bool_ok, + # "CLUSTER NODES": parse_cluster_nodes, + # "CLUSTER REPLICAS": parse_cluster_nodes, + # "CLUSTER REPLICATE": bool_ok, + # "CLUSTER RESET": bool_ok, + # "CLUSTER SAVECONFIG": bool_ok, + # "CLUSTER SET-CONFIG-EPOCH": bool_ok, + # "CLUSTER SETSLOT": bool_ok, + # "CLUSTER SLAVES": parse_cluster_nodes, + # "COMMAND COUNT": int, + # "COMMAND GETKEYS": lambda r: list(map(str_if_bytes, r)), + # "CONFIG GET": parse_config_get, + # "CONFIG RESETSTAT": bool_ok, + # "CONFIG SET": bool_ok, + # "DEBUG OBJECT": parse_debug_object, + # "FUNCTION DELETE": bool_ok, + # "FUNCTION FLUSH": bool_ok, + # "FUNCTION RESTORE": bool_ok, + # "GEOHASH": lambda r: list(map(str_if_bytes, r)), + # "GEOPOS": lambda r: list( + # map(lambda ll: (float(ll[0]), float(ll[1])) if ll is not None else None, r) + # ), + # "GEOSEARCH": parse_geosearch_generic, + # "GEORADIUS": parse_geosearch_generic, + # "GEORADIUSBYMEMBER": parse_geosearch_generic, + # "HGETALL": lambda r: r and pairs_to_dict(r) or {}, + # "HSCAN": parse_hscan, + # "LASTSAVE": timestamp_to_datetime, + # "MEMORY PURGE": bool_ok, + # "MEMORY STATS": parse_memory_stats, + # "MEMORY USAGE": int_or_none, + # "MODULE LOAD": parse_module_result, + # "MODULE UNLOAD": parse_module_result, + # "MODULE LIST": lambda r: [pairs_to_dict(m) for m in r], + # "OBJECT": parse_object, + # "PING": lambda r: str_if_bytes(r) == "PONG", + # "QUIT": bool_ok, + # "STRALGO": parse_stralgo, + # "PUBSUB NUMSUB": parse_pubsub_numsub, + # "RANDOMKEY": lambda r: r and r or None, + # "RESET": str_if_bytes, + # "SCAN": parse_scan, + # "SCRIPT EXISTS": lambda r: list(map(bool, r)), + # "SCRIPT FLUSH": bool_ok, + # "SCRIPT KILL": bool_ok, + # "SCRIPT LOAD": str_if_bytes, + # "SENTINEL CKQUORUM": bool_ok, + # "SENTINEL FAILOVER": bool_ok, + # "SENTINEL FLUSHCONFIG": bool_ok, + # "SENTINEL GET-MASTER-ADDR-BY-NAME": parse_sentinel_get_master, + # "SENTINEL MASTER": parse_sentinel_master, + # "SENTINEL MASTERS": parse_sentinel_masters, + # "SENTINEL MONITOR": bool_ok, + # "SENTINEL RESET": bool_ok, + # "SENTINEL REMOVE": bool_ok, + # "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels, + # "SENTINEL SET": bool_ok, + # "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels, + # "SLOWLOG GET": parse_slowlog_get, + # "SLOWLOG LEN": int, + # "SLOWLOG RESET": bool_ok, + # "SSCAN": parse_scan, + # "TIME": lambda x: (int(x[0]), int(x[1])), + # "XCLAIM": parse_xclaim, + # "XAUTOCLAIM": parse_xautoclaim, + # "XGROUP CREATE": bool_ok, + # "XGROUP DELCONSUMER": int, + # "XGROUP DESTROY": bool, + # "XGROUP SETID": bool_ok, + # "XINFO CONSUMERS": parse_list_of_dicts, + # "XINFO GROUPS": parse_list_of_dicts, + # "XINFO STREAM": parse_xinfo_stream, + # "XPENDING": parse_xpending, + # "ZADD": parse_zadd, + # "ZSCAN": parse_zscan, + # "ZMSCORE": parse_zmscore, } RESP3_RESPONSE_CALLBACKS = { From 0faf16b54e8084dd4b297f3eec25651e1151b062 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 1 Jun 2023 15:03:41 +0300 Subject: [PATCH 02/10] clean sone callbacks --- redis/client.py | 109 +++++++++++++++++++++-------------------- tests/test_commands.py | 97 ++++++++++++++++++------------------ 2 files changed, 105 insertions(+), 101 deletions(-) diff --git a/redis/client.py b/redis/client.py index 8b5be625d7..491946e2bd 100755 --- a/redis/client.py +++ b/redis/client.py @@ -696,17 +696,70 @@ def parse_set_result(response, **options): class AbstractRedis: RESPONSE_CALLBACKS = { + **string_keys_to_dict("EXPIRE EXPIREAT PEXPIRE PEXPIREAT", bool), **string_keys_to_dict("EXISTS", int), + **string_keys_to_dict("INCRBYFLOAT HINCRBYFLOAT", float), + **string_keys_to_dict("READONLY", bool_ok), "CLUSTER DELSLOTS": bool_ok, "CLUSTER ADDSLOTS": bool_ok, "COMMAND": parse_command, "INFO": parse_info, "SET": parse_set_result, + "CLIENT ID": int, + "CLIENT KILL": parse_client_kill, + "CLIENT LIST": parse_client_list, + "CLIENT INFO": parse_client_info, + "CLIENT SETNAME": bool_ok, + "CLIENT TRACKINGINFO": lambda r: list(map(str_if_bytes, r)), + "LASTSAVE": timestamp_to_datetime, + "RESET": str_if_bytes, + "SLOWLOG GET": parse_slowlog_get, + "TIME": lambda x: (int(x[0]), int(x[1])), + **string_keys_to_dict("BLPOP BRPOP", lambda r: r and tuple(r) or None), + "SCAN": parse_scan, + "CLIENT GETNAME": str_if_bytes, + "SSCAN": parse_scan, + "ACL LOG": parse_acl_log, + "ACL WHOAMI": str_if_bytes, + "ACL GENPASS": str_if_bytes, + "ACL CAT": lambda r: list(map(str_if_bytes, r)), + "HSCAN": parse_hscan, + "ZSCAN": parse_zscan, + **string_keys_to_dict( + "BZPOPMIN BZPOPMAX", lambda r: r and (r[0], r[1], float(r[2])) or None + ), + "CLUSTER COUNT-FAILURE-REPORTS": lambda x: int(x), + "CLUSTER COUNTKEYSINSLOT": lambda x: int(x), + "CLUSTER FAILOVER": bool_ok, + "CLUSTER FORGET": bool_ok, + "CLUSTER INFO": parse_cluster_info, + "CLUSTER KEYSLOT": lambda x: int(x), + "CLUSTER MEET": bool_ok, + "CLUSTER NODES": parse_cluster_nodes, + "CLUSTER REPLICATE": bool_ok, + "CLUSTER RESET": bool_ok, + "CLUSTER SAVECONFIG": bool_ok, + "CLUSTER SETSLOT": bool_ok, + "CLUSTER SLAVES": parse_cluster_nodes, + **string_keys_to_dict("GEODIST", float_or_none), + "GEOHASH": lambda r: list(map(str_if_bytes, r)), + "GEOPOS": lambda r: list( + map(lambda ll: (float(ll[0]), float(ll[1])) if ll is not None else None, r) + ), + "GEOSEARCH": parse_geosearch_generic, + "GEORADIUS": parse_geosearch_generic, + "GEORADIUSBYMEMBER": parse_geosearch_generic, + "XAUTOCLAIM": parse_xautoclaim, + "XINFO STREAM": parse_xinfo_stream, + "XPENDING": parse_xpending, + **string_keys_to_dict("XREAD XREADGROUP", parse_xread), + "COMMAND GETKEYS": lambda r: list(map(str_if_bytes, r)), + **string_keys_to_dict("SORT", sort_return_tuples), } RESP2_RESPONSE_CALLBACKS = { # **string_keys_to_dict( - # "AUTH COPY EXPIRE EXPIREAT PEXPIRE PEXPIREAT " + # "AUTH COPY " # "HEXISTS HMSET MOVE MSETNX PERSIST " # "PSETEX RENAMENX SISMEMBER SMOVE SETEX SETNX", # bool, @@ -719,20 +772,17 @@ class AbstractRedis: # "ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE", # int, # ), - # **string_keys_to_dict("INCRBYFLOAT HINCRBYFLOAT", float), # **string_keys_to_dict( # # these return OK, or int if redis-server is >=1.3.4 # "LPUSH RPUSH", # lambda r: isinstance(r, int) and r or str_if_bytes(r) == "OK", # ), - # **string_keys_to_dict("SORT", sort_return_tuples), - # **string_keys_to_dict("ZSCORE ZINCRBY GEODIST", float_or_none), + # **string_keys_to_dict("ZSCORE ZINCRBY", float_or_none), # **string_keys_to_dict( - # "FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE ASKING READONLY READWRITE " + # "FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE ASKING READWRITE " # "RENAME SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH ", # bool_ok, # ), - # **string_keys_to_dict("BLPOP BRPOP", lambda r: r and tuple(r) or None), # **string_keys_to_dict( # "SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set() # ), @@ -741,55 +791,26 @@ class AbstractRedis: # "ZREVRANGE ZREVRANGEBYSCORE", # zset_score_pairs, # ), - # **string_keys_to_dict( - # "BZPOPMIN BZPOPMAX", lambda r: r and (r[0], r[1], float(r[2])) or None - # ), # **string_keys_to_dict("ZRANK ZREVRANK", int_or_none), # **string_keys_to_dict("XREVRANGE XRANGE", parse_stream_list), - # **string_keys_to_dict("XREAD XREADGROUP", parse_xread), # **string_keys_to_dict("BGREWRITEAOF BGSAVE", lambda r: True), - # "ACL CAT": lambda r: list(map(str_if_bytes, r)), # "ACL DELUSER": int, - # "ACL GENPASS": str_if_bytes, # "ACL GETUSER": parse_acl_getuser, # "ACL HELP": lambda r: list(map(str_if_bytes, r)), # "ACL LIST": lambda r: list(map(str_if_bytes, r)), # "ACL LOAD": bool_ok, - # "ACL LOG": parse_acl_log, # "ACL SAVE": bool_ok, # "ACL SETUSER": bool_ok, # "ACL USERS": lambda r: list(map(str_if_bytes, r)), - # "ACL WHOAMI": str_if_bytes, - # "CLIENT GETNAME": str_if_bytes, - # "CLIENT ID": int, - # "CLIENT KILL": parse_client_kill, - # "CLIENT LIST": parse_client_list, - # "CLIENT INFO": parse_client_info, - # "CLIENT SETNAME": bool_ok, # "CLIENT UNBLOCK": lambda r: r and int(r) == 1 or False, # "CLIENT PAUSE": bool_ok, # "CLIENT GETREDIR": int, - # "CLIENT TRACKINGINFO": lambda r: list(map(str_if_bytes, r)), # "CLUSTER ADDSLOTSRANGE": bool_ok, - # "CLUSTER COUNT-FAILURE-REPORTS": lambda x: int(x), - # "CLUSTER COUNTKEYSINSLOT": lambda x: int(x), # "CLUSTER DELSLOTSRANGE": bool_ok, - # "CLUSTER FAILOVER": bool_ok, - # "CLUSTER FORGET": bool_ok, # "CLUSTER GETKEYSINSLOT": lambda r: list(map(str_if_bytes, r)), - # "CLUSTER INFO": parse_cluster_info, - # "CLUSTER KEYSLOT": lambda x: int(x), - # "CLUSTER MEET": bool_ok, - # "CLUSTER NODES": parse_cluster_nodes, # "CLUSTER REPLICAS": parse_cluster_nodes, - # "CLUSTER REPLICATE": bool_ok, - # "CLUSTER RESET": bool_ok, - # "CLUSTER SAVECONFIG": bool_ok, # "CLUSTER SET-CONFIG-EPOCH": bool_ok, - # "CLUSTER SETSLOT": bool_ok, - # "CLUSTER SLAVES": parse_cluster_nodes, # "COMMAND COUNT": int, - # "COMMAND GETKEYS": lambda r: list(map(str_if_bytes, r)), # "CONFIG GET": parse_config_get, # "CONFIG RESETSTAT": bool_ok, # "CONFIG SET": bool_ok, @@ -797,16 +818,7 @@ class AbstractRedis: # "FUNCTION DELETE": bool_ok, # "FUNCTION FLUSH": bool_ok, # "FUNCTION RESTORE": bool_ok, - # "GEOHASH": lambda r: list(map(str_if_bytes, r)), - # "GEOPOS": lambda r: list( - # map(lambda ll: (float(ll[0]), float(ll[1])) if ll is not None else None, r) - # ), - # "GEOSEARCH": parse_geosearch_generic, - # "GEORADIUS": parse_geosearch_generic, - # "GEORADIUSBYMEMBER": parse_geosearch_generic, # "HGETALL": lambda r: r and pairs_to_dict(r) or {}, - # "HSCAN": parse_hscan, - # "LASTSAVE": timestamp_to_datetime, # "MEMORY PURGE": bool_ok, # "MEMORY STATS": parse_memory_stats, # "MEMORY USAGE": int_or_none, @@ -819,8 +831,6 @@ class AbstractRedis: # "STRALGO": parse_stralgo, # "PUBSUB NUMSUB": parse_pubsub_numsub, # "RANDOMKEY": lambda r: r and r or None, - # "RESET": str_if_bytes, - # "SCAN": parse_scan, # "SCRIPT EXISTS": lambda r: list(map(bool, r)), # "SCRIPT FLUSH": bool_ok, # "SCRIPT KILL": bool_ok, @@ -837,23 +847,16 @@ class AbstractRedis: # "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels, # "SENTINEL SET": bool_ok, # "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels, - # "SLOWLOG GET": parse_slowlog_get, # "SLOWLOG LEN": int, # "SLOWLOG RESET": bool_ok, - # "SSCAN": parse_scan, - # "TIME": lambda x: (int(x[0]), int(x[1])), # "XCLAIM": parse_xclaim, - # "XAUTOCLAIM": parse_xautoclaim, # "XGROUP CREATE": bool_ok, # "XGROUP DELCONSUMER": int, # "XGROUP DESTROY": bool, # "XGROUP SETID": bool_ok, # "XINFO CONSUMERS": parse_list_of_dicts, # "XINFO GROUPS": parse_list_of_dicts, - # "XINFO STREAM": parse_xinfo_stream, - # "XPENDING": parse_xpending, # "ZADD": parse_zadd, - # "ZSCAN": parse_zscan, # "ZMSCORE": parse_zmscore, } diff --git a/tests/test_commands.py b/tests/test_commands.py index 1af69c83c0..1da4bd3a9f 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -68,54 +68,54 @@ def test_case_insensitive_command_names(self, r): class TestRedisCommands: @skip_if_redis_enterprise() - def test_auth(self, r, request): - # sending an AUTH command before setting a user/password on the - # server should return an AuthenticationError - with pytest.raises(exceptions.AuthenticationError): - r.auth("some_password") - - with pytest.raises(exceptions.AuthenticationError): - r.auth("some_password", "some_user") - - # first, test for default user (`username` is supposed to be optional) - default_username = "default" - temp_pass = "temp_pass" - r.config_set("requirepass", temp_pass) - - assert r.auth(temp_pass, default_username) is True - assert r.auth(temp_pass) is True - - # test for other users - username = "redis-py-auth" - - def teardown(): - try: - # this is needed because after an AuthenticationError the connection - # is closed, and if we send an AUTH command a new connection is - # created, but in this case we'd get an "Authentication required" - # error when switching to the db 9 because we're not authenticated yet - # setting the password on the connection itself triggers the - # authentication in the connection's `on_connect` method - r.connection.password = temp_pass - except AttributeError: - # connection field is not set in Redis Cluster, but that's ok - # because the problem discussed above does not apply to Redis Cluster - pass - - r.auth(temp_pass) - r.config_set("requirepass", "") - r.acl_deluser(username) - - request.addfinalizer(teardown) - - assert r.acl_setuser( - username, enabled=True, passwords=["+strong_password"], commands=["+acl"] - ) - - assert r.auth(username=username, password="strong_password") is True - - with pytest.raises(exceptions.AuthenticationError): - r.auth(username=username, password="wrong_password") + # def test_auth(self, r, request): + # # sending an AUTH command before setting a user/password on the + # # server should return an AuthenticationError + # with pytest.raises(exceptions.AuthenticationError): + # r.auth("some_password") + + # with pytest.raises(exceptions.AuthenticationError): + # r.auth("some_password", "some_user") + + # # first, test for default user (`username` is supposed to be optional) + # default_username = "default" + # temp_pass = "temp_pass" + # r.config_set("requirepass", temp_pass) + + # assert r.auth(temp_pass, default_username) is True + # assert r.auth(temp_pass) is True + + # # test for other users + # username = "redis-py-auth" + + # def teardown(): + # try: + # # this is needed because after an AuthenticationError the connection + # # is closed, and if we send an AUTH command a new connection is + # # created, but in this case we'd get an "Authentication required" + # # error when switching to the db 9 because we're not authenticated yet + # # setting the password on the connection itself triggers the + # # authentication in the connection's `on_connect` method + # r.connection.password = temp_pass + # except AttributeError: + # # connection field is not set in Redis Cluster, but that's ok + # # because the problem discussed above does not apply to Redis Cluster + # pass + + # r.auth(temp_pass) + # r.config_set("requirepass", "") + # r.acl_deluser(username) + + # request.addfinalizer(teardown) + + # assert r.acl_setuser( + # username, enabled=True, passwords=["+strong_password"], commands=["+acl"] + # ) + + # assert r.auth(username=username, password="strong_password") is True + + # with pytest.raises(exceptions.AuthenticationError): + # r.auth(username=username, password="wrong_password") def test_command_on_invalid_key_type(self, r): r.lpush("a", "1") @@ -4522,6 +4522,7 @@ def test_xreadgroup(self, r): ] # xread starting at 0 returns both messages + breakpoint() res = r.xreadgroup(group, consumer, streams={stream: ">"}) if is_resp2_connection(r): assert res == [[strem_name, expected_entries]] From c08d0acd3c58df1be7089582be60e0849e0d422a Mon Sep 17 00:00:00 2001 From: dvora-h Date: Wed, 14 Jun 2023 11:10:37 +0300 Subject: [PATCH 03/10] response callbacks --- redis/asyncio/connection.py | 2 +- redis/client.py | 69 +++++++++++++++++-------------------- redis/connection.py | 4 +-- tests/conftest.py | 2 +- tests/test_commands.py | 21 ++++++----- 5 files changed, 45 insertions(+), 53 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index b51e4fd8ce..364e811a59 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -379,7 +379,7 @@ async def on_connect(self) -> None: raise AuthenticationError("Invalid Username or Password") # if resp version is specified, switch to it - elif self.protocol != 2: + elif self.protocol not in [2, "2"]: if isinstance(self._parser, _AsyncRESP2Parser): self.set_parser(_AsyncRESP3Parser) # update cluster exception classes diff --git a/redis/client.py b/redis/client.py index ddb2ec6a9b..d4bdfbd46a 100755 --- a/redis/client.py +++ b/redis/client.py @@ -726,7 +726,7 @@ def parse_set_result(response, **options): class AbstractRedis: RESPONSE_CALLBACKS = { - **string_keys_to_dict("EXPIRE EXPIREAT PEXPIRE PEXPIREAT", bool), + **string_keys_to_dict("EXPIRE EXPIREAT PEXPIRE PEXPIREAT AUTH", bool), **string_keys_to_dict("EXISTS", int), **string_keys_to_dict("INCRBYFLOAT HINCRBYFLOAT", float), **string_keys_to_dict("READONLY", bool_ok), @@ -785,17 +785,42 @@ class AbstractRedis: **string_keys_to_dict("XREAD XREADGROUP", parse_xread), "COMMAND GETKEYS": lambda r: list(map(str_if_bytes, r)), **string_keys_to_dict("SORT", sort_return_tuples), + "PING": lambda r: str_if_bytes(r) == "PONG", + "ACL SETUSER": bool_ok, + "PUBSUB NUMSUB": parse_pubsub_numsub, + "SCRIPT FLUSH": bool_ok, + "SCRIPT LOAD": str_if_bytes, + "ACL GETUSER": parse_acl_getuser, + "CONFIG SET": bool_ok, + **string_keys_to_dict("XREVRANGE XRANGE", parse_stream_list), + "XCLAIM": parse_xclaim, + } RESP2_RESPONSE_CALLBACKS = { + "CONFIG GET": parse_config_get, + **string_keys_to_dict( + "SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set() + ), + **string_keys_to_dict( + "ZPOPMAX ZPOPMIN ZINTER ZDIFF ZUNION ZRANGE ZRANGEBYSCORE " + "ZREVRANGE ZREVRANGEBYSCORE", + zset_score_pairs, + ), + **string_keys_to_dict("ZSCORE ZINCRBY", float_or_none), + "ZADD": parse_zadd, + "ZMSCORE": parse_zmscore, + "HGETALL": lambda r: r and pairs_to_dict(r) or {}, + "MEMORY STATS": parse_memory_stats, + "MODULE LIST": lambda r: [pairs_to_dict(m) for m in r], + # **string_keys_to_dict( - # "AUTH COPY " + # "COPY " # "HEXISTS HMSET MOVE MSETNX PERSIST " # "PSETEX RENAMENX SISMEMBER SMOVE SETEX SETNX", # bool, # ), # **string_keys_to_dict( - # "BITCOUNT BITPOS DECRBY DEL EXISTS GEOADD GETBIT HDEL HLEN " # "HSTRLEN INCRBY LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD " # "SCARD SDIFFSTORE SETBIT SETRANGE SINTERSTORE SREM STRLEN " # "SUNIONSTORE UNLINK XACK XDEL XLEN XTRIM ZCARD ZLEXCOUNT ZREM " @@ -803,68 +828,39 @@ class AbstractRedis: # int, # ), # **string_keys_to_dict( - # # these return OK, or int if redis-server is >=1.3.4 - # "LPUSH RPUSH", - # lambda r: isinstance(r, int) and r or str_if_bytes(r) == "OK", - # ), - # **string_keys_to_dict("ZSCORE ZINCRBY", float_or_none), - # **string_keys_to_dict( # "FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE ASKING READWRITE " # "RENAME SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH ", # bool_ok, # ), - # **string_keys_to_dict( - # "SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set() - # ), - # **string_keys_to_dict( - # "ZPOPMAX ZPOPMIN ZINTER ZDIFF ZUNION ZRANGE ZRANGEBYSCORE " - # "ZREVRANGE ZREVRANGEBYSCORE", - # zset_score_pairs, - # ), # **string_keys_to_dict("ZRANK ZREVRANK", int_or_none), - # **string_keys_to_dict("XREVRANGE XRANGE", parse_stream_list), # **string_keys_to_dict("BGREWRITEAOF BGSAVE", lambda r: True), - # "ACL DELUSER": int, - # "ACL GETUSER": parse_acl_getuser, # "ACL HELP": lambda r: list(map(str_if_bytes, r)), # "ACL LIST": lambda r: list(map(str_if_bytes, r)), # "ACL LOAD": bool_ok, # "ACL SAVE": bool_ok, - # "ACL SETUSER": bool_ok, # "ACL USERS": lambda r: list(map(str_if_bytes, r)), # "CLIENT UNBLOCK": lambda r: r and int(r) == 1 or False, # "CLIENT PAUSE": bool_ok, - # "CLIENT GETREDIR": int, # "CLUSTER ADDSLOTSRANGE": bool_ok, # "CLUSTER DELSLOTSRANGE": bool_ok, # "CLUSTER GETKEYSINSLOT": lambda r: list(map(str_if_bytes, r)), # "CLUSTER REPLICAS": parse_cluster_nodes, # "CLUSTER SET-CONFIG-EPOCH": bool_ok, - # "COMMAND COUNT": int, - # "CONFIG GET": parse_config_get, # "CONFIG RESETSTAT": bool_ok, - # "CONFIG SET": bool_ok, # "DEBUG OBJECT": parse_debug_object, # "FUNCTION DELETE": bool_ok, # "FUNCTION FLUSH": bool_ok, # "FUNCTION RESTORE": bool_ok, - # "HGETALL": lambda r: r and pairs_to_dict(r) or {}, # "MEMORY PURGE": bool_ok, - # "MEMORY STATS": parse_memory_stats, # "MEMORY USAGE": int_or_none, # "MODULE LOAD": parse_module_result, # "MODULE UNLOAD": parse_module_result, - # "MODULE LIST": lambda r: [pairs_to_dict(m) for m in r], # "OBJECT": parse_object, - # "PING": lambda r: str_if_bytes(r) == "PONG", # "QUIT": bool_ok, # "STRALGO": parse_stralgo, - # "PUBSUB NUMSUB": parse_pubsub_numsub, # "RANDOMKEY": lambda r: r and r or None, # "SCRIPT EXISTS": lambda r: list(map(bool, r)), - # "SCRIPT FLUSH": bool_ok, # "SCRIPT KILL": bool_ok, - # "SCRIPT LOAD": str_if_bytes, # "SENTINEL CKQUORUM": bool_ok, # "SENTINEL FAILOVER": bool_ok, # "SENTINEL FLUSHCONFIG": bool_ok, @@ -877,17 +873,12 @@ class AbstractRedis: # "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels, # "SENTINEL SET": bool_ok, # "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels, - # "SLOWLOG LEN": int, # "SLOWLOG RESET": bool_ok, - # "XCLAIM": parse_xclaim, # "XGROUP CREATE": bool_ok, - # "XGROUP DELCONSUMER": int, # "XGROUP DESTROY": bool, # "XGROUP SETID": bool_ok, # "XINFO CONSUMERS": parse_list_of_dicts, - # "XINFO GROUPS": parse_list_of_dicts, - # "ZADD": parse_zadd, - # "ZMSCORE": parse_zmscore, + "XINFO GROUPS": parse_list_of_dicts, } RESP3_RESPONSE_CALLBACKS = { @@ -1128,6 +1119,8 @@ def __init__( if self.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: self.response_callbacks.update(self.__class__.RESP3_RESPONSE_CALLBACKS) + else: + self.response_callbacks.update(self.__class__.RESP2_RESPONSE_CALLBACKS) def __repr__(self): return f"{type(self).__name__}<{repr(self.connection_pool)}>" diff --git a/redis/connection.py b/redis/connection.py index ee3bece11c..fef31b72d6 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -288,7 +288,7 @@ def on_connect(self): auth_args = cred_provider.get_credentials() # if resp version is specified and we have auth args, # we need to send them via HELLO - if auth_args and self.protocol != 2: + if auth_args and self.protocol not in [2, "2"]: if isinstance(self._parser, _RESP2Parser): self.set_parser(_RESP3Parser) # update cluster exception classes @@ -321,7 +321,7 @@ def on_connect(self): raise AuthenticationError("Invalid Username or Password") # if resp version is specified, switch to it - elif self.protocol != 2: + elif self.protocol not in [2, "2"]: if isinstance(self._parser, _RESP2Parser): self.set_parser(_RESP3Parser) # update cluster exception classes diff --git a/tests/conftest.py b/tests/conftest.py index 6454750353..187be1189e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ REDIS_INFO = {} default_redis_url = "redis://localhost:6379/0" -default_redismod_url = "redis://localhost:36379" +default_redismod_url = "redis://localhost:6379" default_redis_unstable_url = "redis://localhost:6378" # default ssl client ignores verification for the purpose of testing diff --git a/tests/test_commands.py b/tests/test_commands.py index 97fbb34925..0bbdcb27db 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -68,7 +68,7 @@ def test_response_callbacks(self, r): assert r["a"] == "static" def test_case_insensitive_command_names(self, r): - assert r.response_callbacks["del"] == r.response_callbacks["DEL"] + assert r.response_callbacks["ping"] == r.response_callbacks["PING"] class TestRedisCommands: @@ -152,9 +152,8 @@ def teardown(): r.acl_setuser(username, keys=["*"], commands=["+set"]) assert r.acl_dryrun(username, "set", "key", "value") == b"OK" - assert r.acl_dryrun(username, "get", "key").startswith( - b"This user has no permissions to run the" - ) + no_permissions_message = b"user has no permissions to run the" + assert no_permissions_message in r.acl_dryrun(username, "get", "key") @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise() @@ -232,12 +231,12 @@ def teardown(): enabled=True, reset=True, passwords=["+pass1", "+pass2"], - categories=["+set", "+@hash", "-geo"], + categories=["+set", "+@hash", "-@geo"], commands=["+get", "+mget", "-hset"], keys=["cache:*", "objects:*"], ) acl = r.acl_getuser(username) - assert set(acl["categories"]) == {"-@all", "+@set", "+@hash"} + assert set(acl["categories"]) == {"-@all", "+@set", "+@hash", "-@geo"} assert set(acl["commands"]) == {"+get", "+mget", "-hset"} assert acl["enabled"] is True assert "on" in acl["flags"] @@ -315,7 +314,7 @@ def teardown(): selectors=[("+set", "%W~app*")], ) acl = r.acl_getuser(username) - assert set(acl["categories"]) == {"-@all", "+@set", "+@hash"} + assert set(acl["categories"]) == {"-@all", "+@set", "+@hash", "-@geo"} assert set(acl["commands"]) == {"+get", "+mget", "-hset"} assert acl["enabled"] is True assert "on" in acl["flags"] @@ -325,7 +324,7 @@ def teardown(): assert_resp_response( r, acl["selectors"], - ["commands", "-@all +set", "keys", "%W~app*", "channels", ""], + [["commands", "-@all +set", "keys", "%W~app*", "channels", ""]], [{"commands": "-@all +set", "keys": "%W~app*", "channels": ""}], ) @@ -4214,7 +4213,7 @@ def test_xgroup_setid(self, r): ] assert r.xinfo_groups(stream) == expected - @skip_if_server_version_lt("5.0.0") + @skip_if_server_version_lt("7.2.0") def test_xinfo_consumers(self, r): stream = "stream" group = "group" @@ -4230,8 +4229,8 @@ def test_xinfo_consumers(self, r): info = r.xinfo_consumers(stream, group) assert len(info) == 2 expected = [ - {"name": consumer1.encode(), "pending": 1}, - {"name": consumer2.encode(), "pending": 2}, + {"name": consumer1.encode(), "pending": 1, "inactive": 2}, + {"name": consumer2.encode(), "pending": 2, "inactive": 2}, ] # we can't determine the idle time, so just make sure it's an int From 9553dc0eef40c82fbd3416f13430868eeef188dc Mon Sep 17 00:00:00 2001 From: dvora-h Date: Wed, 14 Jun 2023 11:13:16 +0300 Subject: [PATCH 04/10] modules --- redis/commands/bf/__init__.py | 55 +- redis/commands/bf/commands.py | 3 - redis/commands/bf/info.py | 33 ++ redis/commands/json/__init__.py | 43 +- redis/commands/search/__init__.py | 23 +- redis/commands/search/commands.py | 173 +++--- redis/commands/timeseries/__init__.py | 29 +- redis/commands/timeseries/info.py | 9 + tests/test_bloom.py | 85 ++- tests/test_json.py | 301 +++++++---- tests/test_search.py | 726 +++++++++++++++++--------- tests/test_timeseries.py | 573 +++++++++++++------- 12 files changed, 1403 insertions(+), 650 deletions(-) diff --git a/redis/commands/bf/__init__.py b/redis/commands/bf/__init__.py index 4da060e995..63d866353e 100644 --- a/redis/commands/bf/__init__.py +++ b/redis/commands/bf/__init__.py @@ -97,13 +97,22 @@ def __init__(self, client, **kwargs): # CMS_INCRBY: spaceHolder, # CMS_QUERY: spaceHolder, CMS_MERGE: bool_ok, + } + + RESP2_MODULE_CALLBACKS = { CMS_INFO: CMSInfo, } + RESP3_MODULE_CALLBACKS = {} self.client = client self.commandmixin = CMSCommands self.execute_command = client.execute_command + if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + MODULE_CALLBACKS.update(RESP3_MODULE_CALLBACKS) + else: + MODULE_CALLBACKS.update(RESP2_MODULE_CALLBACKS) + for k, v in MODULE_CALLBACKS.items(): self.client.set_response_callback(k, v) @@ -114,18 +123,27 @@ def __init__(self, client, **kwargs): # Set the module commands' callbacks MODULE_CALLBACKS = { TOPK_RESERVE: bool_ok, - TOPK_ADD: parse_to_list, - TOPK_INCRBY: parse_to_list, # TOPK_QUERY: spaceHolder, # TOPK_COUNT: spaceHolder, + } + + RESP2_MODULE_CALLBACKS = { + TOPK_ADD: parse_to_list, + TOPK_INCRBY: parse_to_list, TOPK_LIST: parse_to_list, TOPK_INFO: TopKInfo, } + RESP3_MODULE_CALLBACKS = {} self.client = client self.commandmixin = TOPKCommands self.execute_command = client.execute_command + if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + MODULE_CALLBACKS.update(RESP3_MODULE_CALLBACKS) + else: + MODULE_CALLBACKS.update(RESP2_MODULE_CALLBACKS) + for k, v in MODULE_CALLBACKS.items(): self.client.set_response_callback(k, v) @@ -145,13 +163,22 @@ def __init__(self, client, **kwargs): # CF_COUNT: spaceHolder, # CF_SCANDUMP: spaceHolder, # CF_LOADCHUNK: spaceHolder, + } + + RESP2_MODULE_CALLBACKS = { CF_INFO: CFInfo, } + RESP3_MODULE_CALLBACKS = {} self.client = client self.commandmixin = CFCommands self.execute_command = client.execute_command + if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + MODULE_CALLBACKS.update(RESP3_MODULE_CALLBACKS) + else: + MODULE_CALLBACKS.update(RESP2_MODULE_CALLBACKS) + for k, v in MODULE_CALLBACKS.items(): self.client.set_response_callback(k, v) @@ -165,22 +192,29 @@ def __init__(self, client, **kwargs): # TDIGEST_RESET: bool_ok, # TDIGEST_ADD: spaceHolder, # TDIGEST_MERGE: spaceHolder, + } + + RESP2_MODULE_CALLBACKS = { + TDIGEST_BYRANK: parse_to_list, + TDIGEST_BYREVRANK: parse_to_list, TDIGEST_CDF: parse_to_list, TDIGEST_QUANTILE: parse_to_list, TDIGEST_MIN: float, TDIGEST_MAX: float, TDIGEST_TRIMMED_MEAN: float, TDIGEST_INFO: TDigestInfo, - TDIGEST_RANK: parse_to_list, - TDIGEST_REVRANK: parse_to_list, - TDIGEST_BYRANK: parse_to_list, - TDIGEST_BYREVRANK: parse_to_list, } + RESP3_MODULE_CALLBACKS = {} self.client = client self.commandmixin = TDigestCommands self.execute_command = client.execute_command + if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + MODULE_CALLBACKS.update(RESP3_MODULE_CALLBACKS) + else: + MODULE_CALLBACKS.update(RESP2_MODULE_CALLBACKS) + for k, v in MODULE_CALLBACKS.items(): self.client.set_response_callback(k, v) @@ -199,12 +233,21 @@ def __init__(self, client, **kwargs): # BF_SCANDUMP: spaceHolder, # BF_LOADCHUNK: spaceHolder, # BF_CARD: spaceHolder, + } + + RESP2_MODULE_CALLBACKS = { BF_INFO: BFInfo, } + RESP3_MODULE_CALLBACKS = {} self.client = client self.commandmixin = BFCommands self.execute_command = client.execute_command + if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + MODULE_CALLBACKS.update(RESP3_MODULE_CALLBACKS) + else: + MODULE_CALLBACKS.update(RESP2_MODULE_CALLBACKS) + for k, v in MODULE_CALLBACKS.items(): self.client.set_response_callback(k, v) diff --git a/redis/commands/bf/commands.py b/redis/commands/bf/commands.py index c45523c99b..447f844508 100644 --- a/redis/commands/bf/commands.py +++ b/redis/commands/bf/commands.py @@ -60,7 +60,6 @@ class BFCommands: """Bloom Filter commands.""" - # region Bloom Filter Functions def create(self, key, errorRate, capacity, expansion=None, noScale=None): """ Create a new Bloom Filter `key` with desired probability of false positives @@ -178,7 +177,6 @@ def card(self, key): class CFCommands: """Cuckoo Filter commands.""" - # region Cuckoo Filter Functions def create( self, key, capacity, expansion=None, bucket_size=None, max_iterations=None ): @@ -488,7 +486,6 @@ def byrevrank(self, key, rank, *ranks): class CMSCommands: """Count-Min Sketch Commands""" - # region Count-Min Sketch Functions def initbydim(self, key, width, depth): """ Initialize a Count-Min Sketch `key` to dimensions (`width`, `depth`) specified by user. diff --git a/redis/commands/bf/info.py b/redis/commands/bf/info.py index c526e6ca4c..e1f0208609 100644 --- a/redis/commands/bf/info.py +++ b/redis/commands/bf/info.py @@ -16,6 +16,15 @@ def __init__(self, args): self.insertedNum = response["Number of items inserted"] self.expansionRate = response["Expansion rate"] + def get(self, item): + try: + return self.__getitem__(item) + except AttributeError: + return None + + def __getitem__(self, item): + return getattr(self, item) + class CFInfo(object): size = None @@ -38,6 +47,15 @@ def __init__(self, args): self.expansionRate = response["Expansion rate"] self.maxIteration = response["Max iterations"] + def get(self, item): + try: + return self.__getitem__(item) + except AttributeError: + return None + + def __getitem__(self, item): + return getattr(self, item) + class CMSInfo(object): width = None @@ -50,6 +68,9 @@ def __init__(self, args): self.depth = response["depth"] self.count = response["count"] + def __getitem__(self, item): + return getattr(self, item) + class TopKInfo(object): k = None @@ -64,6 +85,9 @@ def __init__(self, args): self.depth = response["depth"] self.decay = response["decay"] + def __getitem__(self, item): + return getattr(self, item) + class TDigestInfo(object): compression = None @@ -85,3 +109,12 @@ def __init__(self, args): self.unmerged_weight = response["Unmerged weight"] self.total_compressions = response["Total compressions"] self.memory_usage = response["Memory usage"] + + def get(self, item): + try: + return self.__getitem__(item) + except AttributeError: + return None + + def __getitem__(self, item): + return getattr(self, item) diff --git a/redis/commands/json/__init__.py b/redis/commands/json/__init__.py index 7d55023e1e..a9e91fe74d 100644 --- a/redis/commands/json/__init__.py +++ b/redis/commands/json/__init__.py @@ -32,33 +32,50 @@ def __init__( """ # Set the module commands' callbacks self.MODULE_CALLBACKS = { - "JSON.CLEAR": int, - "JSON.DEL": int, - "JSON.FORGET": int, - "JSON.GET": self._decode, + "JSON.ARRPOP": self._decode, "JSON.MGET": bulk_of_jsons(self._decode), "JSON.SET": lambda r: r and nativestr(r) == "OK", - "JSON.NUMINCRBY": self._decode, - "JSON.NUMMULTBY": self._decode, + "JSON.DEBUG": self._decode, "JSON.TOGGLE": self._decode, - "JSON.STRAPPEND": self._decode, - "JSON.STRLEN": self._decode, + "JSON.RESP": self._decode, + } + + RESP2_MODULE_CALLBACKS = { + "JSON.ARRTRIM": self._decode, + "JSON.OBJLEN": self._decode, "JSON.ARRAPPEND": self._decode, "JSON.ARRINDEX": self._decode, "JSON.ARRINSERT": self._decode, + "JSON.TOGGLE": self._decode, + "JSON.STRAPPEND": self._decode, + "JSON.STRLEN": self._decode, "JSON.ARRLEN": self._decode, - "JSON.ARRPOP": self._decode, - "JSON.ARRTRIM": self._decode, - "JSON.OBJLEN": self._decode, + "JSON.CLEAR": int, + "JSON.DEL": int, + "JSON.FORGET": int, + "JSON.NUMINCRBY": self._decode, + "JSON.NUMMULTBY": self._decode, "JSON.OBJKEYS": self._decode, - "JSON.RESP": self._decode, - "JSON.DEBUG": self._decode, + "JSON.GET": self._decode, + } + + RESP3_MODULE_CALLBACKS = { + "JSON.GET": lambda response: [ + [self._decode(r) for r in res] for res in response + ] + if response + else response } self.client = client self.execute_command = client.execute_command self.MODULE_VERSION = version + if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + self.MODULE_CALLBACKS.update(RESP3_MODULE_CALLBACKS) + else: + self.MODULE_CALLBACKS.update(RESP2_MODULE_CALLBACKS) + for key, value in self.MODULE_CALLBACKS.items(): self.client.set_response_callback(key, value) diff --git a/redis/commands/search/__init__.py b/redis/commands/search/__init__.py index 70e9c279e5..228b742035 100644 --- a/redis/commands/search/__init__.py +++ b/redis/commands/search/__init__.py @@ -1,7 +1,18 @@ import redis from ...asyncio.client import Pipeline as AsyncioPipeline -from .commands import AsyncSearchCommands, SearchCommands +from .commands import ( + INFO_CMD, + SEARCH_CMD, + AGGREGATE_CMD, + PROFILE_CMD, + SPELLCHECK_CMD, + CONFIG_CMD, + SUGGET_COMMAND, + SYNDUMP_CMD, + AsyncSearchCommands, + SearchCommands, +) class Search(SearchCommands): @@ -90,6 +101,16 @@ def __init__(self, client, index_name="idx"): self.index_name = index_name self.execute_command = client.execute_command self._pipeline = client.pipeline + self.RESP2_MODULE_CALLBACKS = { + INFO_CMD: self._parse_info, + SEARCH_CMD: self._parse_search, + AGGREGATE_CMD: self._parse_aggregate, + PROFILE_CMD: self._parse_profile, + SPELLCHECK_CMD: self._parse_spellcheck, + CONFIG_CMD: self._parse_config_get, + SUGGET_COMMAND: self._parse_sugget, + SYNDUMP_CMD: self._parse_syndump, + } def pipeline(self, transaction=True, shard_hint=None): """Creates a pipeline for the SEARCH module, that can be used for executing diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 3bd7d47aa8..f448d1d84a 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -63,6 +63,94 @@ class SearchCommands: """Search commands.""" + def _parse_results(self, cmd, res, **kwargs): + if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + return res + else: + return self.RESP2_MODULE_CALLBACKS[cmd](res, **kwargs) + + def _parse_info(self, res, **kwargs): + it = map(to_string, res) + return dict(zip(it, it)) + + def _parse_search(self, res, **kwargs): + return Result( + res, + not kwargs["query"]._no_content, + duration=kwargs["duration"], + has_payload=kwargs["query"]._with_payloads, + with_scores=kwargs["query"]._with_scores, + ) + + def _parse_aggregate(self, res, **kwargs): + return self._get_aggregate_result(res, kwargs["query"], kwargs["has_cursor"]) + + def _parse_profile(self, res, **kwargs): + query = kwargs["query"] + if isinstance(query, AggregateRequest): + result = self._get_aggregate_result(res[0], query, query._cursor) + else: + result = Result( + res[0], + not query._no_content, + duration=kwargs["duration"], + has_payload=query._with_payloads, + with_scores=query._with_scores, + ) + + return result, parse_to_dict(res[1]) + + def _parse_spellcheck(self, res, **kwargs): + corrections = {} + if res == 0: + return corrections + + for _correction in res: + if isinstance(_correction, int) and _correction == 0: + continue + + if len(_correction) != 3: + continue + if not _correction[2]: + continue + if not _correction[2][0]: + continue + + # For spellcheck output + # 1) 1) "TERM" + # 2) "{term1}" + # 3) 1) 1) "{score1}" + # 2) "{suggestion1}" + # 2) 1) "{score2}" + # 2) "{suggestion2}" + # + # Following dictionary will be made + # corrections = { + # '{term1}': [ + # {'score': '{score1}', 'suggestion': '{suggestion1}'}, + # {'score': '{score2}', 'suggestion': '{suggestion2}'} + # ] + # } + corrections[_correction[1]] = [ + {"score": _item[0], "suggestion": _item[1]} for _item in _correction[2] + ] + + return corrections + + def _parse_config_get(self, res, **kwargs): + return {kvs[0]: kvs[1] for kvs in res} if res else {} + + def _parse_sugget(self, res, **kwargs): + results = [] + if not res: + return results + + parser = SuggestionParser(kwargs["with_scores"], kwargs["with_payloads"], res) + return [s for s in parser] + + def _parse_syndump(self, res, **kwargs): + return {res[i]: res[i + 1] for i in range(0, len(res), 2)} + def batch_indexer(self, chunk_size=100): """ Create a new batch indexer from the client with a given chunk size @@ -368,8 +456,7 @@ def info(self): """ res = self.execute_command(INFO_CMD, self.index_name) - it = map(to_string, res) - return dict(zip(it, it)) + return self._parse_results(INFO_CMD, res) def get_params_args( self, query_params: Union[Dict[str, Union[str, int, float]], None] @@ -422,13 +509,7 @@ def search( if isinstance(res, Pipeline): return res - return Result( - res, - not query._no_content, - duration=(time.time() - st) * 1000.0, - has_payload=query._with_payloads, - with_scores=query._with_scores, - ) + return self._parse_results(SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0) def explain( self, @@ -473,7 +554,7 @@ def aggregate( cmd += self.get_params_args(query_params) raw = self.execute_command(*cmd) - return self._get_aggregate_result(raw, query, has_cursor) + return self._parse_results(AGGREGATE_CMD, raw, query=query, has_cursor=has_cursor) def _get_aggregate_result(self, raw, query, has_cursor): if has_cursor: @@ -531,18 +612,7 @@ def profile( res = self.execute_command(*cmd) - if isinstance(query, AggregateRequest): - result = self._get_aggregate_result(res[0], query, query._cursor) - else: - result = Result( - res[0], - not query._no_content, - duration=(time.time() - st) * 1000.0, - has_payload=query._with_payloads, - with_scores=query._with_scores, - ) - - return result, parse_to_dict(res[1]) + return self._parse_results(PROFILE_CMD, res, query=query, duration=(time.time() - st) * 1000.0) def spellcheck(self, query, distance=None, include=None, exclude=None): """ @@ -568,43 +638,9 @@ def spellcheck(self, query, distance=None, include=None, exclude=None): if exclude: cmd.extend(["TERMS", "EXCLUDE", exclude]) - raw = self.execute_command(*cmd) - - corrections = {} - if raw == 0: - return corrections - - for _correction in raw: - if isinstance(_correction, int) and _correction == 0: - continue - - if len(_correction) != 3: - continue - if not _correction[2]: - continue - if not _correction[2][0]: - continue - - # For spellcheck output - # 1) 1) "TERM" - # 2) "{term1}" - # 3) 1) 1) "{score1}" - # 2) "{suggestion1}" - # 2) 1) "{score2}" - # 2) "{suggestion2}" - # - # Following dictionary will be made - # corrections = { - # '{term1}': [ - # {'score': '{score1}', 'suggestion': '{suggestion1}'}, - # {'score': '{score2}', 'suggestion': '{suggestion2}'} - # ] - # } - corrections[_correction[1]] = [ - {"score": _item[0], "suggestion": _item[1]} for _item in _correction[2] - ] + res = self.execute_command(*cmd) - return corrections + return self._parse_results(SPELLCHECK_CMD, res) def dict_add(self, name, *terms): """Adds terms to a dictionary. @@ -670,12 +706,8 @@ def config_get(self, option): For more information see `FT.CONFIG GET `_. """ # noqa cmd = [CONFIG_CMD, "GET", option] - res = {} - raw = self.execute_command(*cmd) - if raw: - for kvs in raw: - res[kvs[0]] = kvs[1] - return res + res = self.execute_command(*cmd) + return self._parse_results(CONFIG_CMD, res) def tagvals(self, tagfield): """ @@ -810,13 +842,8 @@ def sugget( if with_payloads: args.append(WITHPAYLOADS) - ret = self.execute_command(*args) - results = [] - if not ret: - return results - - parser = SuggestionParser(with_scores, with_payloads, ret) - return [s for s in parser] + res = self.execute_command(*args) + return self._parse_results(SUGGET_COMMAND, res, with_scores=with_scores, with_payloads=with_payloads) def synupdate(self, groupid, skipinitial=False, *terms): """ @@ -851,8 +878,8 @@ def syndump(self): For more information see `FT.SYNDUMP `_. """ # noqa - raw = self.execute_command(SYNDUMP_CMD, self.index_name) - return {raw[i]: raw[i + 1] for i in range(0, len(raw), 2)} + res = self.execute_command(SYNDUMP_CMD, self.index_name) + return self._parse_results(SYNDUMP_CMD, res) class AsyncSearchCommands(SearchCommands): diff --git a/redis/commands/timeseries/__init__.py b/redis/commands/timeseries/__init__.py index 4a6886f237..5b8a02466d 100644 --- a/redis/commands/timeseries/__init__.py +++ b/redis/commands/timeseries/__init__.py @@ -1,4 +1,5 @@ import redis +from redis.client import bool_ok from ..helpers import parse_to_list from .commands import ( @@ -33,26 +34,36 @@ def __init__(self, client=None, **kwargs): """Create a new RedisTimeSeries client.""" # Set the module commands' callbacks self.MODULE_CALLBACKS = { - CREATE_CMD: redis.client.bool_ok, - ALTER_CMD: redis.client.bool_ok, - CREATERULE_CMD: redis.client.bool_ok, + CREATE_CMD: bool_ok, + ALTER_CMD: bool_ok, + CREATERULE_CMD: bool_ok, + DELETERULE_CMD: bool_ok, + } + + RESP2_MODULE_CALLBACKS = { DEL_CMD: int, - DELETERULE_CMD: redis.client.bool_ok, + GET_CMD: parse_get, + QUERYINDEX_CMD: parse_to_list, RANGE_CMD: parse_range, REVRANGE_CMD: parse_range, + MGET_CMD: parse_m_get, MRANGE_CMD: parse_m_range, MREVRANGE_CMD: parse_m_range, - GET_CMD: parse_get, - MGET_CMD: parse_m_get, INFO_CMD: TSInfo, - QUERYINDEX_CMD: parse_to_list, + } + RESP3_MODULE_CALLBACKS = {} self.client = client self.execute_command = client.execute_command - for key, value in self.MODULE_CALLBACKS.items(): - self.client.set_response_callback(key, value) + if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + self.MODULE_CALLBACKS.update(RESP3_MODULE_CALLBACKS) + else: + self.MODULE_CALLBACKS.update(RESP2_MODULE_CALLBACKS) + + for k, v in self.MODULE_CALLBACKS.items(): + self.client.set_response_callback(k, v) def pipeline(self, transaction=True, shard_hint=None): """Creates a pipeline for the TimeSeries module, that can be used diff --git a/redis/commands/timeseries/info.py b/redis/commands/timeseries/info.py index 65f3baacd0..3a384dc049 100644 --- a/redis/commands/timeseries/info.py +++ b/redis/commands/timeseries/info.py @@ -80,3 +80,12 @@ def __init__(self, args): self.duplicate_policy = response["duplicatePolicy"] if type(self.duplicate_policy) == bytes: self.duplicate_policy = self.duplicate_policy.decode() + + def get(self, item): + try: + return self.__getitem__(item) + except AttributeError: + return None + + def __getitem__(self, item): + return getattr(self, item) diff --git a/tests/test_bloom.py b/tests/test_bloom.py index 30d3219404..4ee8ba29d2 100644 --- a/tests/test_bloom.py +++ b/tests/test_bloom.py @@ -6,7 +6,7 @@ from redis.exceptions import ModuleError, RedisError from redis.utils import HIREDIS_AVAILABLE -from .conftest import skip_ifmodversion_lt +from .conftest import assert_resp_response, is_resp2_connection, skip_ifmodversion_lt def intlist(obj): @@ -61,7 +61,6 @@ def test_tdigest_create(client): assert client.tdigest().create("tDigest", 100) -# region Test Bloom Filter @pytest.mark.redismod def test_bf_add(client): assert client.bf().create("bloom", 0.01, 1000) @@ -86,9 +85,24 @@ def test_bf_insert(client): assert 0 == client.bf().exists("bloom", "noexist") assert [1, 0] == intlist(client.bf().mexists("bloom", "foo", "noexist")) info = client.bf().info("bloom") - assert 2 == info.insertedNum - assert 1000 == info.capacity - assert 1 == info.filterNum + assert_resp_response( + client, + 2, + info.get("insertedNum"), + info.get("Number of items inserted"), + ) + assert_resp_response( + client, + 1000, + info.get("capacity"), + info.get("Capacity"), + ) + assert_resp_response( + client, + 1, + info.get("filterNum"), + info.get("Number of filters"), + ) @pytest.mark.redismod @@ -149,11 +163,21 @@ def test_bf_info(client): # Store a filter client.bf().create("nonscaling", "0.0001", "1000", noScale=True) info = client.bf().info("nonscaling") - assert info.expansionRate is None + assert_resp_response( + client, + None, + info.get("expansionRate"), + info.get("Expansion rate"), + ) client.bf().create("expanding", "0.0001", "1000", expansion=expansion) info = client.bf().info("expanding") - assert info.expansionRate == 4 + assert_resp_response( + client, + 4, + info.get("expansionRate"), + info.get("Expansion rate"), + ) try: # noScale mean no expansion @@ -180,7 +204,6 @@ def test_bf_card(client): client.bf().card("setKey") -# region Test Cuckoo Filter @pytest.mark.redismod def test_cf_add_and_insert(client): assert client.cf().create("cuckoo", 1000) @@ -196,9 +219,15 @@ def test_cf_add_and_insert(client): assert [1] == client.cf().insert("empty1", ["foo"], capacity=1000) assert [1] == client.cf().insertnx("empty2", ["bar"], capacity=1000) info = client.cf().info("captest") - assert 5 == info.insertedNum - assert 0 == info.deletedNum - assert 1 == info.filterNum + assert_resp_response( + client, 5, info.get("insertedNum"), info.get("Number of items inserted") + ) + assert_resp_response( + client, 0, info.get("deletedNum"), info.get("Number of items deleted") + ) + assert_resp_response( + client, 1, info.get("filterNum"), info.get("Number of filters") + ) @pytest.mark.redismod @@ -214,7 +243,6 @@ def test_cf_exists_and_del(client): assert 0 == client.cf().count("cuckoo", "filter") -# region Test Count-Min Sketch @pytest.mark.redismod def test_cms(client): assert client.cms().initbydim("dim", 1000, 5) @@ -225,9 +253,10 @@ def test_cms(client): assert [10, 15] == client.cms().incrby("dim", ["foo", "bar"], [5, 15]) assert [10, 15] == client.cms().query("dim", "foo", "bar") info = client.cms().info("dim") - assert 1000 == info.width - assert 5 == info.depth - assert 25 == info.count + assert info["width"] + assert 1000 == info["width"] + assert 5 == info["depth"] + assert 25 == info["count"] @pytest.mark.redismod @@ -248,10 +277,6 @@ def test_cms_merge(client): assert [16, 15, 21] == client.cms().query("C", "foo", "bar", "baz") -# endregion - - -# region Test Top-K @pytest.mark.redismod def test_topk(client): # test list with empty buckets @@ -326,10 +351,10 @@ def test_topk(client): assert ["A", "B", "E"] == client.topk().list("topklist") assert ["A", 4, "B", 3, "E", 3] == client.topk().list("topklist", withcount=True) info = client.topk().info("topklist") - assert 3 == info.k - assert 50 == info.width - assert 3 == info.depth - assert 0.9 == round(float(info.decay), 1) + assert 3 == info["k"] + assert 50 == info["width"] + assert 3 == info["depth"] + assert 0.9 == round(float(info["decay"]), 1) @pytest.mark.redismod @@ -346,7 +371,6 @@ def test_topk_incrby(client): ) -# region Test T-Digest @pytest.mark.redismod @pytest.mark.experimental def test_tdigest_reset(client): @@ -357,8 +381,11 @@ def test_tdigest_reset(client): assert client.tdigest().add("tDigest", list(range(10))) assert client.tdigest().reset("tDigest") - # assert we have 0 unmerged nodes - assert 0 == client.tdigest().info("tDigest").unmerged_nodes + # assert we have 0 unmerged + info = client.tdigest().info("tDigest") + assert_resp_response( + client, 0, info.get("unmerged_nodes"), info.get("Unmerged nodes") + ) @pytest.mark.redismod @@ -373,8 +400,10 @@ def test_tdigest_merge(client): assert client.tdigest().merge("to-tDigest", 1, "from-tDigest") # we should now have 110 weight on to-histogram info = client.tdigest().info("to-tDigest") - total_weight_to = float(info.merged_weight) + float(info.unmerged_weight) - assert 20 == total_weight_to + if is_resp2_connection(client): + assert 20 == float(info["merged_weight"]) + float(info["unmerged_weight"]) + else: + assert 20 == float(info["Merged weight"]) + float(info["Unmerged weight"]) # test override assert client.tdigest().create("from-override", 10) assert client.tdigest().create("from-override-2", 10) diff --git a/tests/test_json.py b/tests/test_json.py index 8e8da05609..84232b20d1 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -5,7 +5,7 @@ from redis.commands.json.decoders import decode_list, unstring from redis.commands.json.path import Path -from .conftest import skip_ifmodversion_lt +from .conftest import assert_resp_response, skip_ifmodversion_lt @pytest.fixture @@ -25,7 +25,7 @@ def test_json_setbinarykey(client): @pytest.mark.redismod def test_json_setgetdeleteforget(client): assert client.json().set("foo", Path.root_path(), "bar") - assert client.json().get("foo") == "bar" + assert_resp_response(client, client.json().get("foo"), "bar", [["bar"]]) assert client.json().get("baz") is None assert client.json().delete("foo") == 1 assert client.json().forget("foo") == 0 # second delete @@ -35,13 +35,13 @@ def test_json_setgetdeleteforget(client): @pytest.mark.redismod def test_jsonget(client): client.json().set("foo", Path.root_path(), "bar") - assert client.json().get("foo") == "bar" + assert_resp_response(client, client.json().get("foo"), "bar", [["bar"]]) @pytest.mark.redismod def test_json_get_jset(client): assert client.json().set("foo", Path.root_path(), "bar") - assert "bar" == client.json().get("foo") + assert_resp_response(client, client.json().get("foo"), "bar", [["bar"]]) assert client.json().get("baz") is None assert 1 == client.json().delete("foo") assert client.exists("foo") == 0 @@ -50,7 +50,10 @@ def test_json_get_jset(client): @pytest.mark.redismod def test_nonascii_setgetdelete(client): assert client.json().set("notascii", Path.root_path(), "hyvää-élève") - assert "hyvää-élève" == client.json().get("notascii", no_escape=True) + res = "hyvää-élève" + assert_resp_response( + client, client.json().get("notascii", no_escape=True), res, [[res]] + ) assert 1 == client.json().delete("notascii") assert client.exists("notascii") == 0 @@ -87,22 +90,30 @@ def test_mgetshouldsucceed(client): def test_clear(client): client.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) assert 1 == client.json().clear("arr", Path.root_path()) - assert [] == client.json().get("arr") + assert_resp_response(client, client.json().get("arr"), [], [[[]]]) @pytest.mark.redismod def test_type(client): client.json().set("1", Path.root_path(), 1) - assert "integer" == client.json().type("1", Path.root_path()) - assert "integer" == client.json().type("1") + assert_resp_response( + client, client.json().type("1", Path.root_path()), "integer", ["integer"] + ) + assert_resp_response(client, client.json().type("1"), "integer", ["integer"]) @pytest.mark.redismod def test_numincrby(client): client.json().set("num", Path.root_path(), 1) - assert 2 == client.json().numincrby("num", Path.root_path(), 1) - assert 2.5 == client.json().numincrby("num", Path.root_path(), 0.5) - assert 1.25 == client.json().numincrby("num", Path.root_path(), -1.25) + assert_resp_response( + client, client.json().numincrby("num", Path.root_path(), 1), 2, [2] + ) + assert_resp_response( + client, client.json().numincrby("num", Path.root_path(), 0.5), 2.5, [2.5] + ) + assert_resp_response( + client, client.json().numincrby("num", Path.root_path(), -1.25), 1.25, [1.25] + ) @pytest.mark.redismod @@ -110,9 +121,15 @@ def test_nummultby(client): client.json().set("num", Path.root_path(), 1) with pytest.deprecated_call(): - assert 2 == client.json().nummultby("num", Path.root_path(), 2) - assert 5 == client.json().nummultby("num", Path.root_path(), 2.5) - assert 2.5 == client.json().nummultby("num", Path.root_path(), 0.5) + assert_resp_response( + client, client.json().nummultby("num", Path.root_path(), 2), 2, [2] + ) + assert_resp_response( + client, client.json().nummultby("num", Path.root_path(), 2.5), 5, [5] + ) + assert_resp_response( + client, client.json().nummultby("num", Path.root_path(), 0.5), 2.5, [2.5] + ) @pytest.mark.redismod @@ -131,7 +148,9 @@ def test_toggle(client): def test_strappend(client): client.json().set("jsonkey", Path.root_path(), "foo") assert 6 == client.json().strappend("jsonkey", "bar") - assert "foobar" == client.json().get("jsonkey", Path.root_path()) + assert_resp_response( + client, client.json().get("jsonkey", Path.root_path()), "foobar", [["foobar"]] + ) # @pytest.mark.redismod @@ -177,12 +196,14 @@ def test_arrindex(client): def test_arrinsert(client): client.json().set("arr", Path.root_path(), [0, 4]) assert 5 - -client.json().arrinsert("arr", Path.root_path(), 1, *[1, 2, 3]) - assert [0, 1, 2, 3, 4] == client.json().get("arr") + res = [0, 1, 2, 3, 4] + assert_resp_response(client, client.json().get("arr"), res, [[res]]) # test prepends client.json().set("val2", Path.root_path(), [5, 6, 7, 8, 9]) client.json().arrinsert("val2", Path.root_path(), 0, ["some", "thing"]) - assert client.json().get("val2") == [["some", "thing"], 5, 6, 7, 8, 9] + res = [["some", "thing"], 5, 6, 7, 8, 9] + assert_resp_response(client, client.json().get("val2"), res, [[res]]) @pytest.mark.redismod @@ -200,7 +221,7 @@ def test_arrpop(client): assert 3 == client.json().arrpop("arr", Path.root_path(), -1) assert 2 == client.json().arrpop("arr", Path.root_path()) assert 0 == client.json().arrpop("arr", Path.root_path(), 0) - assert [1] == client.json().get("arr") + assert_resp_response(client, client.json().get("arr"), [1], [[[1]]]) # test out of bounds client.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) @@ -215,7 +236,7 @@ def test_arrpop(client): def test_arrtrim(client): client.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) assert 3 == client.json().arrtrim("arr", Path.root_path(), 1, 3) - assert [1, 2, 3] == client.json().get("arr") + assert_resp_response(client, client.json().get("arr"), [1, 2, 3], [[[1, 2, 3]]]) # <0 test, should be 0 equivalent client.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) @@ -277,7 +298,7 @@ def test_json_commands_in_pipeline(client): p.set("foo", Path.root_path(), "bar") p.get("foo") p.delete("foo") - assert [True, "bar", 1] == p.execute() + assert_resp_response(client, p.execute(), [True, "bar", 1], [True, [["bar"]], 1]) assert client.keys() == [] assert client.get("foo") is None @@ -290,7 +311,7 @@ def test_json_commands_in_pipeline(client): p.jsonget("foo") p.exists("notarealkey") p.delete("foo") - assert [True, d, 0, 1] == p.execute() + assert_resp_response(client, p.execute(), [True, d, 0, 1], [True, [[d]], 0, 1]) assert client.keys() == [] assert client.get("foo") is None @@ -300,14 +321,14 @@ def test_json_delete_with_dollar(client): doc1 = {"a": 1, "nested": {"a": 2, "b": 3}} assert client.json().set("doc1", "$", doc1) assert client.json().delete("doc1", "$..a") == 2 - r = client.json().get("doc1", "$") - assert r == [{"nested": {"b": 3}}] + res = [{"nested": {"b": 3}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) doc2 = {"a": {"a": 2, "b": 3}, "b": ["a", "b"], "nested": {"b": [True, "a", "b"]}} assert client.json().set("doc2", "$", doc2) assert client.json().delete("doc2", "$..a") == 1 - res = client.json().get("doc2", "$") - assert res == [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + res = [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + assert_resp_response(client, client.json().get("doc2", "$"), res, [res]) doc3 = [ { @@ -338,8 +359,7 @@ def test_json_delete_with_dollar(client): } ] ] - res = client.json().get("doc3", "$") - assert res == doc3val + assert_resp_response(client, client.json().get("doc3", "$"), doc3val, [doc3val]) # Test default path assert client.json().delete("doc3") == 1 @@ -353,14 +373,14 @@ def test_json_forget_with_dollar(client): doc1 = {"a": 1, "nested": {"a": 2, "b": 3}} assert client.json().set("doc1", "$", doc1) assert client.json().forget("doc1", "$..a") == 2 - r = client.json().get("doc1", "$") - assert r == [{"nested": {"b": 3}}] + res = [{"nested": {"b": 3}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) doc2 = {"a": {"a": 2, "b": 3}, "b": ["a", "b"], "nested": {"b": [True, "a", "b"]}} assert client.json().set("doc2", "$", doc2) assert client.json().forget("doc2", "$..a") == 1 - res = client.json().get("doc2", "$") - assert res == [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + res = [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + assert_resp_response(client, client.json().get("doc2", "$"), res, [res]) doc3 = [ { @@ -391,8 +411,7 @@ def test_json_forget_with_dollar(client): } ] ] - res = client.json().get("doc3", "$") - assert res == doc3val + assert_resp_response(client, client.json().get("doc3", "$"), doc3val, [doc3val]) # Test default path assert client.json().forget("doc3") == 1 @@ -415,8 +434,10 @@ def test_json_mget_dollar(client): {"a": 4, "b": 5, "nested": {"a": 6}, "c": None, "nested2": {"a": [None]}}, ) # Compare also to single JSON.GET - assert client.json().get("doc1", "$..a") == [1, 3, None] - assert client.json().get("doc2", "$..a") == [4, 6, [None]] + res = [1, 3, None] + assert_resp_response(client, client.json().get("doc1", "$..a"), res, [res]) + res = [4, 6, [None]] + assert_resp_response(client, client.json().get("doc2", "$..a"), res, [res]) # Test mget with single path client.json().mget("doc1", "$..a") == [1, 3, None] @@ -483,15 +504,14 @@ def test_strappend_dollar(client): # Test multi client.json().strappend("doc1", "bar", "$..a") == [6, 8, None] - client.json().get("doc1", "$") == [ - {"a": "foobar", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}} - ] + # res = [{"a": "foobar", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}}] + # assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) + # Test single client.json().strappend("doc1", "baz", "$.nested1.a") == [11] - client.json().get("doc1", "$") == [ - {"a": "foobar", "nested1": {"a": "hellobarbaz"}, "nested2": {"a": 31}} - ] + # res = [{"a": "foobar", "nested1": {"a": "hellobarbaz"}, "nested2": {"a": 31}}] + # assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -499,9 +519,8 @@ def test_strappend_dollar(client): # Test multi client.json().strappend("doc1", "bar", ".*.a") == 8 - client.json().get("doc1", "$") == [ - {"a": "foo", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}} - ] + # res = [{"a": "foo", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}}] + # assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing path with pytest.raises(exceptions.ResponseError): @@ -543,23 +562,25 @@ def test_arrappend_dollar(client): ) # Test multi client.json().arrappend("doc1", "$..a", "bar", "racuda") == [3, 5, None] - assert client.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda"]}, "nested2": {"a": 31}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test single assert client.json().arrappend("doc1", "$.nested1.a", "baz") == [6] - assert client.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda", "baz"]}, "nested2": {"a": 31}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -578,22 +599,25 @@ def test_arrappend_dollar(client): # Test multi (all paths are updated, but return result of last path) assert client.json().arrappend("doc1", "..a", "bar", "racuda") == 5 - assert client.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda"]}, "nested2": {"a": 31}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) + # Test single assert client.json().arrappend("doc1", ".nested1.a", "baz") == 6 - assert client.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda", "baz"]}, "nested2": {"a": 31}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -614,22 +638,25 @@ def test_arrinsert_dollar(client): # Test multi assert client.json().arrinsert("doc1", "$..a", "1", "bar", "racuda") == [3, 5, None] - assert client.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", "bar", "racuda", None, "world"]}, "nested2": {"a": 31}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) + # Test single assert client.json().arrinsert("doc1", "$.nested1.a", -2, "baz") == [6] - assert client.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", "bar", "racuda", "baz", None, "world"]}, "nested2": {"a": 31}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -701,9 +728,8 @@ def test_arrpop_dollar(client): # # # Test multi assert client.json().arrpop("doc1", "$..a", 1) == ['"foo"', None, None] - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -721,9 +747,8 @@ def test_arrpop_dollar(client): ) # Test multi (all paths are updated, but return result of last path) client.json().arrpop("doc1", "..a", "1") is None - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # # Test missing key with pytest.raises(exceptions.ResponseError): @@ -744,19 +769,17 @@ def test_arrtrim_dollar(client): ) # Test multi assert client.json().arrtrim("doc1", "$..a", "1", -1) == [0, 2, None] - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": [None, "world"]}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": [None, "world"]}, "nested2": {"a": 31}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) assert client.json().arrtrim("doc1", "$..a", "1", "1") == [0, 1, None] - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) + # Test single assert client.json().arrtrim("doc1", "$.nested1.a", 1, 0) == [0] - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": []}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": []}, "nested2": {"a": 31}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -778,9 +801,8 @@ def test_arrtrim_dollar(client): # Test single assert client.json().arrtrim("doc1", ".nested1.a", "1", "1") == 1 - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -878,13 +900,17 @@ def test_type_dollar(client): jdata, jtypes = load_types_data("a") client.json().set("doc1", "$", jdata) # Test multi - assert client.json().type("doc1", "$..a") == jtypes + assert_resp_response(client, client.json().type("doc1", "$..a"), jtypes, [jtypes]) # Test single - assert client.json().type("doc1", "$.nested2.a") == [jtypes[1]] + assert_resp_response( + client, client.json().type("doc1", "$.nested2.a"), [jtypes[1]], [[jtypes[1]]] + ) # Test missing key - assert client.json().type("non_existing_doc", "..a") is None + assert_resp_response( + client, client.json().type("non_existing_doc", "..a"), None, [None] + ) @pytest.mark.redismod @@ -902,9 +928,10 @@ def test_clear_dollar(client): # Test multi assert client.json().clear("doc1", "$..a") == 3 - assert client.json().get("doc1", "$") == [ + res = [ {"nested1": {"a": {}}, "a": [], "nested2": {"a": "claro"}, "nested3": {"a": {}}} ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test single client.json().set( @@ -918,7 +945,7 @@ def test_clear_dollar(client): }, ) assert client.json().clear("doc1", "$.nested1.a") == 1 - assert client.json().get("doc1", "$") == [ + res = [ { "nested1": {"a": {}}, "a": ["foo"], @@ -926,10 +953,11 @@ def test_clear_dollar(client): "nested3": {"a": {"baz": 50}}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing path (defaults to root) assert client.json().clear("doc1") == 1 - assert client.json().get("doc1", "$") == [{}] + assert_resp_response(client, client.json().get("doc1", "$"), [{}], [[{}]]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -950,7 +978,7 @@ def test_toggle_dollar(client): ) # Test multi assert client.json().toggle("doc1", "$..a") == [None, 1, None, 0] - assert client.json().get("doc1", "$") == [ + res = [ { "a": ["foo"], "nested1": {"a": True}, @@ -958,6 +986,7 @@ def test_toggle_dollar(client): "nested3": {"a": False}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -1033,7 +1062,7 @@ def test_resp_dollar(client): client.json().set("doc1", "$", data) # Test multi res = client.json().resp("doc1", "$..a") - assert res == [ + resp2 = [ [ "{", "A1_B1", @@ -1089,10 +1118,67 @@ def test_resp_dollar(client): ["{", "A2_B4_C1", "bar"], ], ] + resp3 = [ + [ + "{", + "A1_B1", + 10, + "A1_B2", + "false", + "A1_B3", + [ + "{", + "A1_B3_C1", + None, + "A1_B3_C2", + [ + "[", + "A1_B3_C2_D1_1", + "A1_B3_C2_D1_2", + -19.5, + "A1_B3_C2_D1_4", + "A1_B3_C2_D1_5", + ["{", "A1_B3_C2_D1_6_E1", "true"], + ], + "A1_B3_C3", + ["[", 1], + ], + "A1_B4", + ["{", "A1_B4_C1", "foo"], + ], + [ + "{", + "A2_B1", + 20, + "A2_B2", + "false", + "A2_B3", + [ + "{", + "A2_B3_C1", + None, + "A2_B3_C2", + [ + "[", + "A2_B3_C2_D1_1", + "A2_B3_C2_D1_2", + -37.5, + "A2_B3_C2_D1_4", + "A2_B3_C2_D1_5", + ["{", "A2_B3_C2_D1_6_E1", "false"], + ], + "A2_B3_C3", + ["[", 2], + ], + "A2_B4", + ["{", "A2_B4_C1", "bar"], + ], + ] + assert_resp_response(client, res, resp2, resp3) # Test single - resSingle = client.json().resp("doc1", "$.L1.a") - assert resSingle == [ + res = client.json().resp("doc1", "$.L1.a") + resp2 = [ [ "{", "A1_B1", @@ -1121,6 +1207,36 @@ def test_resp_dollar(client): ["{", "A1_B4_C1", "foo"], ] ] + resp3 = [ + [ + "{", + "A1_B1", + 10, + "A1_B2", + "false", + "A1_B3", + [ + "{", + "A1_B3_C1", + None, + "A1_B3_C2", + [ + "[", + "A1_B3_C2_D1_1", + "A1_B3_C2_D1_2", + -19.5, + "A1_B3_C2_D1_4", + "A1_B3_C2_D1_5", + ["{", "A1_B3_C2_D1_6_E1", "true"], + ], + "A1_B3_C3", + ["[", 1], + ], + "A1_B4", + ["{", "A1_B4_C1", "foo"], + ] + ] + assert_resp_response(client, res, resp2, resp3) # Test missing path client.json().resp("doc1", "$.nowhere") @@ -1175,10 +1291,13 @@ def test_arrindex_dollar(client): }, ) - assert client.json().get("store", "$.store.book[?(@.price<10)].size") == [ - [10, 20, 30, 40], - [5, 10, 20, 30], - ] + assert_resp_response( + client, + client.json().get("store", "$.store.book[?(@.price<10)].size"), + [[10, 20, 30, 40], [5, 10, 20, 30]], + [[[10, 20, 30, 40], [5, 10, 20, 30]]], + ) + assert client.json().arrindex( "store", "$.store.book[?(@.price<10)].size", "20" ) == [-1, -1] @@ -1199,13 +1318,14 @@ def test_arrindex_dollar(client): ], ) - assert client.json().get("test_num", "$..arr") == [ + res = [ [0, 1, 3.0, 3, 2, 1, 0, 3], [5, 4, 3, 2, 1, 0, 1, 2, 3.0, 2, 4, 5], [2, 4, 6], "3", [], ] + assert_resp_response(client, client.json().get("test_num", "$..arr"), res, [res]) assert client.json().arrindex("test_num", "$..arr", 3) == [3, 2, -1, None, -1] @@ -1231,13 +1351,14 @@ def test_arrindex_dollar(client): ], ], ) - assert client.json().get("test_string", "$..arr") == [ + res = [ ["bazzz", "bar", 2, "baz", 2, "ba", "baz", 3], [None, "baz2", "buzz", 2, 1, 0, 1, "2", "baz", 2, 4, 5], ["baz2", 4, 6], "3", [], ] + assert_resp_response(client, client.json().get("test_string", "$..arr"), res, [res]) assert client.json().arrindex("test_string", "$..arr", "baz") == [ 3, @@ -1323,13 +1444,14 @@ def test_arrindex_dollar(client): ], ], ) - assert client.json().get("test_None", "$..arr") == [ + res = [ ["bazzz", "None", 2, None, 2, "ba", "baz", 3], ["zaz", "baz2", "buzz", 2, 1, 0, 1, "2", None, 2, 4, 5], ["None", 4, 6], None, [], ] + assert_resp_response(client, client.json().get("test_None", "$..arr"), res, [res]) # Test with none-scalar value assert client.json().arrindex( @@ -1370,7 +1492,7 @@ def test_custom_decoder(client): cj = client.json(encoder=ujson, decoder=ujson) assert cj.set("foo", Path.root_path(), "bar") - assert "bar" == cj.get("foo") + assert_resp_response(client, cj.get("foo"), "bar", [["bar"]]) assert cj.get("baz") is None assert 1 == cj.delete("foo") assert client.exists("foo") == 0 @@ -1392,7 +1514,7 @@ def test_set_file(client): nojsonfile.write(b"Hello World") assert client.json().set_file("test", Path.root_path(), jsonfile.name) - assert client.json().get("test") == obj + assert_resp_response(client, client.json().get("test"), obj, [[obj]]) with pytest.raises(json.JSONDecodeError): client.json().set_file("test2", Path.root_path(), nojsonfile.name) @@ -1414,4 +1536,7 @@ def test_set_path(client): result = {jsonfile: True, nojsonfile: False} assert client.json().set_path(Path.root_path(), root) == result - assert client.json().get(jsonfile.rsplit(".")[0]) == {"hello": "world"} + res = {"hello": "world"} + assert_resp_response( + client, client.json().get(jsonfile.rsplit(".")[0]), res, [[res]] + ) diff --git a/tests/test_search.py b/tests/test_search.py index 7a2428151e..99bb327d23 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -24,7 +24,7 @@ from redis.commands.search.result import Result from redis.commands.search.suggestion import Suggestion -from .conftest import skip_if_redis_enterprise, skip_ifmodversion_lt +from .conftest import assert_resp_response, skip_if_redis_enterprise, skip_ifmodversion_lt, is_resp2_connection WILL_PLAY_TEXT = os.path.abspath( os.path.join(os.path.dirname(__file__), "testdata", "will_play_text.csv.bz2") @@ -40,12 +40,16 @@ def waitForIndex(env, idx, timeout=None): while True: res = env.execute_command("FT.INFO", idx) try: - res.index("indexing") + if int(res[res.index("indexing") + 1]) == 0: + break except ValueError: break - - if int(res[res.index("indexing") + 1]) == 0: - break + except AttributeError: + try: + if int(res["indexing"]) == 0: + break + except ValueError: + break time.sleep(delay) if timeout is not None: @@ -223,12 +227,16 @@ def test_scores(client): q = Query("foo ~bar").with_scores() res = client.ft().search(q) - assert 2 == res.total - assert "doc2" == res.docs[0].id - assert 3.0 == res.docs[0].score - assert "doc1" == res.docs[1].id - # todo: enable once new RS version is tagged - # self.assertEqual(0.2, res.docs[1].score) + if is_resp2_connection(client): + assert 2 == res.total + assert "doc2" == res.docs[0].id + assert 3.0 == res.docs[0].score + assert "doc1" == res.docs[1].id + else: + assert 2 == res["total_results"] + assert "doc2" == res["results"][0]["id"] + assert 3.0 == res["results"][0]["score"] + assert "doc1" == res["results"][1]["id"] @pytest.mark.redismod @@ -241,8 +249,12 @@ def test_stopwords(client): q1 = Query("foo bar").no_content() q2 = Query("foo bar hello world").no_content() res1, res2 = client.ft().search(q1), client.ft().search(q2) - assert 0 == res1.total - assert 1 == res2.total + if is_resp2_connection(client): + assert 0 == res1.total + assert 1 == res2.total + else: + assert 0 == res1["total_results"] + assert 1 == res2["total_results"] @pytest.mark.redismod @@ -262,25 +274,41 @@ def test_filters(client): .no_content() ) res1, res2 = client.ft().search(q1), client.ft().search(q2) - - assert 1 == res1.total - assert 1 == res2.total - assert "doc2" == res1.docs[0].id - assert "doc1" == res2.docs[0].id + if is_resp2_connection(client): + assert 1 == res1.total + assert 1 == res2.total + assert "doc2" == res1.docs[0].id + assert "doc1" == res2.docs[0].id + else: + assert 1 == res1["total_results"] + assert 1 == res2["total_results"] + assert "doc2" == res1["results"][0]["id"] + assert "doc1" == res2["results"][0]["id"] # Test geo filter q1 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 10)).no_content() q2 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 100)).no_content() res1, res2 = client.ft().search(q1), client.ft().search(q2) - assert 1 == res1.total - assert 2 == res2.total - assert "doc1" == res1.docs[0].id + if is_resp2_connection(client): + assert 1 == res1.total + assert 2 == res2.total + assert "doc1" == res1.docs[0].id + + # Sort results, after RDB reload order may change + res = [res2.docs[0].id, res2.docs[1].id] + res.sort() + assert ["doc1", "doc2"] == res + else: + assert 1 == res1["total_results"] + assert 2 == res2["total_results"] + assert "doc1" == res1["results"][0]["id"] + + # Sort results, after RDB reload order may change + res = [res2["results"][0]["id"], res2["results"][1]["id"]] + res.sort() + assert ["doc1", "doc2"] == res - # Sort results, after RDB reload order may change - res = [res2.docs[0].id, res2.docs[1].id] - res.sort() - assert ["doc1", "doc2"] == res @pytest.mark.redismod @@ -295,14 +323,24 @@ def test_sort_by(client): q2 = Query("foo").sort_by("num", asc=False).no_content() res1, res2 = client.ft().search(q1), client.ft().search(q2) - assert 3 == res1.total - assert "doc1" == res1.docs[0].id - assert "doc2" == res1.docs[1].id - assert "doc3" == res1.docs[2].id - assert 3 == res2.total - assert "doc1" == res2.docs[2].id - assert "doc2" == res2.docs[1].id - assert "doc3" == res2.docs[0].id + if is_resp2_connection(client): + assert 3 == res1.total + assert "doc1" == res1.docs[0].id + assert "doc2" == res1.docs[1].id + assert "doc3" == res1.docs[2].id + assert 3 == res2.total + assert "doc1" == res2.docs[2].id + assert "doc2" == res2.docs[1].id + assert "doc3" == res2.docs[0].id + else: + assert 3 == res1["total_results"] + assert "doc1" == res1["results"][0]["id"] + assert "doc2" == res1["results"][1]["id"] + assert "doc3" == res1["results"][2]["id"] + assert 3 == res2["total_results"] + assert "doc1" == res2["results"][2]["id"] + assert "doc2" == res2["results"][1]["id"] + assert "doc3" == res2["results"][0]["id"] @pytest.mark.redismod @@ -417,27 +455,50 @@ def test_no_index(client): ) waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - res = client.ft().search(Query("@text:aa*")) - assert 0 == res.total + if is_resp2_connection(client): + res = client.ft().search(Query("@text:aa*")) + assert 0 == res.total - res = client.ft().search(Query("@field:aa*")) - assert 2 == res.total + res = client.ft().search(Query("@field:aa*")) + assert 2 == res.total - res = client.ft().search(Query("*").sort_by("text", asc=False)) - assert 2 == res.total - assert "doc2" == res.docs[0].id + res = client.ft().search(Query("*").sort_by("text", asc=False)) + assert 2 == res.total + assert "doc2" == res.docs[0].id - res = client.ft().search(Query("*").sort_by("text", asc=True)) - assert "doc1" == res.docs[0].id + res = client.ft().search(Query("*").sort_by("text", asc=True)) + assert "doc1" == res.docs[0].id - res = client.ft().search(Query("*").sort_by("numeric", asc=True)) - assert "doc1" == res.docs[0].id + res = client.ft().search(Query("*").sort_by("numeric", asc=True)) + assert "doc1" == res.docs[0].id - res = client.ft().search(Query("*").sort_by("geo", asc=True)) - assert "doc1" == res.docs[0].id + res = client.ft().search(Query("*").sort_by("geo", asc=True)) + assert "doc1" == res.docs[0].id - res = client.ft().search(Query("*").sort_by("tag", asc=True)) - assert "doc1" == res.docs[0].id + res = client.ft().search(Query("*").sort_by("tag", asc=True)) + assert "doc1" == res.docs[0].id + else: + res = client.ft().search(Query("@text:aa*")) + assert 0 == res["total_results"] + + res = client.ft().search(Query("@field:aa*")) + assert 2 == res["total_results"] + + res = client.ft().search(Query("*").sort_by("text", asc=False)) + assert 2 == res["total_results"] + assert "doc2" == res["results"][0]["id"] + + res = client.ft().search(Query("*").sort_by("text", asc=True)) + assert "doc1" == res["results"][0]["id"] + + res = client.ft().search(Query("*").sort_by("numeric", asc=True)) + assert "doc1" == res["results"][0]["id"] + + res = client.ft().search(Query("*").sort_by("geo", asc=True)) + assert "doc1" == res["results"][0]["id"] + + res = client.ft().search(Query("*").sort_by("tag", asc=True)) + assert "doc1" == res["results"][0]["id"] # Ensure exception is raised for non-indexable, non-sortable fields with pytest.raises(Exception): @@ -472,21 +533,38 @@ def test_summarize(client): q.highlight(fields=("play", "txt"), tags=("", "")) q.summarize("txt") - doc = sorted(client.ft().search(q).docs)[0] - assert "Henry IV" == doc.play - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt - ) + if is_resp2_connection(client): + doc = sorted(client.ft().search(q).docs)[0] + assert "Henry IV" == doc.play + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc.txt + ) - q = Query("king henry").paging(0, 1).summarize().highlight() + q = Query("king henry").paging(0, 1).summarize().highlight() - doc = sorted(client.ft().search(q).docs)[0] - assert "Henry ... " == doc.play - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt - ) + doc = sorted(client.ft().search(q).docs)[0] + assert "Henry ... " == doc.play + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc.txt + ) + else: + doc = sorted(client.ft().search(q)["results"])[0] + assert "Henry IV" == doc["fields"]["play"] + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc["fields"]["txt"] + ) + + q = Query("king henry").paging(0, 1).summarize().highlight() + + doc = sorted(client.ft().search(q)["results"])[0] + assert "Henry ... " == doc["fields"]["play"] + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc["fields"]["txt"] + ) @pytest.mark.redismod @@ -506,25 +584,46 @@ def test_alias(client): index1.hset("index1:lonestar", mapping={"name": "lonestar"}) index2.hset("index2:yogurt", mapping={"name": "yogurt"}) - res = ftindex1.search("*").docs[0] - assert "index1:lonestar" == res.id + if is_resp2_connection(client): + res = ftindex1.search("*").docs[0] + assert "index1:lonestar" == res.id - # create alias and check for results - ftindex1.aliasadd("spaceballs") - alias_client = getClient(client).ft("spaceballs") - res = alias_client.search("*").docs[0] - assert "index1:lonestar" == res.id + # create alias and check for results + ftindex1.aliasadd("spaceballs") + alias_client = getClient(client).ft("spaceballs") + res = alias_client.search("*").docs[0] + assert "index1:lonestar" == res.id - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - ftindex2.aliasadd("spaceballs") + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + ftindex2.aliasadd("spaceballs") + + # update alias and ensure new results + ftindex2.aliasupdate("spaceballs") + alias_client2 = getClient(client).ft("spaceballs") + + res = alias_client2.search("*").docs[0] + assert "index2:yogurt" == res.id + else: + res = ftindex1.search("*")["results"][0] + assert "index1:lonestar" == res["id"] + + # create alias and check for results + ftindex1.aliasadd("spaceballs") + alias_client = getClient(client).ft("spaceballs") + res = alias_client.search("*")["results"][0] + assert "index1:lonestar" == res["id"] + + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + ftindex2.aliasadd("spaceballs") - # update alias and ensure new results - ftindex2.aliasupdate("spaceballs") - alias_client2 = getClient(client).ft("spaceballs") + # update alias and ensure new results + ftindex2.aliasupdate("spaceballs") + alias_client2 = getClient(client).ft("spaceballs") - res = alias_client2.search("*").docs[0] - assert "index2:yogurt" == res.id + res = alias_client2.search("*")["results"][0] + assert "index2:yogurt" == res["id"] ftindex2.aliasdel("spaceballs") with pytest.raises(Exception): @@ -547,18 +646,32 @@ def test_alias_basic(client): # add the actual alias and check index1.aliasadd("myalias") alias_client = getClient(client).ft("myalias") - res = sorted(alias_client.search("*").docs, key=lambda x: x.id) - assert "doc1" == res[0].id - - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - index2.aliasadd("myalias") - - # update the alias and ensure we get doc2 - index2.aliasupdate("myalias") - alias_client2 = getClient(client).ft("myalias") - res = sorted(alias_client2.search("*").docs, key=lambda x: x.id) - assert "doc1" == res[0].id + if is_resp2_connection(client): + res = sorted(alias_client.search("*").docs, key=lambda x: x.id) + assert "doc1" == res[0].id + + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + index2.aliasadd("myalias") + + # update the alias and ensure we get doc2 + index2.aliasupdate("myalias") + alias_client2 = getClient(client).ft("myalias") + res = sorted(alias_client2.search("*").docs, key=lambda x: x.id) + assert "doc1" == res[0].id + else: + res = sorted(alias_client.search("*")["results"], key=lambda x: x["id"]) + assert "doc1" == res[0]["id"] + + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + index2.aliasadd("myalias") + + # update the alias and ensure we get doc2 + index2.aliasupdate("myalias") + alias_client2 = getClient(client).ft("myalias") + res = sorted(alias_client2.search("*")["results"], key=lambda x: x["id"]) + assert "doc1" == res[0]["id"] # delete the alias and expect an error if we try to query again index2.aliasdel("myalias") @@ -573,8 +686,12 @@ def test_textfield_sortable_nostem(client): # Now get the index info to confirm its contents response = client.ft().info() - assert "SORTABLE" in response["attributes"][0] - assert "NOSTEM" in response["attributes"][0] + if is_resp2_connection(client): + assert "SORTABLE" in response["attributes"][0] + assert "NOSTEM" in response["attributes"][0] + else: + assert "SORTABLE" in response["attributes"][0]["flags"] + assert "NOSTEM" in response["attributes"][0]["flags"] @pytest.mark.redismod @@ -595,7 +712,10 @@ def test_alter_schema_add(client): # Ensure we find the result searching on the added body field res = client.ft().search(q) - assert 1 == res.total + if is_resp2_connection(client): + assert 1 == res.total + else: + assert 1 == res["total_results"] @pytest.mark.redismod @@ -650,7 +770,7 @@ def test_dict_operations(client): # Dump dict and inspect content res = client.ft().dict_dump("custom_dict") - assert ["item1", "item3"] == res + assert_resp_response(client, res, ["item1", "item3"], {"item1", "item3"}) # Remove rest of the items before reload client.ft().dict_del("custom_dict", *res) @@ -663,8 +783,12 @@ def test_phonetic_matcher(client): client.hset("doc2", mapping={"name": "John"}) res = client.ft().search(Query("Jon")) - assert 1 == len(res.docs) - assert "Jon" == res.docs[0].name + if is_resp2_connection(client): + assert 1 == len(res.docs) + assert "Jon" == res.docs[0].name + else: + assert 1 == res["total_results"] + assert "Jon" == res["results"][0]["fields"]["name"] # Drop and create index with phonetic matcher client.flushdb() @@ -674,8 +798,12 @@ def test_phonetic_matcher(client): client.hset("doc2", mapping={"name": "John"}) res = client.ft().search(Query("Jon")) - assert 2 == len(res.docs) - assert ["John", "Jon"] == sorted(d.name for d in res.docs) + if is_resp2_connection(client): + assert 2 == len(res.docs) + assert ["John", "Jon"] == sorted(d.name for d in res.docs) + else: + assert 2 == res["total_results"] + assert ["John", "Jon"] == sorted(d["fields"]["name"] for d in res["results"]) @pytest.mark.redismod @@ -694,20 +822,36 @@ def test_scorer(client): ) # default scorer is TFIDF - res = client.ft().search(Query("quick").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) - assert 0.1111111111111111 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("BM25").with_scores()) - assert 0.17699114465425977 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) - assert 2.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("HAMMING").with_scores()) - assert 0.0 == res.docs[0].score + if is_resp2_connection(client): + res = client.ft().search(Query("quick").with_scores()) + assert 1.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) + assert 1.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) + assert 0.1111111111111111 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("BM25").with_scores()) + assert 0.17699114465425977 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) + assert 2.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) + assert 1.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("HAMMING").with_scores()) + assert 0.0 == res.docs[0].score + else: + res = client.ft().search(Query("quick").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) + assert 0.1111111111111111 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("BM25").with_scores()) + assert 0.17699114465425977 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) + assert 2.0 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("HAMMING").with_scores()) + assert 0.0 == res["results"][0]["score"] @pytest.mark.redismod @@ -1060,7 +1204,11 @@ def test_skip_initial_scan(client): q = Query("@foo:bar") client.ft().create_index((TextField("foo"),), skip_initial_scan=True) - assert 0 == client.ft().search(q).total + res = client.ft().search(q) + if is_resp2_connection(client): + assert res.total == 0 + else: + assert res["total_results"] == 0 @pytest.mark.redismod @@ -1148,10 +1296,16 @@ def test_create_client_definition_json(client): client.json().set("king:2", Path.root_path(), {"name": "james"}) res = client.ft().search("henry") - assert res.docs[0].id == "king:1" - assert res.docs[0].payload is None - assert res.docs[0].json == '{"name":"henry"}' - assert res.total == 1 + if is_resp2_connection(client): + assert res.docs[0].id == "king:1" + assert res.docs[0].payload is None + assert res.docs[0].json == '{"name":"henry"}' + assert res.total == 1 + else: + assert res["results"][0]["id"] == "king:1" + # assert res["results"][0]["payload"] is None + # assert res["results"][0]["json"] == '{"name":"henry"}' + assert res["total_results"] == 1 @pytest.mark.redismod @@ -1169,11 +1323,17 @@ def test_fields_as_name(client): res = client.json().set("doc:1", Path.root_path(), {"name": "Jon", "age": 25}) assert res - total = client.ft().search(Query("Jon").return_fields("name", "just_a_number")).docs - assert 1 == len(total) - assert "doc:1" == total[0].id - assert "Jon" == total[0].name - assert "25" == total[0].just_a_number + res = client.ft().search(Query("Jon").return_fields("name", "just_a_number")) + if is_resp2_connection(client): + assert 1 == len(res.docs) + assert "doc:1" == res.docs[0].id + assert "Jon" == res.docs[0].name + assert "25" == res.docs[0].just_a_number + else: + assert 1 == len(res["results"]) + assert "doc:1" == res["results"][0]["id"] + assert "Jon" == res["results"][0]["fields"]["name"] + assert "25" == res["results"][0]["fields"]["just_a_number"] @pytest.mark.redismod @@ -1184,11 +1344,16 @@ def test_casesensitive(client): client.ft().client.hset("1", "t", "HELLO") client.ft().client.hset("2", "t", "hello") - res = client.ft().search("@t:{HELLO}").docs + res = client.ft().search("@t:{HELLO}") - assert 2 == len(res) - assert "1" == res[0].id - assert "2" == res[1].id + if is_resp2_connection(client): + assert 2 == len(res.docs) + assert "1" == res.docs[0].id + assert "2" == res.docs[1].id + else: + assert 2 == len(res["results"]) + assert "1" == res["results"][0]["id"] + assert "2" == res["results"][1]["id"] # create casesensitive index client.ft().dropindex() @@ -1196,9 +1361,13 @@ def test_casesensitive(client): client.ft().create_index(SCHEMA) waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - res = client.ft().search("@t:{HELLO}").docs - assert 1 == len(res) - assert "1" == res[0].id + res = client.ft().search("@t:{HELLO}") + if is_resp2_connection(client): + assert 1 == len(res.docs) + assert "1" == res.docs[0].id + else: + assert 1 == len(res["results"]) + assert "1" == res["results"][0]["id"] @pytest.mark.redismod @@ -1217,15 +1386,26 @@ def test_search_return_fields(client): client.ft().create_index(SCHEMA, definition=definition) waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - total = client.ft().search(Query("*").return_field("$.t", as_field="txt")).docs - assert 1 == len(total) - assert "doc:1" == total[0].id - assert "riceratops" == total[0].txt + if is_resp2_connection(client): + total = client.ft().search(Query("*").return_field("$.t", as_field="txt")).docs + assert 1 == len(total) + assert "doc:1" == total[0].id + assert "riceratops" == total[0].txt - total = client.ft().search(Query("*").return_field("$.t2", as_field="txt")).docs - assert 1 == len(total) - assert "doc:1" == total[0].id - assert "telmatosaurus" == total[0].txt + total = client.ft().search(Query("*").return_field("$.t2", as_field="txt")).docs + assert 1 == len(total) + assert "doc:1" == total[0].id + assert "telmatosaurus" == total[0].txt + else: + total = client.ft().search(Query("*").return_field("$.t", as_field="txt")) + assert 1 == len(total["results"]) + assert "doc:1" == total["results"][0]["id"] + assert "riceratops" == total["results"][0]["fields"]["txt"] + + total = client.ft().search(Query("*").return_field("$.t2", as_field="txt")) + assert 1 == len(total["results"]) + assert "doc:1" == total["results"][0]["id"] + assert "telmatosaurus" == total["results"][0]["fields"]["txt"] @pytest.mark.redismod @@ -1242,9 +1422,14 @@ def test_synupdate(client): client.hset("doc2", mapping={"title": "he is another baby", "body": "another test"}) res = client.ft().search(Query("child").expander("SYNONYM")) - assert res.docs[0].id == "doc2" - assert res.docs[0].title == "he is another baby" - assert res.docs[0].body == "another test" + if is_resp2_connection(client): + assert res.docs[0].id == "doc2" + assert res.docs[0].title == "he is another baby" + assert res.docs[0].body == "another test" + else: + assert res["results"][0]["id"] == "doc2" + assert res["results"][0]["fields"]["title"] == "he is another baby" + assert res["results"][0]["fields"]["body"] == "another test" @pytest.mark.redismod @@ -1284,15 +1469,26 @@ def test_create_json_with_alias(client): client.json().set("king:1", Path.root_path(), {"name": "henry", "num": 42}) client.json().set("king:2", Path.root_path(), {"name": "james", "num": 3.14}) - res = client.ft().search("@name:henry") - assert res.docs[0].id == "king:1" - assert res.docs[0].json == '{"name":"henry","num":42}' - assert res.total == 1 - - res = client.ft().search("@num:[0 10]") - assert res.docs[0].id == "king:2" - assert res.docs[0].json == '{"name":"james","num":3.14}' - assert res.total == 1 + if is_resp2_connection(client): + res = client.ft().search("@name:henry") + assert res.docs[0].id == "king:1" + assert res.docs[0].json == '{"name":"henry","num":42}' + assert res.total == 1 + + res = client.ft().search("@num:[0 10]") + assert res.docs[0].id == "king:2" + assert res.docs[0].json == '{"name":"james","num":3.14}' + assert res.total == 1 + else: + res = client.ft().search("@name:henry") + assert res["results"][0]["id"] == "king:1" + assert res["results"][0]["fields"]["$"] == '{"name":"henry","num":42}' + assert res["total_results"] == 1 + + res = client.ft().search("@num:[0 10]") + assert res["results"][0]["id"] == "king:2" + assert res["results"][0]["fields"]["$"] == '{"name":"james","num":3.14}' + assert res["total_results"] == 1 # Tests returns an error if path contain special characters (user should # use an alias) @@ -1316,15 +1512,32 @@ def test_json_with_multipath(client): "king:1", Path.root_path(), {"name": "henry", "country": {"name": "england"}} ) - res = client.ft().search("@name:{henry}") - assert res.docs[0].id == "king:1" - assert res.docs[0].json == '{"name":"henry","country":{"name":"england"}}' - assert res.total == 1 + if is_resp2_connection(client): + res = client.ft().search("@name:{henry}") + assert res.docs[0].id == "king:1" + assert res.docs[0].json == '{"name":"henry","country":{"name":"england"}}' + assert res.total == 1 + + res = client.ft().search("@name:{england}") + assert res.docs[0].id == "king:1" + assert res.docs[0].json == '{"name":"henry","country":{"name":"england"}}' + assert res.total == 1 + else: + res = client.ft().search("@name:{henry}") + assert res["results"][0]["id"] == "king:1" + assert ( + res["results"][0]["fields"]["$"] + == '{"name":"henry","country":{"name":"england"}}' + ) + assert res["total_results"] == 1 - res = client.ft().search("@name:{england}") - assert res.docs[0].id == "king:1" - assert res.docs[0].json == '{"name":"henry","country":{"name":"england"}}' - assert res.total == 1 + res = client.ft().search("@name:{england}") + assert res["results"][0]["id"] == "king:1" + assert ( + res["results"][0]["fields"]["$"] + == '{"name":"henry","country":{"name":"england"}}' + ) + assert res["total_results"] == 1 @pytest.mark.redismod @@ -1341,98 +1554,116 @@ def test_json_with_jsonpath(client): client.json().set("doc:1", Path.root_path(), {"prod:name": "RediSearch"}) - # query for a supported field succeeds - res = client.ft().search(Query("@name:RediSearch")) - assert res.total == 1 - assert res.docs[0].id == "doc:1" - assert res.docs[0].json == '{"prod:name":"RediSearch"}' - - # query for an unsupported field - res = client.ft().search("@name_unsupported:RediSearch") - assert res.total == 1 - - # return of a supported field succeeds - res = client.ft().search(Query("@name:RediSearch").return_field("name")) - assert res.total == 1 - assert res.docs[0].id == "doc:1" - assert res.docs[0].name == "RediSearch" - - -@pytest.mark.redismod -@pytest.mark.onlynoncluster -@skip_if_redis_enterprise() -def test_profile(client): - client.ft().create_index((TextField("t"),)) - client.ft().client.hset("1", "t", "hello") - client.ft().client.hset("2", "t", "world") - - # check using Query - q = Query("hello|world").no_content() - res, det = client.ft().profile(q) - assert det["Iterators profile"]["Counter"] == 2.0 - assert len(det["Iterators profile"]["Child iterators"]) == 2 - assert det["Iterators profile"]["Type"] == "UNION" - assert det["Parsing time"] < 0.5 - assert len(res.docs) == 2 # check also the search result - - # check using AggregateRequest - req = ( - aggregations.AggregateRequest("*") - .load("t") - .apply(prefix="startswith(@t, 'hel')") - ) - res, det = client.ft().profile(req) - assert det["Iterators profile"]["Counter"] == 2.0 - assert det["Iterators profile"]["Type"] == "WILDCARD" - assert isinstance(det["Parsing time"], float) - assert len(res.rows) == 2 # check also the search result - - -@pytest.mark.redismod -@pytest.mark.onlynoncluster -def test_profile_limited(client): - client.ft().create_index((TextField("t"),)) - client.ft().client.hset("1", "t", "hello") - client.ft().client.hset("2", "t", "hell") - client.ft().client.hset("3", "t", "help") - client.ft().client.hset("4", "t", "helowa") - - q = Query("%hell% hel*") - res, det = client.ft().profile(q, limited=True) - assert ( - det["Iterators profile"]["Child iterators"][0]["Child iterators"] - == "The number of iterators in the union is 3" - ) - assert ( - det["Iterators profile"]["Child iterators"][1]["Child iterators"] - == "The number of iterators in the union is 4" - ) - assert det["Iterators profile"]["Type"] == "INTERSECT" - assert len(res.docs) == 3 # check also the search result - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -def test_profile_query_params(modclient: redis.Redis): - modclient.flushdb() - modclient.ft().create_index( - ( - VectorField( - "v", "HNSW", {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"} - ), - ) - ) - modclient.hset("a", "v", "aaaaaaaa") - modclient.hset("b", "v", "aaaabaaa") - modclient.hset("c", "v", "aaaaabaa") - query = "*=>[KNN 2 @v $vec]" - q = Query(query).return_field("__v_score").sort_by("__v_score", True).dialect(2) - res, det = modclient.ft().profile(q, query_params={"vec": "aaaaaaaa"}) - assert det["Iterators profile"]["Counter"] == 2.0 - assert det["Iterators profile"]["Type"] == "VECTOR" - assert res.total == 2 - assert "a" == res.docs[0].id - assert "0" == res.docs[0].__getattribute__("__v_score") + if is_resp2_connection(client): + # query for a supported field succeeds + res = client.ft().search(Query("@name:RediSearch")) + assert res.total == 1 + assert res.docs[0].id == "doc:1" + assert res.docs[0].json == '{"prod:name":"RediSearch"}' + + # query for an unsupported field + res = client.ft().search("@name_unsupported:RediSearch") + assert res.total == 1 + + # return of a supported field succeeds + res = client.ft().search(Query("@name:RediSearch").return_field("name")) + assert res.total == 1 + assert res.docs[0].id == "doc:1" + assert res.docs[0].name == "RediSearch" + else: + # query for a supported field succeeds + res = client.ft().search(Query("@name:RediSearch")) + assert res["total_results"] == 1 + assert res["results"][0]["id"] == "doc:1" + assert res["results"][0]["fields"]["$"] == '{"prod:name":"RediSearch"}' + + # query for an unsupported field + res = client.ft().search("@name_unsupported:RediSearch") + assert res["total_results"] == 1 + + # return of a supported field succeeds + res = client.ft().search(Query("@name:RediSearch").return_field("name")) + assert res["total_results"] == 1 + assert res["results"][0]["id"] == "doc:1" + assert res["results"][0]["fields"]["name"] == "RediSearch" + + + +# @pytest.mark.redismod +# @pytest.mark.onlynoncluster +# @skip_if_redis_enterprise() +# def test_profile(client): +# client.ft().create_index((TextField("t"),)) +# client.ft().client.hset("1", "t", "hello") +# client.ft().client.hset("2", "t", "world") + +# # check using Query +# q = Query("hello|world").no_content() +# res, det = client.ft().profile(q) +# assert det["Iterators profile"]["Counter"] == 2.0 +# assert len(det["Iterators profile"]["Child iterators"]) == 2 +# assert det["Iterators profile"]["Type"] == "UNION" +# assert det["Parsing time"] < 0.5 +# assert len(res.docs) == 2 # check also the search result + +# # check using AggregateRequest +# req = ( +# aggregations.AggregateRequest("*") +# .load("t") +# .apply(prefix="startswith(@t, 'hel')") +# ) +# res, det = client.ft().profile(req) +# assert det["Iterators profile"]["Counter"] == 2.0 +# assert det["Iterators profile"]["Type"] == "WILDCARD" +# assert isinstance(det["Parsing time"], float) +# assert len(res.rows) == 2 # check also the search result + + +# @pytest.mark.redismod +# @pytest.mark.onlynoncluster +# def test_profile_limited(client): +# client.ft().create_index((TextField("t"),)) +# client.ft().client.hset("1", "t", "hello") +# client.ft().client.hset("2", "t", "hell") +# client.ft().client.hset("3", "t", "help") +# client.ft().client.hset("4", "t", "helowa") + +# q = Query("%hell% hel*") +# res, det = client.ft().profile(q, limited=True) +# assert ( +# det["Iterators profile"]["Child iterators"][0]["Child iterators"] +# == "The number of iterators in the union is 3" +# ) +# assert ( +# det["Iterators profile"]["Child iterators"][1]["Child iterators"] +# == "The number of iterators in the union is 4" +# ) +# assert det["Iterators profile"]["Type"] == "INTERSECT" +# assert len(res.docs) == 3 # check also the search result + + +# @pytest.mark.redismod +# @skip_ifmodversion_lt("2.4.3", "search") +# def test_profile_query_params(modclient: redis.Redis): +# modclient.flushdb() +# modclient.ft().create_index( +# ( +# VectorField( +# "v", "HNSW", {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"} +# ), +# ) +# ) +# modclient.hset("a", "v", "aaaaaaaa") +# modclient.hset("b", "v", "aaaabaaa") +# modclient.hset("c", "v", "aaaaabaa") +# query = "*=>[KNN 2 @v $vec]" +# q = Query(query).return_field("__v_score").sort_by("__v_score", True).dialect(2) +# res, det = modclient.ft().profile(q, query_params={"vec": "aaaaaaaa"}) +# assert det["Iterators profile"]["Counter"] == 2.0 +# assert det["Iterators profile"]["Type"] == "VECTOR" +# assert res.total == 2 +# assert "a" == res.docs[0].id +# assert "0" == res.docs[0].__getattribute__("__v_score") @pytest.mark.redismod @@ -1553,6 +1784,7 @@ def test_dialect_config(modclient: redis.Redis): assert modclient.ft().config_get("DEFAULT_DIALECT") == {"DEFAULT_DIALECT": "1"} assert modclient.ft().config_set("DEFAULT_DIALECT", 2) assert modclient.ft().config_get("DEFAULT_DIALECT") == {"DEFAULT_DIALECT": "2"} + assert modclient.ft().config_set("DEFAULT_DIALECT", 1) with pytest.raises(redis.ResponseError): modclient.ft().config_set("DEFAULT_DIALECT", 0) diff --git a/tests/test_timeseries.py b/tests/test_timeseries.py index 6ced5359f7..4603161315 100644 --- a/tests/test_timeseries.py +++ b/tests/test_timeseries.py @@ -6,7 +6,7 @@ import redis -from .conftest import skip_ifmodversion_lt +from .conftest import assert_resp_response, is_resp2_connection, skip_ifmodversion_lt @pytest.fixture @@ -22,13 +22,15 @@ def test_create(client): assert client.ts().create(3, labels={"Redis": "Labs"}) assert client.ts().create(4, retention_msecs=20, labels={"Time": "Series"}) info = client.ts().info(4) - assert 20 == info.retention_msecs - assert "Series" == info.labels["Time"] + assert_resp_response( + client, 20, info.get("retention_msecs"), info.get("retentionTime") + ) + assert "Series" == info["labels"]["Time"] # Test for a chunk size of 128 Bytes assert client.ts().create("time-serie-1", chunk_size=128) info = client.ts().info("time-serie-1") - assert 128, info.chunk_size + assert_resp_response(client, 128, info.get("chunk_size"), info.get("chunkSize")) @pytest.mark.redismod @@ -39,19 +41,33 @@ def test_create_duplicate_policy(client): ts_name = f"time-serie-ooo-{duplicate_policy}" assert client.ts().create(ts_name, duplicate_policy=duplicate_policy) info = client.ts().info(ts_name) - assert duplicate_policy == info.duplicate_policy + assert_resp_response( + client, + duplicate_policy, + info.get("duplicate_policy"), + info.get("duplicatePolicy"), + ) @pytest.mark.redismod def test_alter(client): assert client.ts().create(1) - assert 0 == client.ts().info(1).retention_msecs + info = client.ts().info(1) + assert_resp_response( + client, 0, info.get("retention_msecs"), info.get("retentionTime") + ) assert client.ts().alter(1, retention_msecs=10) - assert {} == client.ts().info(1).labels - assert 10, client.ts().info(1).retention_msecs + assert {} == client.ts().info(1)["labels"] + info = client.ts().info(1) + assert_resp_response( + client, 10, info.get("retention_msecs"), info.get("retentionTime") + ) assert client.ts().alter(1, labels={"Time": "Series"}) - assert "Series" == client.ts().info(1).labels["Time"] - assert 10 == client.ts().info(1).retention_msecs + assert "Series" == client.ts().info(1)["labels"]["Time"] + info = client.ts().info(1) + assert_resp_response( + client, 10, info.get("retention_msecs"), info.get("retentionTime") + ) @pytest.mark.redismod @@ -59,10 +75,14 @@ def test_alter(client): def test_alter_diplicate_policy(client): assert client.ts().create(1) info = client.ts().info(1) - assert info.duplicate_policy is None + assert_resp_response( + client, None, info.get("duplicate_policy"), info.get("duplicatePolicy") + ) assert client.ts().alter(1, duplicate_policy="min") info = client.ts().info(1) - assert "min" == info.duplicate_policy + assert_resp_response( + client, "min", info.get("duplicate_policy"), info.get("duplicatePolicy") + ) @pytest.mark.redismod @@ -77,13 +97,15 @@ def test_add(client): assert abs(time.time() - float(client.ts().add(5, "*", 1)) / 1000) < 1.0 info = client.ts().info(4) - assert 10 == info.retention_msecs - assert "Labs" == info.labels["Redis"] + assert_resp_response( + client, 10, info.get("retention_msecs"), info.get("retentionTime") + ) + assert "Labs" == info["labels"]["Redis"] # Test for a chunk size of 128 Bytes on TS.ADD assert client.ts().add("time-serie-1", 1, 10.0, chunk_size=128) info = client.ts().info("time-serie-1") - assert 128 == info.chunk_size + assert_resp_response(client, 128, info.get("chunk_size"), info.get("chunkSize")) @pytest.mark.redismod @@ -142,21 +164,21 @@ def test_incrby_decrby(client): assert 0 == client.ts().get(1)[1] assert client.ts().incrby(2, 1.5, timestamp=5) - assert (5, 1.5) == client.ts().get(2) + assert_resp_response(client, client.ts().get(2), (5, 1.5), [5, 1.5]) assert client.ts().incrby(2, 2.25, timestamp=7) - assert (7, 3.75) == client.ts().get(2) + assert_resp_response(client, client.ts().get(2), (7, 3.75), [7, 3.75]) assert client.ts().decrby(2, 1.5, timestamp=15) - assert (15, 2.25) == client.ts().get(2) + assert_resp_response(client, client.ts().get(2), (15, 2.25), [15, 2.25]) # Test for a chunk size of 128 Bytes on TS.INCRBY assert client.ts().incrby("time-serie-1", 10, chunk_size=128) info = client.ts().info("time-serie-1") - assert 128 == info.chunk_size + assert_resp_response(client, 128, info.get("chunk_size"), info.get("chunkSize")) # Test for a chunk size of 128 Bytes on TS.DECRBY assert client.ts().decrby("time-serie-2", 10, chunk_size=128) info = client.ts().info("time-serie-2") - assert 128 == info.chunk_size + assert_resp_response(client, 128, info.get("chunk_size"), info.get("chunkSize")) @pytest.mark.redismod @@ -172,12 +194,15 @@ def test_create_and_delete_rule(client): client.ts().add(1, time * 2, 1.5) assert round(client.ts().get(2)[1], 5) == 1.5 info = client.ts().info(1) - assert info.rules[0][1] == 100 + if is_resp2_connection(client): + assert info.rules[0][1] == 100 + else: + assert info["rules"]["2"][0] == 100 # test rule deletion client.ts().deleterule(1, 2) info = client.ts().info(1) - assert not info.rules + assert not info["rules"] @pytest.mark.redismod @@ -192,7 +217,7 @@ def test_del_range(client): client.ts().add(1, i, i % 7) assert 22 == client.ts().delete(1, 0, 21) assert [] == client.ts().range(1, 0, 21) - assert [(22, 1.0)] == client.ts().range(1, 22, 22) + assert_resp_response(client, client.ts().range(1, 22, 22), [(22, 1.0)], [[22, 1.0]]) @pytest.mark.redismod @@ -227,15 +252,18 @@ def test_range_advanced(client): filter_by_max_value=2, ) ) - assert [(0, 10.0), (10, 1.0)] == client.ts().range( + res = client.ts().range( 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" ) - assert [(0, 5.0), (5, 6.0)] == client.ts().range( + assert_resp_response(client, res, [(0, 10.0), (10, 1.0)], [[0, 10.0], [10, 1.0]]) + res = client.ts().range( 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=5 ) - assert [(0, 2.55), (10, 3.0)] == client.ts().range( + assert_resp_response(client, res, [(0, 5.0), (5, 6.0)], [[0, 5.0], [5, 6.0]]) + res = client.ts().range( 1, 0, 10, aggregation_type="twa", bucket_size_msec=10 ) + assert_resp_response(client, res, [(0, 2.55), (10, 3.0)], [[0, 2.55], [10, 3.0]]) @pytest.mark.redismod @@ -249,14 +277,22 @@ def test_range_latest(client: redis.Redis): timeseries.add("t1", 2, 3) timeseries.add("t1", 11, 7) timeseries.add("t1", 13, 1) - res = timeseries.range("t1", 0, 20) - assert res == [(1, 1.0), (2, 3.0), (11, 7.0), (13, 1.0)] - res = timeseries.range("t2", 0, 10) - assert res == [(0, 4.0)] + assert_resp_response( + client, + timeseries.range("t1", 0, 20), + [(1, 1.0), (2, 3.0), (11, 7.0), (13, 1.0)], + [[1, 1.0], [2, 3.0], [11, 7.0], [13, 1.0]], + ) + assert_resp_response( + client, timeseries.range("t2", 0, 10), [(0, 4.0)], [[0, 4.0]] + ) res = timeseries.range("t2", 0, 10, latest=True) - assert res == [(0, 4.0), (10, 8.0)] - res = timeseries.range("t2", 0, 9, latest=True) - assert res == [(0, 4.0)] + assert_resp_response( + client, res, [(0, 4.0), (10, 8.0)], [[0, 4.0], [10, 8.0]] + ) + assert_resp_response( + client, timeseries.range("t2", 0, 9, latest=True), [(0, 4.0)], [[0, 4.0]] + ) @pytest.mark.redismod @@ -269,17 +305,27 @@ def test_range_bucket_timestamp(client: redis.Redis): timeseries.add("t1", 51, 3) timeseries.add("t1", 73, 5) timeseries.add("t1", 75, 3) - assert [(10, 4.0), (50, 3.0), (70, 5.0)] == timeseries.range( - "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 - ) - assert [(20, 4.0), (60, 3.0), (80, 5.0)] == timeseries.range( - "t1", - 0, - 100, - align=0, - aggregation_type="max", - bucket_size_msec=10, - bucket_timestamp="+", + assert_resp_response( + client, + timeseries.range( + "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + ), + [(10, 4.0), (50, 3.0), (70, 5.0)], + [[10, 4.0], [50, 3.0], [70, 5.0]], + ) + assert_resp_response( + client, + timeseries.range( + "t1", + 0, + 100, + align=0, + aggregation_type="max", + bucket_size_msec=10, + bucket_timestamp="+", + ), + [(20, 4.0), (60, 3.0), (80, 5.0)], + [[20, 4.0], [60, 3.0], [80, 5.0]], ) @@ -293,8 +339,13 @@ def test_range_empty(client: redis.Redis): timeseries.add("t1", 51, 3) timeseries.add("t1", 73, 5) timeseries.add("t1", 75, 3) - assert [(10, 4.0), (50, 3.0), (70, 5.0)] == timeseries.range( - "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + assert_resp_response( + client, + timeseries.range( + "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + ), + [(10, 4.0), (50, 3.0), (70, 5.0)], + [[10, 4.0], [50, 3.0], [70, 5.0]], ) res = timeseries.range( "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10, empty=True @@ -302,15 +353,13 @@ def test_range_empty(client: redis.Redis): for i in range(len(res)): if math.isnan(res[i][1]): res[i] = (res[i][0], None) - assert [ - (10, 4.0), - (20, None), - (30, None), - (40, None), - (50, 3.0), - (60, None), - (70, 5.0), - ] == res + resp2_expected = [ + (10, 4.0), (20, None), (30, None), (40, None), (50, 3.0), (60, None), (70, 5.0) + ] + resp3_expected = [ + [10, 4.0], (20, None), (30, None), (40, None), [50, 3.0], (60, None), [70, 5.0] + ] + assert_resp_response(client, res, resp2_expected, resp3_expected) @pytest.mark.redismod @@ -337,14 +386,29 @@ def test_rev_range(client): filter_by_max_value=2, ) ) - assert [(10, 1.0), (0, 10.0)] == client.ts().revrange( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" + assert_resp_response( + client, + client.ts().revrange( + 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" + ), + [(10, 1.0), (0, 10.0)], + [[10, 1.0], [0, 10.0]], ) - assert [(1, 10.0), (0, 1.0)] == client.ts().revrange( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=1 + assert_resp_response( + client, + client.ts().revrange( + 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=1 + ), + [(1, 10.0), (0, 1.0)], + [[1, 10.0], [0, 1.0]], ) - assert [(10, 3.0), (0, 2.55)] == client.ts().revrange( - 1, 0, 10, aggregation_type="twa", bucket_size_msec=10 + assert_resp_response( + client, + client.ts().revrange( + 1, 0, 10, aggregation_type="twa", bucket_size_msec=10 + ), + [(10, 3.0), (0, 2.55)], + [[10, 3.0], [0, 2.55]], ) @@ -360,11 +424,11 @@ def test_revrange_latest(client: redis.Redis): timeseries.add("t1", 11, 7) timeseries.add("t1", 13, 1) res = timeseries.revrange("t2", 0, 10) - assert res == [(0, 4.0)] + assert_resp_response(client, res, [(0, 4.0)], [[0, 4.0]]) res = timeseries.revrange("t2", 0, 10, latest=True) - assert res == [(10, 8.0), (0, 4.0)] + assert_resp_response(client, res, [(10, 8.0), (0, 4.0)], [[10, 8.0], [0, 4.0]]) res = timeseries.revrange("t2", 0, 9, latest=True) - assert res == [(0, 4.0)] + assert_resp_response(client, res, [(0, 4.0)], [[0, 4.0]]) @pytest.mark.redismod @@ -377,17 +441,21 @@ def test_revrange_bucket_timestamp(client: redis.Redis): timeseries.add("t1", 51, 3) timeseries.add("t1", 73, 5) timeseries.add("t1", 75, 3) - assert [(70, 5.0), (50, 3.0), (10, 4.0)] == timeseries.revrange( - "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + assert_resp_response( + client, + timeseries.revrange( + "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + ), + [(70, 5.0), (50, 3.0), (10, 4.0)], + [[70, 5.0], [50, 3.0], [10, 4.0]], ) - assert [(20, 4.0), (60, 3.0), (80, 5.0)] == timeseries.range( - "t1", - 0, - 100, - align=0, - aggregation_type="max", - bucket_size_msec=10, - bucket_timestamp="+", + assert_resp_response( + client, + timeseries.range( + "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10, bucket_timestamp="+" + ), + [(20, 4.0), (60, 3.0), (80, 5.0)], + [[20, 4.0], [60, 3.0], [80, 5.0]], ) @@ -401,8 +469,13 @@ def test_revrange_empty(client: redis.Redis): timeseries.add("t1", 51, 3) timeseries.add("t1", 73, 5) timeseries.add("t1", 75, 3) - assert [(70, 5.0), (50, 3.0), (10, 4.0)] == timeseries.revrange( - "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + assert_resp_response( + client, + timeseries.revrange( + "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + ), + [(70, 5.0), (50, 3.0), (10, 4.0)], + [[70, 5.0], [50, 3.0], [10, 4.0]], ) res = timeseries.revrange( "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10, empty=True @@ -410,15 +483,13 @@ def test_revrange_empty(client: redis.Redis): for i in range(len(res)): if math.isnan(res[i][1]): res[i] = (res[i][0], None) - assert [ - (70, 5.0), - (60, None), - (50, 3.0), - (40, None), - (30, None), - (20, None), - (10, 4.0), - ] == res + resp2_expected = [ + (70, 5.0), (60, None), (50, 3.0), (40, None), (30, None), (20, None), (10, 4.0) + ] + resp3_expected = [ + [70, 5.0], (60, None), [50, 3.0], (40, None), (30, None), (20, None), [10, 4.0] + ] + assert_resp_response(client, res, resp2_expected, resp3_expected) @pytest.mark.redismod @@ -432,23 +503,42 @@ def test_mrange(client): res = client.ts().mrange(0, 200, filters=["Test=This"]) assert 2 == len(res) - assert 100 == len(res[0]["1"][1]) + if is_resp2_connection(client): + assert 100 == len(res[0]["1"][1]) - res = client.ts().mrange(0, 200, filters=["Test=This"], count=10) - assert 10 == len(res[0]["1"][1]) + res = client.ts().mrange(0, 200, filters=["Test=This"], count=10) + assert 10 == len(res[0]["1"][1]) - for i in range(100): - client.ts().add(1, i + 200, i % 7) - res = client.ts().mrange( - 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 - ) - assert 2 == len(res) - assert 20 == len(res[0]["1"][1]) + for i in range(100): + client.ts().add(1, i + 200, i % 7) + res = client.ts().mrange( + 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 + ) + assert 2 == len(res) + assert 20 == len(res[0]["1"][1]) + + # test withlabels + assert {} == res[0]["1"][0] + res = client.ts().mrange(0, 200, filters=["Test=This"], with_labels=True) + assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + else: + assert 100 == len(res["1"][2]) + + res = client.ts().mrange(0, 200, filters=["Test=This"], count=10) + assert 10 == len(res["1"][2]) + + for i in range(100): + client.ts().add(1, i + 200, i % 7) + res = client.ts().mrange( + 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 + ) + assert 2 == len(res) + assert 20 == len(res["1"][2]) - # test withlabels - assert {} == res[0]["1"][0] - res = client.ts().mrange(0, 200, filters=["Test=This"], with_labels=True) - assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + # test withlabels + assert {} == res["1"][0] + res = client.ts().mrange(0, 200, filters=["Test=This"], with_labels=True) + assert {"Test": "This", "team": "ny"} == res["1"][0] @pytest.mark.redismod @@ -463,49 +553,94 @@ def test_multi_range_advanced(client): # test with selected labels res = client.ts().mrange(0, 200, filters=["Test=This"], select_labels=["team"]) - assert {"team": "ny"} == res[0]["1"][0] - assert {"team": "sf"} == res[1]["2"][0] - - # test with filterby - res = client.ts().mrange( - 0, - 200, - filters=["Test=This"], - filter_by_ts=[i for i in range(10, 20)], - filter_by_min_value=1, - filter_by_max_value=2, - ) - assert [(15, 1.0), (16, 2.0)] == res[0]["1"][1] + if is_resp2_connection(client): + assert {"team": "ny"} == res[0]["1"][0] + assert {"team": "sf"} == res[1]["2"][0] - # test groupby - res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="Test", reduce="sum") - assert [(0, 0.0), (1, 2.0), (2, 4.0), (3, 6.0)] == res[0]["Test=This"][1] - res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="Test", reduce="max") - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["Test=This"][1] - res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="team", reduce="min") - assert 2 == len(res) - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["team=ny"][1] - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[1]["team=sf"][1] + # test with filterby + res = client.ts().mrange( + 0, + 200, + filters=["Test=This"], + filter_by_ts=[i for i in range(10, 20)], + filter_by_min_value=1, + filter_by_max_value=2, + ) + assert [(15, 1.0), (16, 2.0)] == res[0]["1"][1] + + # test groupby + res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="Test", reduce="sum") + assert [(0, 0.0), (1, 2.0), (2, 4.0), (3, 6.0)] == res[0]["Test=This"][1] + res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="Test", reduce="max") + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["Test=This"][1] + res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="team", reduce="min") + assert 2 == len(res) + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["team=ny"][1] + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[1]["team=sf"][1] + + # test align + res = client.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align="-", + ) + assert [(0, 10.0), (10, 1.0)] == res[0]["1"][1] + res = client.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align=5, + ) + assert [(0, 5.0), (5, 6.0)] == res[0]["1"][1] + else: + assert {"team": "ny"} == res["1"][0] + assert {"team": "sf"} == res["2"][0] - # test align - res = client.ts().mrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align="-", - ) - assert [(0, 10.0), (10, 1.0)] == res[0]["1"][1] - res = client.ts().mrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align=5, - ) - assert [(0, 5.0), (5, 6.0)] == res[0]["1"][1] + # test with filterby + res = client.ts().mrange( + 0, + 200, + filters=["Test=This"], + filter_by_ts=[i for i in range(10, 20)], + filter_by_min_value=1, + filter_by_max_value=2, + ) + assert [[15, 1.0], [16, 2.0]] == res["1"][2] + + # test groupby + res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="Test", reduce="sum") + assert [[0, 0.0], [1, 2.0], [2, 4.0], [3, 6.0]] == res["Test=This"][3] + res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="Test", reduce="max") + assert [[0, 0.0], [1, 1.0], [2, 2.0], [3, 3.0]] == res["Test=This"][3] + res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="team", reduce="min") + assert 2 == len(res) + assert [[0, 0.0], [1, 1.0], [2, 2.0], [3, 3.0]] == res["team=ny"][3] + assert [[0, 0.0], [1, 1.0], [2, 2.0], [3, 3.0]] == res["team=sf"][3] + + # test align + res = client.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align="-", + ) + assert [[0, 10.0], [10, 1.0]] == res["1"][2] + res = client.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align=5, + ) + assert [[0, 5.0], [5, 6.0]] == res["1"][2] @pytest.mark.redismod @@ -527,10 +662,15 @@ def test_mrange_latest(client: redis.Redis): timeseries.add("t3", 2, 3) timeseries.add("t3", 11, 7) timeseries.add("t3", 13, 1) - assert client.ts().mrange(0, 10, filters=["is_compaction=true"], latest=True) == [ - {"t2": [{}, [(0, 4.0), (10, 8.0)]]}, - {"t4": [{}, [(0, 4.0), (10, 8.0)]]}, - ] + assert_resp_response( + client, + client.ts().mrange(0, 10, filters=["is_compaction=true"], latest=True), + [{"t2": [{}, [(0, 4.0), (10, 8.0)]]}, {"t4": [{}, [(0, 4.0), (10, 8.0)]]}], + { + 't2': [{}, {'aggregators': []}, [[0, 4.0], [10, 8.0]]], + 't4': [{}, {'aggregators': []}, [[0, 4.0], [10, 8.0]]], + } + ) @pytest.mark.redismod @@ -545,10 +685,16 @@ def test_multi_reverse_range(client): res = client.ts().mrange(0, 200, filters=["Test=This"]) assert 2 == len(res) - assert 100 == len(res[0]["1"][1]) + if is_resp2_connection(client): + assert 100 == len(res[0]["1"][1]) + else: + assert 100 == len(res["1"][2]) res = client.ts().mrange(0, 200, filters=["Test=This"], count=10) - assert 10 == len(res[0]["1"][1]) + if is_resp2_connection(client): + assert 10 == len(res[0]["1"][1]) + else: + assert 10 == len(res["1"][2]) for i in range(100): client.ts().add(1, i + 200, i % 7) @@ -556,17 +702,28 @@ def test_multi_reverse_range(client): 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 ) assert 2 == len(res) - assert 20 == len(res[0]["1"][1]) - assert {} == res[0]["1"][0] + if is_resp2_connection(client): + assert 20 == len(res[0]["1"][1]) + assert {} == res[0]["1"][0] + else: + assert 20 == len(res["1"][2]) + assert {} == res["1"][0] # test withlabels res = client.ts().mrevrange(0, 200, filters=["Test=This"], with_labels=True) - assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + if is_resp2_connection(client): + assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + else: + assert {"Test": "This", "team": "ny"} == res["1"][0] # test with selected labels res = client.ts().mrevrange(0, 200, filters=["Test=This"], select_labels=["team"]) - assert {"team": "ny"} == res[0]["1"][0] - assert {"team": "sf"} == res[1]["2"][0] + if is_resp2_connection(client): + assert {"team": "ny"} == res[0]["1"][0] + assert {"team": "sf"} == res[1]["2"][0] + else: + assert {"team": "ny"} == res["1"][0] + assert {"team": "sf"} == res["2"][0] # test filterby res = client.ts().mrevrange( @@ -577,23 +734,36 @@ def test_multi_reverse_range(client): filter_by_min_value=1, filter_by_max_value=2, ) - assert [(16, 2.0), (15, 1.0)] == res[0]["1"][1] + if is_resp2_connection(client): + assert [(16, 2.0), (15, 1.0)] == res[0]["1"][1] + else: + assert [[16, 2.0], [15, 1.0]] == res["1"][2] # test groupby res = client.ts().mrevrange( 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" ) - assert [(3, 6.0), (2, 4.0), (1, 2.0), (0, 0.0)] == res[0]["Test=This"][1] + if is_resp2_connection(client): + assert [(3, 6.0), (2, 4.0), (1, 2.0), (0, 0.0)] == res[0]["Test=This"][1] + else: + assert [[3, 6.0], [2, 4.0], [1, 2.0], [0, 0.0]] == res["Test=This"][3] res = client.ts().mrevrange( 0, 3, filters=["Test=This"], groupby="Test", reduce="max" ) - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["Test=This"][1] + if is_resp2_connection(client): + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["Test=This"][1] + else: + assert [[3, 3.0], [2, 2.0], [1, 1.0], [0, 0.0]] == res["Test=This"][3] res = client.ts().mrevrange( 0, 3, filters=["Test=This"], groupby="team", reduce="min" ) assert 2 == len(res) - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["team=ny"][1] - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[1]["team=sf"][1] + if is_resp2_connection(client): + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["team=ny"][1] + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[1]["team=sf"][1] + else: + assert [[3, 3.0], [2, 2.0], [1, 1.0], [0, 0.0]] == res["team=ny"][3] + assert [[3, 3.0], [2, 2.0], [1, 1.0], [0, 0.0]] == res["team=sf"][3] # test align res = client.ts().mrevrange( @@ -604,7 +774,10 @@ def test_multi_reverse_range(client): bucket_size_msec=10, align="-", ) - assert [(10, 1.0), (0, 10.0)] == res[0]["1"][1] + if is_resp2_connection(client): + assert [(10, 1.0), (0, 10.0)] == res[0]["1"][1] + else: + assert [[10, 1.0], [0, 10.0]] == res["1"][2] res = client.ts().mrevrange( 0, 10, @@ -613,7 +786,10 @@ def test_multi_reverse_range(client): bucket_size_msec=10, align=1, ) - assert [(1, 10.0), (0, 1.0)] == res[0]["1"][1] + if is_resp2_connection(client): + assert [(1, 10.0), (0, 1.0)] == res[0]["1"][1] + else: + assert [[1, 10.0], [0, 1.0]] == res["1"][2] @pytest.mark.redismod @@ -635,16 +811,22 @@ def test_mrevrange_latest(client: redis.Redis): timeseries.add("t3", 2, 3) timeseries.add("t3", 11, 7) timeseries.add("t3", 13, 1) - assert client.ts().mrevrange( - 0, 10, filters=["is_compaction=true"], latest=True - ) == [{"t2": [{}, [(10, 8.0), (0, 4.0)]]}, {"t4": [{}, [(10, 8.0), (0, 4.0)]]}] + assert_resp_response( + client, + client.ts().mrevrange(0, 10, filters=["is_compaction=true"], latest=True), + [{"t2": [{}, [(10, 8.0), (0, 4.0)]]}, {"t4": [{}, [(10, 8.0), (0, 4.0)]]}], + { + 't2': [{}, {'aggregators': []}, [[10, 8.0], [0, 4.0]]], + 't4': [{}, {'aggregators': []}, [[10, 8.0], [0, 4.0]]] + }, + ) @pytest.mark.redismod def test_get(client): name = "test" client.ts().create(name) - assert client.ts().get(name) is None + assert not client.ts().get(name) client.ts().add(name, 2, 3) assert 2 == client.ts().get(name)[0] client.ts().add(name, 3, 4) @@ -662,8 +844,8 @@ def test_get_latest(client: redis.Redis): timeseries.add("t1", 2, 3) timeseries.add("t1", 11, 7) timeseries.add("t1", 13, 1) - assert (0, 4.0) == timeseries.get("t2") - assert (10, 8.0) == timeseries.get("t2", latest=True) + assert_resp_response(client, timeseries.get("t2"), (0, 4.0), [0, 4.0]) + assert_resp_response(client, timeseries.get("t2", latest=True), (10, 8.0), [10, 8.0]) @pytest.mark.redismod @@ -673,19 +855,33 @@ def test_mget(client): client.ts().create(2, labels={"Test": "This", "Taste": "That"}) act_res = client.ts().mget(["Test=This"]) exp_res = [{"1": [{}, None, None]}, {"2": [{}, None, None]}] - assert act_res == exp_res + exp_res_resp3 = {"1": [{}, []], "2": [{}, []]} + assert_resp_response(client, act_res, exp_res, exp_res_resp3) client.ts().add(1, "*", 15) client.ts().add(2, "*", 25) res = client.ts().mget(["Test=This"]) - assert 15 == res[0]["1"][2] - assert 25 == res[1]["2"][2] + if is_resp2_connection(client): + assert 15 == res[0]["1"][2] + assert 25 == res[1]["2"][2] + else: + assert 15 == res["1"][1][1] + assert 25 == res["2"][1][1] res = client.ts().mget(["Taste=That"]) - assert 25 == res[0]["2"][2] + if is_resp2_connection(client): + assert 25 == res[0]["2"][2] + else: + assert 25 == res["2"][1][1] # test with_labels - assert {} == res[0]["2"][0] + if is_resp2_connection(client): + assert {} == res[0]["2"][0] + else: + assert {} == res["2"][0] res = client.ts().mget(["Taste=That"], with_labels=True) - assert {"Taste": "That", "Test": "This"} == res[0]["2"][0] + if is_resp2_connection(client): + assert {"Taste": "That", "Test": "This"} == res[0]["2"][0] + else: + assert {"Taste": "That", "Test": "This"} == res["2"][0] @pytest.mark.redismod @@ -700,18 +896,20 @@ def test_mget_latest(client: redis.Redis): timeseries.add("t1", 2, 3) timeseries.add("t1", 11, 7) timeseries.add("t1", 13, 1) - assert timeseries.mget(filters=["is_compaction=true"]) == [{"t2": [{}, 0, 4.0]}] - assert [{"t2": [{}, 10, 8.0]}] == timeseries.mget( - filters=["is_compaction=true"], latest=True - ) + res = timeseries.mget(filters=["is_compaction=true"]) + assert_resp_response(client, res, [{"t2": [{}, 0, 4.0]}], {'t2': [{}, [0, 4.0]]}) + res = timeseries.mget(filters=["is_compaction=true"], latest=True) + assert_resp_response(client, res, [{"t2": [{}, 10, 8.0]}], {'t2': [{}, [10, 8.0]]}) @pytest.mark.redismod def test_info(client): client.ts().create(1, retention_msecs=5, labels={"currentLabel": "currentData"}) info = client.ts().info(1) - assert 5 == info.retention_msecs - assert info.labels["currentLabel"] == "currentData" + assert_resp_response( + client, 5, info.get("retention_msecs"), info.get("retentionTime") + ) + assert info["labels"]["currentLabel"] == "currentData" @pytest.mark.redismod @@ -719,11 +917,15 @@ def test_info(client): def testInfoDuplicatePolicy(client): client.ts().create(1, retention_msecs=5, labels={"currentLabel": "currentData"}) info = client.ts().info(1) - assert info.duplicate_policy is None + assert_resp_response( + client, None, info.get("duplicate_policy"), info.get("duplicatePolicy") + ) client.ts().create("time-serie-2", duplicate_policy="min") info = client.ts().info("time-serie-2") - assert "min" == info.duplicate_policy + assert_resp_response( + client, "min", info.get("duplicate_policy"), info.get("duplicatePolicy") + ) @pytest.mark.redismod @@ -733,7 +935,7 @@ def test_query_index(client): client.ts().create(2, labels={"Test": "This", "Taste": "That"}) assert 2 == len(client.ts().queryindex(["Test=This"])) assert 1 == len(client.ts().queryindex(["Taste=That"])) - assert [2] == client.ts().queryindex(["Taste=That"]) + assert_resp_response(client, client.ts().queryindex(["Taste=That"]), [2], {"2"}) @pytest.mark.redismod @@ -745,8 +947,12 @@ def test_pipeline(client): pipeline.execute() info = client.ts().info("with_pipeline") - assert info.last_timestamp == 99 - assert info.total_samples == 100 + assert_resp_response( + client, 99, info.get("last_timestamp"), info.get("lastTimestamp") + ) + assert_resp_response( + client, 100, info.get("total_samples"), info.get("totalSamples") + ) assert client.ts().get("with_pipeline")[1] == 99 * 1.1 @@ -756,4 +962,7 @@ def test_uncompressed(client): client.ts().create("uncompressed", uncompressed=True) compressed_info = client.ts().info("compressed") uncompressed_info = client.ts().info("uncompressed") - assert compressed_info.memory_usage != uncompressed_info.memory_usage + if is_resp2_connection(client): + assert compressed_info.memory_usage != uncompressed_info.memory_usage + else: + assert compressed_info["memoryUsage"] != uncompressed_info["memoryUsage"] From a70913d838945fea61987632b43929f1bcd9bc44 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 15 Jun 2023 11:03:28 +0300 Subject: [PATCH 05/10] tests --- tests/test_search.py | 623 ++++++++++++++++++++++++++++++------------- 1 file changed, 433 insertions(+), 190 deletions(-) diff --git a/tests/test_search.py b/tests/test_search.py index 99bb327d23..7ee5d611b7 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -728,33 +728,61 @@ def test_spell_check(client): client.hset("doc2", mapping={"f1": "very important", "f2": "lorem ipsum"}) waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - # test spellcheck - res = client.ft().spellcheck("impornant") - assert "important" == res["impornant"][0]["suggestion"] - - res = client.ft().spellcheck("contnt") - assert "content" == res["contnt"][0]["suggestion"] - - # test spellcheck with Levenshtein distance - res = client.ft().spellcheck("vlis") - assert res == {} - res = client.ft().spellcheck("vlis", distance=2) - assert "valid" == res["vlis"][0]["suggestion"] - - # test spellcheck include - client.ft().dict_add("dict", "lore", "lorem", "lorm") - res = client.ft().spellcheck("lorm", include="dict") - assert len(res["lorm"]) == 3 - assert ( - res["lorm"][0]["suggestion"], - res["lorm"][1]["suggestion"], - res["lorm"][2]["suggestion"], - ) == ("lorem", "lore", "lorm") - assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0") - - # test spellcheck exclude - res = client.ft().spellcheck("lorm", exclude="dict") - assert res == {} + if is_resp2_connection(client): + + # test spellcheck + res = client.ft().spellcheck("impornant") + assert "important" == res["impornant"][0]["suggestion"] + + res = client.ft().spellcheck("contnt") + assert "content" == res["contnt"][0]["suggestion"] + + # test spellcheck with Levenshtein distance + res = client.ft().spellcheck("vlis") + assert res == {} + res = client.ft().spellcheck("vlis", distance=2) + assert "valid" == res["vlis"][0]["suggestion"] + + # test spellcheck include + client.ft().dict_add("dict", "lore", "lorem", "lorm") + res = client.ft().spellcheck("lorm", include="dict") + assert len(res["lorm"]) == 3 + assert ( + res["lorm"][0]["suggestion"], + res["lorm"][1]["suggestion"], + res["lorm"][2]["suggestion"], + ) == ("lorem", "lore", "lorm") + assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0") + + # test spellcheck exclude + res = client.ft().spellcheck("lorm", exclude="dict") + assert res == {} + else: + # test spellcheck + res = client.ft().spellcheck("impornant") + assert "important" in res["impornant"][0].keys() + + res = client.ft().spellcheck("contnt") + assert "content" in res["contnt"][0].keys() + + # test spellcheck with Levenshtein distance + res = client.ft().spellcheck("vlis") + assert res == {'vlis': []} + res = client.ft().spellcheck("vlis", distance=2) + assert "valid" in res["vlis"][0].keys() + + # test spellcheck include + client.ft().dict_add("dict", "lore", "lorem", "lorm") + res = client.ft().spellcheck("lorm", include="dict") + assert len(res["lorm"]) == 3 + assert "lorem" in res["lorm"][0].keys() + assert "lore" in res["lorm"][1].keys() + assert "lorm" in res["lorm"][2].keys() + assert (res["lorm"][0]["lorem"], res["lorm"][1]["lore"]) == (0.5, 0) + + # test spellcheck exclude + res = client.ft().spellcheck("lorm", exclude="dict") + assert res == {} @pytest.mark.redismod @@ -932,102 +960,199 @@ def test_aggregations_groupby(client): }, ) - req = aggregations.AggregateRequest("redis").group_by("@parent", reducers.count()) + if is_resp2_connection(client): + req = aggregations.AggregateRequest("redis").group_by("@parent", reducers.count()) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3" - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.count_distinct("@title") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count_distinct("@title") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3" - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.count_distinctish("@title") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count_distinctish("@title") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3" - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.sum("@random_num") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.sum("@random_num") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "21" # 10+8+3 + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "21" # 10+8+3 - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.min("@random_num") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.min("@random_num") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3" # min(10,8,3) + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3" # min(10,8,3) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.max("@random_num") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.max("@random_num") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "10" # max(10,8,3) + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "10" # max(10,8,3) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.avg("@random_num") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.avg("@random_num") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - index = res.index("__generated_aliasavgrandom_num") - assert res[index + 1] == "7" # (10+3+8)/3 + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + index = res.index("__generated_aliasavgrandom_num") + assert res[index + 1] == "7" # (10+3+8)/3 - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.stddev("random_num") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.stddev("random_num") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3.60555127546" + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3.60555127546" - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.quantile("@random_num", 0.5) - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.quantile("@random_num", 0.5) + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "8" # median of 3,8,10 + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "8" # median of 3,8,10 - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.tolist("@title") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.tolist("@title") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"} + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"} - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.first_value("@title").alias("first") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.first_value("@title").alias("first") + ) - res = client.ft().aggregate(req).rows[0] - assert res == ["parent", "redis", "first", "RediSearch"] + res = client.ft().aggregate(req).rows[0] + assert res == ["parent", "redis", "first", "RediSearch"] - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.random_sample("@title", 2).alias("random") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.random_sample("@title", 2).alias("random") + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[2] == "random" + assert len(res[3]) == 2 + assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"] + else: + req = aggregations.AggregateRequest("redis").group_by("@parent", reducers.count()) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[2] == "random" - assert len(res[3]) == 2 - assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"] + res = client.ft().aggregate(req)["results"][0] + assert res["fields"]["parent"] == "redis" + assert res["fields"]["__generated_aliascount"] == "3" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count_distinct("@title") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["fields"]["parent"] == "redis" + assert res["fields"]["__generated_aliascount_distincttitle"] == "3" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count_distinctish("@title") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["fields"]["parent"] == "redis" + assert res["fields"]["__generated_aliascount_distinctishtitle"] == "3" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.sum("@random_num") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["fields"]["parent"] == "redis" + assert res["fields"]["__generated_aliassumrandom_num"] == "21" # 10+8+3 + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.min("@random_num") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["fields"]["parent"] == "redis" + assert res["fields"]["__generated_aliasminrandom_num"] == "3" # min(10,8,3) + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.max("@random_num") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["fields"]["parent"] == "redis" + assert res["fields"]["__generated_aliasmaxrandom_num"] == "10" # max(10,8,3) + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.avg("@random_num") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["fields"]["parent"] == "redis" + assert res["fields"]["__generated_aliasavgrandom_num"] == "7" # (10+3+8)/3 + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.stddev("random_num") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["fields"]["parent"] == "redis" + assert res["fields"]["__generated_aliasstddevrandom_num"] == "3.60555127546" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.quantile("@random_num", 0.5) + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["fields"]["parent"] == "redis" + assert res["fields"]["__generated_aliasquantilerandom_num,0.5"] == "8" # median of 3,8,10 + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.tolist("@title") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["fields"]["parent"] == "redis" + assert set( + res["fields"]["__generated_aliastolisttitle"] + ) == {"RediSearch", "RedisAI", "RedisJson"} + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.first_value("@title").alias("first") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["fields"] == {"parent": "redis", "first": "RediSearch"} + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.random_sample("@title", 2).alias("random") + ) + res = client.ft().aggregate(req)["results"][0] + assert res["fields"]["parent"] == "redis" + assert "random" in res["fields"].keys() + assert len(res["fields"]["random"]) == 2 + assert res["fields"]["random"][0] in ["RediSearch", "RedisAI", "RedisJson"] @pytest.mark.redismod def test_aggregations_sort_by_and_limit(client): @@ -1036,30 +1161,56 @@ def test_aggregations_sort_by_and_limit(client): client.ft().client.hset("doc1", mapping={"t1": "a", "t2": "b"}) client.ft().client.hset("doc2", mapping={"t1": "b", "t2": "a"}) - # test sort_by using SortDirection - req = aggregations.AggregateRequest("*").sort_by( - aggregations.Asc("@t2"), aggregations.Desc("@t1") - ) - res = client.ft().aggregate(req) - assert res.rows[0] == ["t2", "a", "t1", "b"] - assert res.rows[1] == ["t2", "b", "t1", "a"] + if is_resp2_connection(client): + # test sort_by using SortDirection + req = aggregations.AggregateRequest("*").sort_by( + aggregations.Asc("@t2"), aggregations.Desc("@t1") + ) + res = client.ft().aggregate(req) + assert res.rows[0] == ["t2", "a", "t1", "b"] + assert res.rows[1] == ["t2", "b", "t1", "a"] - # test sort_by without SortDirection - req = aggregations.AggregateRequest("*").sort_by("@t1") - res = client.ft().aggregate(req) - assert res.rows[0] == ["t1", "a"] - assert res.rows[1] == ["t1", "b"] + # test sort_by without SortDirection + req = aggregations.AggregateRequest("*").sort_by("@t1") + res = client.ft().aggregate(req) + assert res.rows[0] == ["t1", "a"] + assert res.rows[1] == ["t1", "b"] - # test sort_by with max - req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) - res = client.ft().aggregate(req) - assert len(res.rows) == 1 + # test sort_by with max + req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) + res = client.ft().aggregate(req) + assert len(res.rows) == 1 - # test limit - req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) - res = client.ft().aggregate(req) - assert len(res.rows) == 1 - assert res.rows[0] == ["t1", "b"] + # test limit + req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) + res = client.ft().aggregate(req) + assert len(res.rows) == 1 + assert res.rows[0] == ["t1", "b"] + else: + # test sort_by using SortDirection + req = aggregations.AggregateRequest("*").sort_by( + aggregations.Asc("@t2"), aggregations.Desc("@t1") + ) + res = client.ft().aggregate(req)["results"] + assert res[0]["fields"] == {"t2": "a", "t1": "b"} + assert res[1]["fields"] == {"t2": "b", "t1": "a"} + + # test sort_by without SortDirection + req = aggregations.AggregateRequest("*").sort_by("@t1") + res = client.ft().aggregate(req)["results"] + assert res[0]["fields"] == {"t1": "a"} + assert res[1]["fields"] == {"t1": "b"} + + # test sort_by with max + req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) + res = client.ft().aggregate(req) + assert len(res["results"]) == 1 + + # test limit + req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) + res = client.ft().aggregate(req) + assert len(res["results"]) == 1 + assert res["results"][0]["fields"] == {"t1": "b"} @pytest.mark.redismod @@ -1068,20 +1219,36 @@ def test_aggregations_load(client): client.ft().client.hset("doc1", mapping={"t1": "hello", "t2": "world"}) - # load t1 - req = aggregations.AggregateRequest("*").load("t1") - res = client.ft().aggregate(req) - assert res.rows[0] == ["t1", "hello"] + if is_resp2_connection(client): + # load t1 + req = aggregations.AggregateRequest("*").load("t1") + res = client.ft().aggregate(req) + assert res.rows[0] == ["t1", "hello"] - # load t2 - req = aggregations.AggregateRequest("*").load("t2") - res = client.ft().aggregate(req) - assert res.rows[0] == ["t2", "world"] + # load t2 + req = aggregations.AggregateRequest("*").load("t2") + res = client.ft().aggregate(req) + assert res.rows[0] == ["t2", "world"] - # load all - req = aggregations.AggregateRequest("*").load() - res = client.ft().aggregate(req) - assert res.rows[0] == ["t1", "hello", "t2", "world"] + # load all + req = aggregations.AggregateRequest("*").load() + res = client.ft().aggregate(req) + assert res.rows[0] == ["t1", "hello", "t2", "world"] + else: + # load t1 + req = aggregations.AggregateRequest("*").load("t1") + res = client.ft().aggregate(req) + assert res["results"][0]["fields"] == {"t1": "hello"} + + # load t2 + req = aggregations.AggregateRequest("*").load("t2") + res = client.ft().aggregate(req) + assert res["results"][0]["fields"] == {"t2": "world"} + + # load all + req = aggregations.AggregateRequest("*").load() + res = client.ft().aggregate(req) + assert res["results"][0]["fields"] == {"t1": "hello", "t2": "world"} @pytest.mark.redismod @@ -1106,8 +1273,15 @@ def test_aggregations_apply(client): CreatedDateTimeUTC="@CreatedDateTimeUTC * 10" ) res = client.ft().aggregate(req) - res_set = set([res.rows[0][1], res.rows[1][1]]) - assert res_set == set(["6373878785249699840", "6373878758592700416"]) + if is_resp2_connection(client): + res_set = set([res.rows[0][1], res.rows[1][1]]) + assert res_set == set(["6373878785249699840", "6373878758592700416"]) + else: + res_set = set( + [res["results"][0]["fields"]["CreatedDateTimeUTC"], + res["results"][1]["fields"]["CreatedDateTimeUTC"]], + ) + assert res_set == set(["6373878785249699840", "6373878758592700416"]) @pytest.mark.redismod @@ -1126,19 +1300,34 @@ def test_aggregations_filter(client): .dialect(dialect) ) res = client.ft().aggregate(req) - assert len(res.rows) == 1 - assert res.rows[0] == ["name", "foo", "age", "19"] - - req = ( - aggregations.AggregateRequest("*") - .filter("@age > 15") - .sort_by("@age") - .dialect(dialect) - ) - res = client.ft().aggregate(req) - assert len(res.rows) == 2 - assert res.rows[0] == ["age", "19"] - assert res.rows[1] == ["age", "25"] + if is_resp2_connection(client): + assert len(res.rows) == 1 + assert res.rows[0] == ["name", "foo", "age", "19"] + + req = ( + aggregations.AggregateRequest("*") + .filter("@age > 15") + .sort_by("@age") + .dialect(dialect) + ) + res = client.ft().aggregate(req) + assert len(res.rows) == 2 + assert res.rows[0] == ["age", "19"] + assert res.rows[1] == ["age", "25"] + else: + assert len(res["results"]) == 1 + assert res["results"][0]["fields"] == {"name": "foo", "age": "19"} + + req = ( + aggregations.AggregateRequest("*") + .filter("@age > 15") + .sort_by("@age") + .dialect(dialect) + ) + res = client.ft().aggregate(req) + assert len(res["results"]) == 2 + assert res["results"][0]["fields"] == {"age": "19"} + assert res["results"][1]["fields"] == {"age": "25"} @pytest.mark.redismod @@ -1303,8 +1492,7 @@ def test_create_client_definition_json(client): assert res.total == 1 else: assert res["results"][0]["id"] == "king:1" - # assert res["results"][0]["payload"] is None - # assert res["results"][0]["json"] == '{"name":"henry"}' + assert res["results"][0]["fields"]["$"] == '{"name":"henry"}' assert res["total_results"] == 1 @@ -1685,8 +1873,12 @@ def test_vector_field(modclient): q = Query(query).return_field("__v_score").sort_by("__v_score", True).dialect(2) res = modclient.ft().search(q, query_params={"vec": "aaaaaaaa"}) - assert "a" == res.docs[0].id - assert "0" == res.docs[0].__getattribute__("__v_score") + if is_resp2_connection(modclient): + assert "a" == res.docs[0].id + assert "0" == res.docs[0].__getattribute__("__v_score") + else: + assert "a" == res["results"][0]["id"] + assert "0" == res["results"][0]["fields"]["__v_score"] @pytest.mark.redismod @@ -1716,9 +1908,14 @@ def test_text_params(modclient): params_dict = {"name1": "Alice", "name2": "Bob"} q = Query("@name:($name1 | $name2 )").dialect(2) res = modclient.ft().search(q, query_params=params_dict) - assert 2 == res.total - assert "doc1" == res.docs[0].id - assert "doc2" == res.docs[1].id + if is_resp2_connection(modclient): + assert 2 == res.total + assert "doc1" == res.docs[0].id + assert "doc2" == res.docs[1].id + else: + assert 2 == res["total_results"] + assert "doc1" == res["results"][0]["id"] + assert "doc2" == res["results"][1]["id"] @pytest.mark.redismod @@ -1735,9 +1932,14 @@ def test_numeric_params(modclient): q = Query("@numval:[$min $max]").dialect(2) res = modclient.ft().search(q, query_params=params_dict) - assert 2 == res.total - assert "doc1" == res.docs[0].id - assert "doc2" == res.docs[1].id + if is_resp2_connection(modclient): + assert 2 == res.total + assert "doc1" == res.docs[0].id + assert "doc2" == res.docs[1].id + else: + assert 2 == res["total_results"] + assert "doc1" == res["results"][0]["id"] + assert "doc2" == res["results"][1]["id"] @pytest.mark.redismod @@ -1753,10 +1955,16 @@ def test_geo_params(modclient): params_dict = {"lat": "34.95126", "lon": "29.69465", "radius": 1000, "units": "km"} q = Query("@g:[$lon $lat $radius $units]").dialect(2) res = modclient.ft().search(q, query_params=params_dict) - assert 3 == res.total - assert "doc1" == res.docs[0].id - assert "doc2" == res.docs[1].id - assert "doc3" == res.docs[2].id + if is_resp2_connection(modclient): + assert 3 == res.total + assert "doc1" == res.docs[0].id + assert "doc2" == res.docs[1].id + assert "doc3" == res.docs[2].id + else: + assert 3 == res["total_results"] + assert "doc1" == res["results"][0]["id"] + assert "doc2" == res["results"][1]["id"] + assert "doc3" == res["results"][2]["id"] @pytest.mark.redismod @@ -1769,12 +1977,22 @@ def test_search_commands_in_pipeline(client): q = Query("foo bar").with_payloads() p.search(q) res = p.execute() - assert res[:3] == ["OK", True, True] - assert 2 == res[3][0] - assert "doc1" == res[3][1] - assert "doc2" == res[3][4] - assert res[3][5] is None - assert res[3][3] == res[3][6] == ["txt", "foo bar"] + if is_resp2_connection(client): + assert res[:3] == ["OK", True, True] + assert 2 == res[3][0] + assert "doc1" == res[3][1] + assert "doc2" == res[3][4] + assert res[3][5] is None + assert res[3][3] == res[3][6] == ["txt", "foo bar"] + else: + assert res[:3] == ["OK", True, True] + assert 2 == res[3]["total_results"] + assert "doc1" == res[3]["results"][0]["id"] + assert "doc2" == res[3]["results"][1]["id"] + assert res[3]["results"][0]["payload"] is None + assert res[3]["results"][0]["fields"] == res[3]["results"][1]["fields"] == { + "txt": "foo bar" + } @pytest.mark.redismod @@ -1829,12 +2047,20 @@ def test_expire_while_search(modclient: redis.Redis): modclient.hset("hset:1", "txt", "a") modclient.hset("hset:2", "txt", "b") modclient.hset("hset:3", "txt", "c") - assert 3 == modclient.ft().search(Query("*")).total - modclient.pexpire("hset:2", 300) - for _ in range(500): - modclient.ft().search(Query("*")).docs[1] - time.sleep(1) - assert 2 == modclient.ft().search(Query("*")).total + if is_resp2_connection(modclient): + assert 3 == modclient.ft().search(Query("*")).total + modclient.pexpire("hset:2", 300) + for _ in range(500): + modclient.ft().search(Query("*")).docs[1] + time.sleep(1) + assert 2 == modclient.ft().search(Query("*")).total + else: + assert 3 == modclient.ft().search(Query("*"))["total_results"] + modclient.pexpire("hset:2", 300) + for _ in range(500): + modclient.ft().search(Query("*"))["results"][1] + time.sleep(1) + assert 2 == modclient.ft().search(Query("*"))["total_results"] @pytest.mark.redismod @@ -1843,23 +2069,40 @@ def test_withsuffixtrie(modclient: redis.Redis): # create index assert modclient.ft().create_index((TextField("txt"),)) waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) - info = modclient.ft().info() - assert "WITHSUFFIXTRIE" not in info["attributes"][0] - assert modclient.ft().dropindex("idx") - - # create withsuffixtrie index (text fiels) - assert modclient.ft().create_index((TextField("t", withsuffixtrie=True))) - waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) - info = modclient.ft().info() - assert "WITHSUFFIXTRIE" in info["attributes"][0] - assert modclient.ft().dropindex("idx") - - # create withsuffixtrie index (tag field) - assert modclient.ft().create_index((TagField("t", withsuffixtrie=True))) - waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) - info = modclient.ft().info() - assert "WITHSUFFIXTRIE" in info["attributes"][0] - + if is_resp2_connection(modclient): + info = modclient.ft().info() + assert "WITHSUFFIXTRIE" not in info["attributes"][0] + assert modclient.ft().dropindex("idx") + + # create withsuffixtrie index (text fiels) + assert modclient.ft().create_index((TextField("t", withsuffixtrie=True))) + waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) + info = modclient.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0] + assert modclient.ft().dropindex("idx") + + # create withsuffixtrie index (tag field) + assert modclient.ft().create_index((TagField("t", withsuffixtrie=True))) + waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) + info = modclient.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0] + else: + info = modclient.ft().info() + assert "WITHSUFFIXTRIE" not in info["attributes"][0]["flags"] + assert modclient.ft().dropindex("idx") + + # create withsuffixtrie index (text fiels) + assert modclient.ft().create_index((TextField("t", withsuffixtrie=True))) + waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) + info = modclient.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] + assert modclient.ft().dropindex("idx") + + # create withsuffixtrie index (tag field) + assert modclient.ft().create_index((TagField("t", withsuffixtrie=True))) + waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) + info = modclient.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] @pytest.mark.redismod def test_query_timeout(modclient: redis.Redis): From baedc9c1dcfb7ee634e77a7ed869387167f3de1e Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 15 Jun 2023 12:27:12 +0300 Subject: [PATCH 06/10] finish sync search tests --- redis/commands/search/__init__.py | 2 - redis/commands/search/commands.py | 15 +- tests/test_search.py | 227 ++++++++++++++++++++---------- 3 files changed, 157 insertions(+), 87 deletions(-) diff --git a/redis/commands/search/__init__.py b/redis/commands/search/__init__.py index 228b742035..ed55dd149f 100644 --- a/redis/commands/search/__init__.py +++ b/redis/commands/search/__init__.py @@ -8,7 +8,6 @@ PROFILE_CMD, SPELLCHECK_CMD, CONFIG_CMD, - SUGGET_COMMAND, SYNDUMP_CMD, AsyncSearchCommands, SearchCommands, @@ -108,7 +107,6 @@ def __init__(self, client, index_name="idx"): PROFILE_CMD: self._parse_profile, SPELLCHECK_CMD: self._parse_spellcheck, CONFIG_CMD: self._parse_config_get, - SUGGET_COMMAND: self._parse_sugget, SYNDUMP_CMD: self._parse_syndump, } diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index f448d1d84a..e9865368bc 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -140,14 +140,6 @@ def _parse_spellcheck(self, res, **kwargs): def _parse_config_get(self, res, **kwargs): return {kvs[0]: kvs[1] for kvs in res} if res else {} - def _parse_sugget(self, res, **kwargs): - results = [] - if not res: - return results - - parser = SuggestionParser(kwargs["with_scores"], kwargs["with_payloads"], res) - return [s for s in parser] - def _parse_syndump(self, res, **kwargs): return {res[i]: res[i + 1] for i in range(0, len(res), 2)} @@ -843,7 +835,12 @@ def sugget( args.append(WITHPAYLOADS) res = self.execute_command(*args) - return self._parse_results(SUGGET_COMMAND, res, with_scores=with_scores, with_payloads=with_payloads) + results = [] + if not res: + return results + + parser = SuggestionParser(with_scores, with_payloads, res) + return [s for s in parser] def synupdate(self, groupid, skipinitial=False, *terms): """ diff --git a/tests/test_search.py b/tests/test_search.py index 7ee5d611b7..1c8558887e 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -137,84 +137,159 @@ def test_client(client): assert num_docs == int(info["num_docs"]) res = client.ft().search("henry iv") - assert isinstance(res, Result) - assert 225 == res.total - assert 10 == len(res.docs) - assert res.duration > 0 - - for doc in res.docs: - assert doc.id - assert doc["id"] - assert doc.play == "Henry IV" - assert doc["play"] == "Henry IV" + if is_resp2_connection(client): + assert isinstance(res, Result) + assert 225 == res.total + assert 10 == len(res.docs) + assert res.duration > 0 + + for doc in res.docs: + assert doc.id + assert doc["id"] + assert doc.play == "Henry IV" + assert doc["play"] == "Henry IV" + assert len(doc.txt) > 0 + + # test no content + res = client.ft().search(Query("king").no_content()) + assert 194 == res.total + assert 10 == len(res.docs) + for doc in res.docs: + assert "txt" not in doc.__dict__ + assert "play" not in doc.__dict__ + + # test verbatim vs no verbatim + total = client.ft().search(Query("kings").no_content()).total + vtotal = client.ft().search(Query("kings").no_content().verbatim()).total + assert total > vtotal + + # test in fields + txt_total = ( + client.ft().search(Query("henry").no_content().limit_fields("txt")).total + ) + play_total = ( + client.ft().search(Query("henry").no_content().limit_fields("play")).total + ) + both_total = ( + client.ft() + .search(Query("henry").no_content().limit_fields("play", "txt")) + .total + ) + assert 129 == txt_total + assert 494 == play_total + assert 494 == both_total + + # test load_document + doc = client.ft().load_document("henry vi part 3:62") + assert doc is not None + assert "henry vi part 3:62" == doc.id + assert doc.play == "Henry VI Part 3" assert len(doc.txt) > 0 - # test no content - res = client.ft().search(Query("king").no_content()) - assert 194 == res.total - assert 10 == len(res.docs) - for doc in res.docs: - assert "txt" not in doc.__dict__ - assert "play" not in doc.__dict__ - - # test verbatim vs no verbatim - total = client.ft().search(Query("kings").no_content()).total - vtotal = client.ft().search(Query("kings").no_content().verbatim()).total - assert total > vtotal - - # test in fields - txt_total = ( - client.ft().search(Query("henry").no_content().limit_fields("txt")).total - ) - play_total = ( - client.ft().search(Query("henry").no_content().limit_fields("play")).total - ) - both_total = ( - client.ft() - .search(Query("henry").no_content().limit_fields("play", "txt")) - .total - ) - assert 129 == txt_total - assert 494 == play_total - assert 494 == both_total - - # test load_document - doc = client.ft().load_document("henry vi part 3:62") - assert doc is not None - assert "henry vi part 3:62" == doc.id - assert doc.play == "Henry VI Part 3" - assert len(doc.txt) > 0 - - # test in-keys - ids = [x.id for x in client.ft().search(Query("henry")).docs] - assert 10 == len(ids) - subset = ids[:5] - docs = client.ft().search(Query("henry").limit_ids(*subset)) - assert len(subset) == docs.total - ids = [x.id for x in docs.docs] - assert set(ids) == set(subset) - - # test slop and in order - assert 193 == client.ft().search(Query("henry king")).total - assert 3 == client.ft().search(Query("henry king").slop(0).in_order()).total - assert 52 == client.ft().search(Query("king henry").slop(0).in_order()).total - assert 53 == client.ft().search(Query("henry king").slop(0)).total - assert 167 == client.ft().search(Query("henry king").slop(100)).total - - # test delete document - client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) - res = client.ft().search(Query("death of a salesman")) - assert 1 == res.total - - assert 1 == client.ft().delete_document("doc-5ghs2") - res = client.ft().search(Query("death of a salesman")) - assert 0 == res.total - assert 0 == client.ft().delete_document("doc-5ghs2") - - client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) - res = client.ft().search(Query("death of a salesman")) - assert 1 == res.total - client.ft().delete_document("doc-5ghs2") + # test in-keys + ids = [x.id for x in client.ft().search(Query("henry")).docs] + assert 10 == len(ids) + subset = ids[:5] + docs = client.ft().search(Query("henry").limit_ids(*subset)) + assert len(subset) == docs.total + ids = [x.id for x in docs.docs] + assert set(ids) == set(subset) + + # test slop and in order + assert 193 == client.ft().search(Query("henry king")).total + assert 3 == client.ft().search(Query("henry king").slop(0).in_order()).total + assert 52 == client.ft().search(Query("king henry").slop(0).in_order()).total + assert 53 == client.ft().search(Query("henry king").slop(0)).total + assert 167 == client.ft().search(Query("henry king").slop(100)).total + + # test delete document + client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = client.ft().search(Query("death of a salesman")) + assert 1 == res.total + + assert 1 == client.ft().delete_document("doc-5ghs2") + res = client.ft().search(Query("death of a salesman")) + assert 0 == res.total + assert 0 == client.ft().delete_document("doc-5ghs2") + + client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = client.ft().search(Query("death of a salesman")) + assert 1 == res.total + client.ft().delete_document("doc-5ghs2") + else: + assert isinstance(res, dict) + assert 225 == res["total_results"] + assert 10 == len(res["results"]) + + for doc in res["results"]: + assert doc["id"] + assert doc["fields"]["play"] == "Henry IV" + assert len(doc["fields"]["txt"]) > 0 + + # test no content + res = client.ft().search(Query("king").no_content()) + assert 194 == res["total_results"] + assert 10 == len(res["results"]) + for doc in res["results"]: + assert "fields" not in doc.keys() + + # test verbatim vs no verbatim + total = client.ft().search(Query("kings").no_content())["total_results"] + vtotal = client.ft().search(Query("kings").no_content().verbatim())["total_results"] + assert total > vtotal + + # test in fields + txt_total = ( + client.ft().search(Query("henry").no_content().limit_fields("txt"))["total_results"] + ) + play_total = ( + client.ft().search(Query("henry").no_content().limit_fields("play"))["total_results"] + ) + both_total = ( + client.ft() + .search(Query("henry").no_content().limit_fields("play", "txt"))["total_results"] + ) + assert 129 == txt_total + assert 494 == play_total + assert 494 == both_total + + # test load_document + doc = client.ft().load_document("henry vi part 3:62") + assert doc is not None + assert "henry vi part 3:62" == doc.id + assert doc.play == "Henry VI Part 3" + assert len(doc.txt) > 0 + + # test in-keys + ids = [x["id"] for x in client.ft().search(Query("henry"))["results"]] + assert 10 == len(ids) + subset = ids[:5] + docs = client.ft().search(Query("henry").limit_ids(*subset)) + assert len(subset) == docs["total_results"] + ids = [x["id"] for x in docs["results"]] + assert set(ids) == set(subset) + + # test slop and in order + assert 193 == client.ft().search(Query("henry king"))["total_results"] + assert 3 == client.ft().search(Query("henry king").slop(0).in_order())["total_results"] + assert 52 == client.ft().search(Query("king henry").slop(0).in_order())["total_results"] + assert 53 == client.ft().search(Query("henry king").slop(0))["total_results"] + assert 167 == client.ft().search(Query("henry king").slop(100))["total_results"] + + # test delete document + client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = client.ft().search(Query("death of a salesman")) + assert 1 == res["total_results"] + + assert 1 == client.ft().delete_document("doc-5ghs2") + res = client.ft().search(Query("death of a salesman")) + assert 0 == res["total_results"] + assert 0 == client.ft().delete_document("doc-5ghs2") + + client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = client.ft().search(Query("death of a salesman")) + assert 1 == res["total_results"] + client.ft().delete_document("doc-5ghs2") @pytest.mark.redismod From 20f7a46a786dc5a31d56fd19f8ab07c906b92623 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 15 Jun 2023 12:57:10 +0300 Subject: [PATCH 07/10] linters --- redis/client.py | 2 - redis/commands/search/__init__.py | 6 +- redis/commands/search/commands.py | 16 +++-- redis/commands/timeseries/__init__.py | 1 - tests/test_search.py | 86 ++++++++++++++-------- tests/test_timeseries.py | 100 +++++++++++++++++--------- 6 files changed, 138 insertions(+), 73 deletions(-) diff --git a/redis/client.py b/redis/client.py index d4bdfbd46a..a2162f8344 100755 --- a/redis/client.py +++ b/redis/client.py @@ -794,7 +794,6 @@ class AbstractRedis: "CONFIG SET": bool_ok, **string_keys_to_dict("XREVRANGE XRANGE", parse_stream_list), "XCLAIM": parse_xclaim, - } RESP2_RESPONSE_CALLBACKS = { @@ -813,7 +812,6 @@ class AbstractRedis: "HGETALL": lambda r: r and pairs_to_dict(r) or {}, "MEMORY STATS": parse_memory_stats, "MODULE LIST": lambda r: [pairs_to_dict(m) for m in r], - # **string_keys_to_dict( # "COPY " # "HEXISTS HMSET MOVE MSETNX PERSIST " diff --git a/redis/commands/search/__init__.py b/redis/commands/search/__init__.py index ed55dd149f..7a7fdff844 100644 --- a/redis/commands/search/__init__.py +++ b/redis/commands/search/__init__.py @@ -2,12 +2,12 @@ from ...asyncio.client import Pipeline as AsyncioPipeline from .commands import ( - INFO_CMD, - SEARCH_CMD, AGGREGATE_CMD, + CONFIG_CMD, + INFO_CMD, PROFILE_CMD, + SEARCH_CMD, SPELLCHECK_CMD, - CONFIG_CMD, SYNDUMP_CMD, AsyncSearchCommands, SearchCommands, diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index e9865368bc..0bce9eb223 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -99,7 +99,7 @@ def _parse_profile(self, res, **kwargs): ) return result, parse_to_dict(res[1]) - + def _parse_spellcheck(self, res, **kwargs): corrections = {} if res == 0: @@ -136,7 +136,7 @@ def _parse_spellcheck(self, res, **kwargs): ] return corrections - + def _parse_config_get(self, res, **kwargs): return {kvs[0]: kvs[1] for kvs in res} if res else {} @@ -501,7 +501,9 @@ def search( if isinstance(res, Pipeline): return res - return self._parse_results(SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0) + return self._parse_results( + SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0 + ) def explain( self, @@ -546,7 +548,9 @@ def aggregate( cmd += self.get_params_args(query_params) raw = self.execute_command(*cmd) - return self._parse_results(AGGREGATE_CMD, raw, query=query, has_cursor=has_cursor) + return self._parse_results( + AGGREGATE_CMD, raw, query=query, has_cursor=has_cursor + ) def _get_aggregate_result(self, raw, query, has_cursor): if has_cursor: @@ -604,7 +608,9 @@ def profile( res = self.execute_command(*cmd) - return self._parse_results(PROFILE_CMD, res, query=query, duration=(time.time() - st) * 1000.0) + return self._parse_results( + PROFILE_CMD, res, query=query, duration=(time.time() - st) * 1000.0 + ) def spellcheck(self, query, distance=None, include=None, exclude=None): """ diff --git a/redis/commands/timeseries/__init__.py b/redis/commands/timeseries/__init__.py index 5b8a02466d..7e085af768 100644 --- a/redis/commands/timeseries/__init__.py +++ b/redis/commands/timeseries/__init__.py @@ -50,7 +50,6 @@ def __init__(self, client=None, **kwargs): MRANGE_CMD: parse_m_range, MREVRANGE_CMD: parse_m_range, INFO_CMD: TSInfo, - } RESP3_MODULE_CALLBACKS = {} diff --git a/tests/test_search.py b/tests/test_search.py index 1c8558887e..fc63bcc1d2 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -24,7 +24,12 @@ from redis.commands.search.result import Result from redis.commands.search.suggestion import Suggestion -from .conftest import assert_resp_response, skip_if_redis_enterprise, skip_ifmodversion_lt, is_resp2_connection +from .conftest import ( + assert_resp_response, + is_resp2_connection, + skip_if_redis_enterprise, + skip_ifmodversion_lt, +) WILL_PLAY_TEXT = os.path.abspath( os.path.join(os.path.dirname(__file__), "testdata", "will_play_text.csv.bz2") @@ -49,7 +54,7 @@ def waitForIndex(env, idx, timeout=None): if int(res["indexing"]) == 0: break except ValueError: - break + break time.sleep(delay) if timeout is not None: @@ -235,20 +240,21 @@ def test_client(client): # test verbatim vs no verbatim total = client.ft().search(Query("kings").no_content())["total_results"] - vtotal = client.ft().search(Query("kings").no_content().verbatim())["total_results"] + vtotal = client.ft().search(Query("kings").no_content().verbatim())[ + "total_results" + ] assert total > vtotal # test in fields - txt_total = ( - client.ft().search(Query("henry").no_content().limit_fields("txt"))["total_results"] - ) - play_total = ( - client.ft().search(Query("henry").no_content().limit_fields("play"))["total_results"] - ) - both_total = ( - client.ft() - .search(Query("henry").no_content().limit_fields("play", "txt"))["total_results"] - ) + txt_total = client.ft().search(Query("henry").no_content().limit_fields("txt"))[ + "total_results" + ] + play_total = client.ft().search( + Query("henry").no_content().limit_fields("play") + )["total_results"] + both_total = client.ft().search( + Query("henry").no_content().limit_fields("play", "txt") + )["total_results"] assert 129 == txt_total assert 494 == play_total assert 494 == both_total @@ -271,8 +277,18 @@ def test_client(client): # test slop and in order assert 193 == client.ft().search(Query("henry king"))["total_results"] - assert 3 == client.ft().search(Query("henry king").slop(0).in_order())["total_results"] - assert 52 == client.ft().search(Query("king henry").slop(0).in_order())["total_results"] + assert ( + 3 + == client.ft().search(Query("henry king").slop(0).in_order())[ + "total_results" + ] + ) + assert ( + 52 + == client.ft().search(Query("king henry").slop(0).in_order())[ + "total_results" + ] + ) assert 53 == client.ft().search(Query("henry king").slop(0))["total_results"] assert 167 == client.ft().search(Query("henry king").slop(100))["total_results"] @@ -385,7 +401,6 @@ def test_filters(client): assert ["doc1", "doc2"] == res - @pytest.mark.redismod def test_sort_by(client): client.ft().create_index((TextField("txt"), NumericField("num", sortable=True))) @@ -804,7 +819,7 @@ def test_spell_check(client): waitForIndex(client, getattr(client.ft(), "index_name", "idx")) if is_resp2_connection(client): - + # test spellcheck res = client.ft().spellcheck("impornant") assert "important" == res["impornant"][0]["suggestion"] @@ -842,7 +857,7 @@ def test_spell_check(client): # test spellcheck with Levenshtein distance res = client.ft().spellcheck("vlis") - assert res == {'vlis': []} + assert res == {"vlis": []} res = client.ft().spellcheck("vlis", distance=2) assert "valid" in res["vlis"][0].keys() @@ -1036,7 +1051,9 @@ def test_aggregations_groupby(client): ) if is_resp2_connection(client): - req = aggregations.AggregateRequest("redis").group_by("@parent", reducers.count()) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count() + ) res = client.ft().aggregate(req).rows[0] assert res[1] == "redis" @@ -1132,7 +1149,9 @@ def test_aggregations_groupby(client): assert len(res[3]) == 2 assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"] else: - req = aggregations.AggregateRequest("redis").group_by("@parent", reducers.count()) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count() + ) res = client.ft().aggregate(req)["results"][0] assert res["fields"]["parent"] == "redis" @@ -1200,7 +1219,7 @@ def test_aggregations_groupby(client): res = client.ft().aggregate(req)["results"][0] assert res["fields"]["parent"] == "redis" - assert res["fields"]["__generated_aliasquantilerandom_num,0.5"] == "8" # median of 3,8,10 + assert res["fields"]["__generated_aliasquantilerandom_num,0.5"] == "8" req = aggregations.AggregateRequest("redis").group_by( "@parent", reducers.tolist("@title") @@ -1208,9 +1227,11 @@ def test_aggregations_groupby(client): res = client.ft().aggregate(req)["results"][0] assert res["fields"]["parent"] == "redis" - assert set( - res["fields"]["__generated_aliastolisttitle"] - ) == {"RediSearch", "RedisAI", "RedisJson"} + assert set(res["fields"]["__generated_aliastolisttitle"]) == { + "RediSearch", + "RedisAI", + "RedisJson", + } req = aggregations.AggregateRequest("redis").group_by( "@parent", reducers.first_value("@title").alias("first") @@ -1229,6 +1250,7 @@ def test_aggregations_groupby(client): assert len(res["fields"]["random"]) == 2 assert res["fields"]["random"][0] in ["RediSearch", "RedisAI", "RedisJson"] + @pytest.mark.redismod def test_aggregations_sort_by_and_limit(client): client.ft().create_index((TextField("t1"), TextField("t2"))) @@ -1353,8 +1375,10 @@ def test_aggregations_apply(client): assert res_set == set(["6373878785249699840", "6373878758592700416"]) else: res_set = set( - [res["results"][0]["fields"]["CreatedDateTimeUTC"], - res["results"][1]["fields"]["CreatedDateTimeUTC"]], + [ + res["results"][0]["fields"]["CreatedDateTimeUTC"], + res["results"][1]["fields"]["CreatedDateTimeUTC"], + ], ) assert res_set == set(["6373878785249699840", "6373878758592700416"]) @@ -1851,7 +1875,6 @@ def test_json_with_jsonpath(client): assert res["results"][0]["fields"]["name"] == "RediSearch" - # @pytest.mark.redismod # @pytest.mark.onlynoncluster # @skip_if_redis_enterprise() @@ -2065,9 +2088,11 @@ def test_search_commands_in_pipeline(client): assert "doc1" == res[3]["results"][0]["id"] assert "doc2" == res[3]["results"][1]["id"] assert res[3]["results"][0]["payload"] is None - assert res[3]["results"][0]["fields"] == res[3]["results"][1]["fields"] == { - "txt": "foo bar" - } + assert ( + res[3]["results"][0]["fields"] + == res[3]["results"][1]["fields"] + == {"txt": "foo bar"} + ) @pytest.mark.redismod @@ -2179,6 +2204,7 @@ def test_withsuffixtrie(modclient: redis.Redis): info = modclient.ft().info() assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] + @pytest.mark.redismod def test_query_timeout(modclient: redis.Redis): q1 = Query("foo").timeout(5000) diff --git a/tests/test_timeseries.py b/tests/test_timeseries.py index 4603161315..31e753c158 100644 --- a/tests/test_timeseries.py +++ b/tests/test_timeseries.py @@ -259,10 +259,8 @@ def test_range_advanced(client): res = client.ts().range( 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=5 ) - assert_resp_response(client, res, [(0, 5.0), (5, 6.0)], [[0, 5.0], [5, 6.0]]) - res = client.ts().range( - 1, 0, 10, aggregation_type="twa", bucket_size_msec=10 - ) + assert_resp_response(client, res, [(0, 5.0), (5, 6.0)], [[0, 5.0], [5, 6.0]]) + res = client.ts().range(1, 0, 10, aggregation_type="twa", bucket_size_msec=10) assert_resp_response(client, res, [(0, 2.55), (10, 3.0)], [[0, 2.55], [10, 3.0]]) @@ -283,13 +281,9 @@ def test_range_latest(client: redis.Redis): [(1, 1.0), (2, 3.0), (11, 7.0), (13, 1.0)], [[1, 1.0], [2, 3.0], [11, 7.0], [13, 1.0]], ) - assert_resp_response( - client, timeseries.range("t2", 0, 10), [(0, 4.0)], [[0, 4.0]] - ) + assert_resp_response(client, timeseries.range("t2", 0, 10), [(0, 4.0)], [[0, 4.0]]) res = timeseries.range("t2", 0, 10, latest=True) - assert_resp_response( - client, res, [(0, 4.0), (10, 8.0)], [[0, 4.0], [10, 8.0]] - ) + assert_resp_response(client, res, [(0, 4.0), (10, 8.0)], [[0, 4.0], [10, 8.0]]) assert_resp_response( client, timeseries.range("t2", 0, 9, latest=True), [(0, 4.0)], [[0, 4.0]] ) @@ -354,10 +348,22 @@ def test_range_empty(client: redis.Redis): if math.isnan(res[i][1]): res[i] = (res[i][0], None) resp2_expected = [ - (10, 4.0), (20, None), (30, None), (40, None), (50, 3.0), (60, None), (70, 5.0) + (10, 4.0), + (20, None), + (30, None), + (40, None), + (50, 3.0), + (60, None), + (70, 5.0), ] resp3_expected = [ - [10, 4.0], (20, None), (30, None), (40, None), [50, 3.0], (60, None), [70, 5.0] + [10, 4.0], + (20, None), + (30, None), + (40, None), + [50, 3.0], + (60, None), + [70, 5.0], ] assert_resp_response(client, res, resp2_expected, resp3_expected) @@ -404,9 +410,7 @@ def test_rev_range(client): ) assert_resp_response( client, - client.ts().revrange( - 1, 0, 10, aggregation_type="twa", bucket_size_msec=10 - ), + client.ts().revrange(1, 0, 10, aggregation_type="twa", bucket_size_msec=10), [(10, 3.0), (0, 2.55)], [[10, 3.0], [0, 2.55]], ) @@ -452,7 +456,13 @@ def test_revrange_bucket_timestamp(client: redis.Redis): assert_resp_response( client, timeseries.range( - "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10, bucket_timestamp="+" + "t1", + 0, + 100, + align=0, + aggregation_type="max", + bucket_size_msec=10, + bucket_timestamp="+", ), [(20, 4.0), (60, 3.0), (80, 5.0)], [[20, 4.0], [60, 3.0], [80, 5.0]], @@ -484,10 +494,22 @@ def test_revrange_empty(client: redis.Redis): if math.isnan(res[i][1]): res[i] = (res[i][0], None) resp2_expected = [ - (70, 5.0), (60, None), (50, 3.0), (40, None), (30, None), (20, None), (10, 4.0) + (70, 5.0), + (60, None), + (50, 3.0), + (40, None), + (30, None), + (20, None), + (10, 4.0), ] resp3_expected = [ - [70, 5.0], (60, None), [50, 3.0], (40, None), (30, None), (20, None), [10, 4.0] + [70, 5.0], + (60, None), + [50, 3.0], + (40, None), + (30, None), + (20, None), + [10, 4.0], ] assert_resp_response(client, res, resp2_expected, resp3_expected) @@ -569,11 +591,17 @@ def test_multi_range_advanced(client): assert [(15, 1.0), (16, 2.0)] == res[0]["1"][1] # test groupby - res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="Test", reduce="sum") + res = client.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" + ) assert [(0, 0.0), (1, 2.0), (2, 4.0), (3, 6.0)] == res[0]["Test=This"][1] - res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="Test", reduce="max") + res = client.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="max" + ) assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["Test=This"][1] - res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="team", reduce="min") + res = client.ts().mrange( + 0, 3, filters=["Test=This"], groupby="team", reduce="min" + ) assert 2 == len(res) assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["team=ny"][1] assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[1]["team=sf"][1] @@ -613,11 +641,17 @@ def test_multi_range_advanced(client): assert [[15, 1.0], [16, 2.0]] == res["1"][2] # test groupby - res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="Test", reduce="sum") + res = client.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" + ) assert [[0, 0.0], [1, 2.0], [2, 4.0], [3, 6.0]] == res["Test=This"][3] - res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="Test", reduce="max") + res = client.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="max" + ) assert [[0, 0.0], [1, 1.0], [2, 2.0], [3, 3.0]] == res["Test=This"][3] - res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="team", reduce="min") + res = client.ts().mrange( + 0, 3, filters=["Test=This"], groupby="team", reduce="min" + ) assert 2 == len(res) assert [[0, 0.0], [1, 1.0], [2, 2.0], [3, 3.0]] == res["team=ny"][3] assert [[0, 0.0], [1, 1.0], [2, 2.0], [3, 3.0]] == res["team=sf"][3] @@ -667,9 +701,9 @@ def test_mrange_latest(client: redis.Redis): client.ts().mrange(0, 10, filters=["is_compaction=true"], latest=True), [{"t2": [{}, [(0, 4.0), (10, 8.0)]]}, {"t4": [{}, [(0, 4.0), (10, 8.0)]]}], { - 't2': [{}, {'aggregators': []}, [[0, 4.0], [10, 8.0]]], - 't4': [{}, {'aggregators': []}, [[0, 4.0], [10, 8.0]]], - } + "t2": [{}, {"aggregators": []}, [[0, 4.0], [10, 8.0]]], + "t4": [{}, {"aggregators": []}, [[0, 4.0], [10, 8.0]]], + }, ) @@ -816,8 +850,8 @@ def test_mrevrange_latest(client: redis.Redis): client.ts().mrevrange(0, 10, filters=["is_compaction=true"], latest=True), [{"t2": [{}, [(10, 8.0), (0, 4.0)]]}, {"t4": [{}, [(10, 8.0), (0, 4.0)]]}], { - 't2': [{}, {'aggregators': []}, [[10, 8.0], [0, 4.0]]], - 't4': [{}, {'aggregators': []}, [[10, 8.0], [0, 4.0]]] + "t2": [{}, {"aggregators": []}, [[10, 8.0], [0, 4.0]]], + "t4": [{}, {"aggregators": []}, [[10, 8.0], [0, 4.0]]], }, ) @@ -845,7 +879,9 @@ def test_get_latest(client: redis.Redis): timeseries.add("t1", 11, 7) timeseries.add("t1", 13, 1) assert_resp_response(client, timeseries.get("t2"), (0, 4.0), [0, 4.0]) - assert_resp_response(client, timeseries.get("t2", latest=True), (10, 8.0), [10, 8.0]) + assert_resp_response( + client, timeseries.get("t2", latest=True), (10, 8.0), [10, 8.0] + ) @pytest.mark.redismod @@ -897,9 +933,9 @@ def test_mget_latest(client: redis.Redis): timeseries.add("t1", 11, 7) timeseries.add("t1", 13, 1) res = timeseries.mget(filters=["is_compaction=true"]) - assert_resp_response(client, res, [{"t2": [{}, 0, 4.0]}], {'t2': [{}, [0, 4.0]]}) + assert_resp_response(client, res, [{"t2": [{}, 0, 4.0]}], {"t2": [{}, [0, 4.0]]}) res = timeseries.mget(filters=["is_compaction=true"], latest=True) - assert_resp_response(client, res, [{"t2": [{}, 10, 8.0]}], {'t2': [{}, [10, 8.0]]}) + assert_resp_response(client, res, [{"t2": [{}, 10, 8.0]}], {"t2": [{}, [10, 8.0]]}) @pytest.mark.redismod From 423584e6f7391d993f62870ad8d328127788b7db Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 15 Jun 2023 16:25:18 +0300 Subject: [PATCH 08/10] async modules --- redis/commands/search/commands.py | 45 +- tests/test_asyncio/test_bloom.py | 83 +- tests/test_asyncio/test_json.py | 193 ++-- tests/test_asyncio/test_search.py | 1227 +++++++++++++++++-------- tests/test_asyncio/test_timeseries.py | 555 +++++++---- 5 files changed, 1415 insertions(+), 688 deletions(-) diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 0bce9eb223..50ebf8c203 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -895,8 +895,7 @@ async def info(self): """ res = await self.execute_command(INFO_CMD, self.index_name) - it = map(to_string, res) - return dict(zip(it, it)) + return self._parse_results(INFO_CMD, res) async def search( self, @@ -921,12 +920,8 @@ async def search( if isinstance(res, Pipeline): return res - return Result( - res, - not query._no_content, - duration=(time.time() - st) * 1000.0, - has_payload=query._with_payloads, - with_scores=query._with_scores, + return self._parse_results( + SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0 ) async def aggregate( @@ -957,7 +952,9 @@ async def aggregate( cmd += self.get_params_args(query_params) raw = await self.execute_command(*cmd) - return self._get_aggregate_result(raw, query, has_cursor) + return self._parse_results( + AGGREGATE_CMD, raw, query=query, has_cursor=has_cursor + ) async def spellcheck(self, query, distance=None, include=None, exclude=None): """ @@ -983,28 +980,9 @@ async def spellcheck(self, query, distance=None, include=None, exclude=None): if exclude: cmd.extend(["TERMS", "EXCLUDE", exclude]) - raw = await self.execute_command(*cmd) - - corrections = {} - if raw == 0: - return corrections - - for _correction in raw: - if isinstance(_correction, int) and _correction == 0: - continue - - if len(_correction) != 3: - continue - if not _correction[2]: - continue - if not _correction[2][0]: - continue - - corrections[_correction[1]] = [ - {"score": _item[0], "suggestion": _item[1]} for _item in _correction[2] - ] + res = await self.execute_command(*cmd) - return corrections + return self._parse_results(SPELLCHECK_CMD, res) async def config_set(self, option, value): """Set runtime configuration option. @@ -1031,11 +1009,8 @@ async def config_get(self, option): """ # noqa cmd = [CONFIG_CMD, "GET", option] res = {} - raw = await self.execute_command(*cmd) - if raw: - for kvs in raw: - res[kvs[0]] = kvs[1] - return res + res = await self.execute_command(*cmd) + return self._parse_results(CONFIG_CMD, res) async def load_document(self, id): """ diff --git a/tests/test_asyncio/test_bloom.py b/tests/test_asyncio/test_bloom.py index 9f4a805c4c..0c9a933b12 100644 --- a/tests/test_asyncio/test_bloom.py +++ b/tests/test_asyncio/test_bloom.py @@ -5,7 +5,7 @@ import redis.asyncio as redis from redis.exceptions import ModuleError, RedisError from redis.utils import HIREDIS_AVAILABLE -from tests.conftest import skip_ifmodversion_lt +from tests.conftest import assert_resp_response, is_resp2_connection, skip_ifmodversion_lt def intlist(obj): @@ -45,7 +45,6 @@ async def test_tdigest_create(modclient: redis.Redis): assert await modclient.tdigest().create("tDigest", 100) -# region Test Bloom Filter @pytest.mark.redismod async def test_bf_add(modclient: redis.Redis): assert await modclient.bf().create("bloom", 0.01, 1000) @@ -70,9 +69,24 @@ async def test_bf_insert(modclient: redis.Redis): assert 0 == await modclient.bf().exists("bloom", "noexist") assert [1, 0] == intlist(await modclient.bf().mexists("bloom", "foo", "noexist")) info = await modclient.bf().info("bloom") - assert 2 == info.insertedNum - assert 1000 == info.capacity - assert 1 == info.filterNum + assert_resp_response( + modclient, + 2, + info.get("insertedNum"), + info.get("Number of items inserted"), + ) + assert_resp_response( + modclient, + 1000, + info.get("capacity"), + info.get("Capacity"), + ) + assert_resp_response( + modclient, + 1, + info.get("filterNum"), + info.get("Number of filters"), + ) @pytest.mark.redismod @@ -133,11 +147,21 @@ async def test_bf_info(modclient: redis.Redis): # Store a filter await modclient.bf().create("nonscaling", "0.0001", "1000", noScale=True) info = await modclient.bf().info("nonscaling") - assert info.expansionRate is None + assert_resp_response( + modclient, + None, + info.get("expansionRate"), + info.get("Expansion rate"), + ) await modclient.bf().create("expanding", "0.0001", "1000", expansion=expansion) info = await modclient.bf().info("expanding") - assert info.expansionRate == 4 + assert_resp_response( + modclient, + 4, + info.get("expansionRate"), + info.get("Expansion rate"), + ) try: # noScale mean no expansion @@ -164,7 +188,6 @@ async def test_bf_card(modclient: redis.Redis): await modclient.bf().card("setKey") -# region Test Cuckoo Filter @pytest.mark.redismod async def test_cf_add_and_insert(modclient: redis.Redis): assert await modclient.cf().create("cuckoo", 1000) @@ -180,9 +203,15 @@ async def test_cf_add_and_insert(modclient: redis.Redis): assert [1] == await modclient.cf().insert("empty1", ["foo"], capacity=1000) assert [1] == await modclient.cf().insertnx("empty2", ["bar"], capacity=1000) info = await modclient.cf().info("captest") - assert 5 == info.insertedNum - assert 0 == info.deletedNum - assert 1 == info.filterNum + assert_resp_response( + modclient, 5, info.get("insertedNum"), info.get("Number of items inserted") + ) + assert_resp_response( + modclient, 0, info.get("deletedNum"), info.get("Number of items deleted") + ) + assert_resp_response( + modclient, 1, info.get("filterNum"), info.get("Number of filters") + ) @pytest.mark.redismod @@ -197,7 +226,6 @@ async def test_cf_exists_and_del(modclient: redis.Redis): assert 0 == await modclient.cf().count("cuckoo", "filter") -# region Test Count-Min Sketch @pytest.mark.redismod async def test_cms(modclient: redis.Redis): assert await modclient.cms().initbydim("dim", 1000, 5) @@ -208,9 +236,10 @@ async def test_cms(modclient: redis.Redis): assert [10, 15] == await modclient.cms().incrby("dim", ["foo", "bar"], [5, 15]) assert [10, 15] == await modclient.cms().query("dim", "foo", "bar") info = await modclient.cms().info("dim") - assert 1000 == info.width - assert 5 == info.depth - assert 25 == info.count + assert info["width"] + assert 1000 == info["width"] + assert 5 == info["depth"] + assert 25 == info["count"] @pytest.mark.redismod @@ -231,10 +260,6 @@ async def test_cms_merge(modclient: redis.Redis): assert [16, 15, 21] == await modclient.cms().query("C", "foo", "bar", "baz") -# endregion - - -# region Test Top-K @pytest.mark.redismod async def test_topk(modclient: redis.Redis): # test list with empty buckets @@ -310,10 +335,10 @@ async def test_topk(modclient: redis.Redis): res = await modclient.topk().list("topklist", withcount=True) assert ["A", 4, "B", 3, "E", 3] == res info = await modclient.topk().info("topklist") - assert 3 == info.k - assert 50 == info.width - assert 3 == info.depth - assert 0.9 == round(float(info.decay), 1) + assert 3 == info["k"] + assert 50 == info["width"] + assert 3 == info["depth"] + assert 0.9 == round(float(info["decay"]), 1) @pytest.mark.redismod @@ -331,7 +356,6 @@ async def test_topk_incrby(modclient: redis.Redis): ) -# region Test T-Digest @pytest.mark.redismod @pytest.mark.experimental async def test_tdigest_reset(modclient: redis.Redis): @@ -343,7 +367,10 @@ async def test_tdigest_reset(modclient: redis.Redis): assert await modclient.tdigest().reset("tDigest") # assert we have 0 unmerged nodes - assert 0 == (await modclient.tdigest().info("tDigest")).unmerged_nodes + info = await modclient.tdigest().info("tDigest") + assert_resp_response( + modclient, 0, info.get("unmerged_nodes"), info.get("Unmerged nodes") + ) @pytest.mark.redismod @@ -358,8 +385,10 @@ async def test_tdigest_merge(modclient: redis.Redis): assert await modclient.tdigest().merge("to-tDigest", 1, "from-tDigest") # we should now have 110 weight on to-histogram info = await modclient.tdigest().info("to-tDigest") - total_weight_to = float(info.merged_weight) + float(info.unmerged_weight) - assert 20.0 == total_weight_to + if is_resp2_connection(modclient): + assert 20 == float(info["merged_weight"]) + float(info["unmerged_weight"]) + else: + assert 20 == float(info["Merged weight"]) + float(info["Unmerged weight"]) # test override assert await modclient.tdigest().create("from-override", 10) assert await modclient.tdigest().create("from-override-2", 10) diff --git a/tests/test_asyncio/test_json.py b/tests/test_asyncio/test_json.py index fc530c63c1..58c0601ea7 100644 --- a/tests/test_asyncio/test_json.py +++ b/tests/test_asyncio/test_json.py @@ -3,7 +3,7 @@ import redis.asyncio as redis from redis import exceptions from redis.commands.json.path import Path -from tests.conftest import skip_ifmodversion_lt +from tests.conftest import assert_resp_response, skip_ifmodversion_lt @pytest.mark.redismod @@ -17,7 +17,7 @@ async def test_json_setbinarykey(modclient: redis.Redis): @pytest.mark.redismod async def test_json_setgetdeleteforget(modclient: redis.Redis): assert await modclient.json().set("foo", Path.root_path(), "bar") - assert await modclient.json().get("foo") == "bar" + assert_resp_response(modclient, await modclient.json().get("foo"), "bar", [["bar"]]) assert await modclient.json().get("baz") is None assert await modclient.json().delete("foo") == 1 assert await modclient.json().forget("foo") == 0 # second delete @@ -27,13 +27,13 @@ async def test_json_setgetdeleteforget(modclient: redis.Redis): @pytest.mark.redismod async def test_jsonget(modclient: redis.Redis): await modclient.json().set("foo", Path.root_path(), "bar") - assert await modclient.json().get("foo") == "bar" + assert_resp_response(modclient, await modclient.json().get("foo"), "bar", [["bar"]]) @pytest.mark.redismod async def test_json_get_jset(modclient: redis.Redis): assert await modclient.json().set("foo", Path.root_path(), "bar") - assert "bar" == await modclient.json().get("foo") + assert_resp_response(modclient, await modclient.json().get("foo"), "bar", [["bar"]]) assert await modclient.json().get("baz") is None assert 1 == await modclient.json().delete("foo") assert await modclient.exists("foo") == 0 @@ -42,7 +42,10 @@ async def test_json_get_jset(modclient: redis.Redis): @pytest.mark.redismod async def test_nonascii_setgetdelete(modclient: redis.Redis): assert await modclient.json().set("notascii", Path.root_path(), "hyvää-élève") - assert "hyvää-élève" == await modclient.json().get("notascii", no_escape=True) + res = "hyvää-élève" + assert_resp_response( + modclient, await modclient.json().get("notascii", no_escape=True), res, [[res]] + ) assert 1 == await modclient.json().delete("notascii") assert await modclient.exists("notascii") == 0 @@ -79,22 +82,37 @@ async def test_mgetshouldsucceed(modclient: redis.Redis): async def test_clear(modclient: redis.Redis): await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) assert 1 == await modclient.json().clear("arr", Path.root_path()) - assert [] == await modclient.json().get("arr") + assert_resp_response(modclient, await modclient.json().get("arr"), [], [[[]]]) @pytest.mark.redismod async def test_type(modclient: redis.Redis): await modclient.json().set("1", Path.root_path(), 1) - assert "integer" == await modclient.json().type("1", Path.root_path()) - assert "integer" == await modclient.json().type("1") + assert_resp_response( + modclient, + await modclient.json().type("1", Path.root_path()), + "integer", + ["integer"], + ) + assert_resp_response( + modclient, await modclient.json().type("1"), "integer", ["integer"] + ) @pytest.mark.redismod async def test_numincrby(modclient): await modclient.json().set("num", Path.root_path(), 1) - assert 2 == await modclient.json().numincrby("num", Path.root_path(), 1) - assert 2.5 == await modclient.json().numincrby("num", Path.root_path(), 0.5) - assert 1.25 == await modclient.json().numincrby("num", Path.root_path(), -1.25) + assert_resp_response( + modclient, await modclient.json().numincrby("num", Path.root_path(), 1), 2, [2] + ) + res = await modclient.json().numincrby("num", Path.root_path(), 0.5) + assert_resp_response( + modclient, res, 2.5, [2.5] + ) + res = await modclient.json().numincrby("num", Path.root_path(), -1.25) + assert_resp_response( + modclient, res, 1.25, [1.25] + ) @pytest.mark.redismod @@ -102,9 +120,18 @@ async def test_nummultby(modclient: redis.Redis): await modclient.json().set("num", Path.root_path(), 1) with pytest.deprecated_call(): - assert 2 == await modclient.json().nummultby("num", Path.root_path(), 2) - assert 5 == await modclient.json().nummultby("num", Path.root_path(), 2.5) - assert 2.5 == await modclient.json().nummultby("num", Path.root_path(), 0.5) + res = await modclient.json().nummultby("num", Path.root_path(), 2) + assert_resp_response( + modclient, res, 2, [2] + ) + res = await modclient.json().nummultby("num", Path.root_path(), 2.5) + assert_resp_response( + modclient, res, 5, [5] + ) + res = await modclient.json().nummultby("num", Path.root_path(), 0.5) + assert_resp_response( + modclient, res, 2.5, [2.5] + ) @pytest.mark.redismod @@ -123,7 +150,10 @@ async def test_toggle(modclient: redis.Redis): async def test_strappend(modclient: redis.Redis): await modclient.json().set("jsonkey", Path.root_path(), "foo") assert 6 == await modclient.json().strappend("jsonkey", "bar") - assert "foobar" == await modclient.json().get("jsonkey", Path.root_path()) + res = await modclient.json().get("jsonkey", Path.root_path()) + assert_resp_response( + modclient, res, "foobar", [["foobar"]] + ) @pytest.mark.redismod @@ -159,13 +189,15 @@ async def test_arrindex(modclient: redis.Redis): @pytest.mark.redismod async def test_arrinsert(modclient: redis.Redis): await modclient.json().set("arr", Path.root_path(), [0, 4]) - assert 5 - -await modclient.json().arrinsert("arr", Path.root_path(), 1, *[1, 2, 3]) - assert [0, 1, 2, 3, 4] == await modclient.json().get("arr") + assert 5 == await modclient.json().arrinsert("arr", Path.root_path(), 1, *[1, 2, 3]) + res = [0, 1, 2, 3, 4] + assert_resp_response(modclient, await modclient.json().get("arr"), res, [[res]]) # test prepends await modclient.json().set("val2", Path.root_path(), [5, 6, 7, 8, 9]) await modclient.json().arrinsert("val2", Path.root_path(), 0, ["some", "thing"]) - assert await modclient.json().get("val2") == [["some", "thing"], 5, 6, 7, 8, 9] + res = [["some", "thing"], 5, 6, 7, 8, 9] + assert_resp_response(modclient, await modclient.json().get("val2"), res, [[res]]) @pytest.mark.redismod @@ -183,7 +215,7 @@ async def test_arrpop(modclient: redis.Redis): assert 3 == await modclient.json().arrpop("arr", Path.root_path(), -1) assert 2 == await modclient.json().arrpop("arr", Path.root_path()) assert 0 == await modclient.json().arrpop("arr", Path.root_path(), 0) - assert [1] == await modclient.json().get("arr") + assert_resp_response(modclient, await modclient.json().get("arr"), [1], [[[1]]]) # test out of bounds await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) @@ -198,7 +230,8 @@ async def test_arrpop(modclient: redis.Redis): async def test_arrtrim(modclient: redis.Redis): await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) assert 3 == await modclient.json().arrtrim("arr", Path.root_path(), 1, 3) - assert [1, 2, 3] == await modclient.json().get("arr") + res = await modclient.json().get("arr") + assert_resp_response(modclient, res, [1, 2, 3], [[[1, 2, 3]]]) # <0 test, should be 0 equivalent await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) @@ -284,13 +317,15 @@ async def test_json_delete_with_dollar(modclient: redis.Redis): assert await modclient.json().set("doc1", "$", doc1) assert await modclient.json().delete("doc1", "$..a") == 2 r = await modclient.json().get("doc1", "$") - assert r == [{"nested": {"b": 3}}] + res = [{"nested": {"b": 3}}] + assert_resp_response(modclient, await modclient.json().get("doc1", "$"), res, [res]) doc2 = {"a": {"a": 2, "b": 3}, "b": ["a", "b"], "nested": {"b": [True, "a", "b"]}} assert await modclient.json().set("doc2", "$", doc2) assert await modclient.json().delete("doc2", "$..a") == 1 res = await modclient.json().get("doc2", "$") - assert res == [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + res = [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + assert_resp_response(modclient, await modclient.json().get("doc2", "$"), res, [res]) doc3 = [ { @@ -322,7 +357,7 @@ async def test_json_delete_with_dollar(modclient: redis.Redis): ] ] res = await modclient.json().get("doc3", "$") - assert res == doc3val + assert_resp_response(modclient, res, doc3val, [doc3val]) # Test async default path assert await modclient.json().delete("doc3") == 1 @@ -336,14 +371,14 @@ async def test_json_forget_with_dollar(modclient: redis.Redis): doc1 = {"a": 1, "nested": {"a": 2, "b": 3}} assert await modclient.json().set("doc1", "$", doc1) assert await modclient.json().forget("doc1", "$..a") == 2 - r = await modclient.json().get("doc1", "$") - assert r == [{"nested": {"b": 3}}] + res = [{"nested": {"b": 3}}] + assert_resp_response(modclient, await modclient.json().get("doc1", "$"), res, [res]) doc2 = {"a": {"a": 2, "b": 3}, "b": ["a", "b"], "nested": {"b": [True, "a", "b"]}} assert await modclient.json().set("doc2", "$", doc2) assert await modclient.json().forget("doc2", "$..a") == 1 - res = await modclient.json().get("doc2", "$") - assert res == [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + res = [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + assert_resp_response(modclient, await modclient.json().get("doc2", "$"), res, [res]) doc3 = [ { @@ -375,7 +410,7 @@ async def test_json_forget_with_dollar(modclient: redis.Redis): ] ] res = await modclient.json().get("doc3", "$") - assert res == doc3val + assert_resp_response(modclient, res, doc3val, [doc3val]) # Test async default path assert await modclient.json().forget("doc3") == 1 @@ -398,8 +433,14 @@ async def test_json_mget_dollar(modclient: redis.Redis): {"a": 4, "b": 5, "nested": {"a": 6}, "c": None, "nested2": {"a": [None]}}, ) # Compare also to single JSON.GET - assert await modclient.json().get("doc1", "$..a") == [1, 3, None] - assert await modclient.json().get("doc2", "$..a") == [4, 6, [None]] + res = [1, 3, None] + assert_resp_response( + modclient, await modclient.json().get("doc1", "$..a"), res, [res] + ) + res = [4, 6, [None]] + assert_resp_response( + modclient, await modclient.json().get("doc2", "$..a"), res, [res] + ) # Test mget with single path await modclient.json().mget("doc1", "$..a") == [1, 3, None] @@ -479,15 +520,14 @@ async def test_strappend_dollar(modclient: redis.Redis): # Test multi await modclient.json().strappend("doc1", "bar", "$..a") == [6, 8, None] - await modclient.json().get("doc1", "$") == [ - {"a": "foobar", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}} - ] + res = [{"a": "foobar", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}}] + assert_resp_response(modclient, await modclient.json().get("doc1", "$"), res, [res]) + # Test single await modclient.json().strappend("doc1", "baz", "$.nested1.a") == [11] - await modclient.json().get("doc1", "$") == [ - {"a": "foobar", "nested1": {"a": "hellobarbaz"}, "nested2": {"a": 31}} - ] + res = [{"a": "foobar", "nested1": {"a": "hellobarbaz"}, "nested2": {"a": 31}}] + assert_resp_response(modclient, await modclient.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -495,9 +535,8 @@ async def test_strappend_dollar(modclient: redis.Redis): # Test multi await modclient.json().strappend("doc1", "bar", ".*.a") == 8 - await modclient.json().get("doc1", "$") == [ - {"a": "foo", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}} - ] + res = [{"a": "foobar", "nested1": {"a": "hellobarbazbar"}, "nested2": {"a": 31}}] + assert_resp_response(modclient, await modclient.json().get("doc1", "$"), res, [res]) # Test missing path with pytest.raises(exceptions.ResponseError): @@ -539,23 +578,25 @@ async def test_arrappend_dollar(modclient: redis.Redis): ) # Test multi await modclient.json().arrappend("doc1", "$..a", "bar", "racuda") == [3, 5, None] - assert await modclient.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda"]}, "nested2": {"a": 31}, } ] + assert_resp_response(modclient, await modclient.json().get("doc1", "$"), res, [res]) # Test single assert await modclient.json().arrappend("doc1", "$.nested1.a", "baz") == [6] - assert await modclient.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda", "baz"]}, "nested2": {"a": 31}, } ] + assert_resp_response(modclient, await modclient.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -574,22 +615,24 @@ async def test_arrappend_dollar(modclient: redis.Redis): # Test multi (all paths are updated, but return result of last path) assert await modclient.json().arrappend("doc1", "..a", "bar", "racuda") == 5 - assert await modclient.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda"]}, "nested2": {"a": 31}, } ] + assert_resp_response(modclient, await modclient.json().get("doc1", "$"), res, [res]) # Test single assert await modclient.json().arrappend("doc1", ".nested1.a", "baz") == 6 - assert await modclient.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda", "baz"]}, "nested2": {"a": 31}, } ] + assert_resp_response(modclient, await modclient.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -611,22 +654,24 @@ async def test_arrinsert_dollar(modclient: redis.Redis): res = await modclient.json().arrinsert("doc1", "$..a", "1", "bar", "racuda") assert res == [3, 5, None] - assert await modclient.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", "bar", "racuda", None, "world"]}, "nested2": {"a": 31}, } ] + assert_resp_response(modclient, await modclient.json().get("doc1", "$"), res, [res]) # Test single assert await modclient.json().arrinsert("doc1", "$.nested1.a", -2, "baz") == [6] - assert await modclient.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", "bar", "racuda", "baz", None, "world"]}, "nested2": {"a": 31}, } ] + assert_resp_response(modclient, await modclient.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -692,12 +737,11 @@ async def test_arrpop_dollar(modclient: redis.Redis): }, ) - # # # Test multi + # Test multi assert await modclient.json().arrpop("doc1", "$..a", 1) == ['"foo"', None, None] - assert await modclient.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}}] + assert_resp_response(modclient, await modclient.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -715,9 +759,8 @@ async def test_arrpop_dollar(modclient: redis.Redis): ) # Test multi (all paths are updated, but return result of last path) await modclient.json().arrpop("doc1", "..a", "1") is None - assert await modclient.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}}] + assert_resp_response(modclient, await modclient.json().get("doc1", "$"), res, [res]) # # Test missing key with pytest.raises(exceptions.ResponseError): @@ -738,19 +781,16 @@ async def test_arrtrim_dollar(modclient: redis.Redis): ) # Test multi assert await modclient.json().arrtrim("doc1", "$..a", "1", -1) == [0, 2, None] - assert await modclient.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": [None, "world"]}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": [None, "world"]}, "nested2": {"a": 31}}] + assert_resp_response(modclient, await modclient.json().get("doc1", "$"), res, [res]) assert await modclient.json().arrtrim("doc1", "$..a", "1", "1") == [0, 1, None] - assert await modclient.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}}] + assert_resp_response(modclient, await modclient.json().get("doc1", "$"), res, [res]) # Test single assert await modclient.json().arrtrim("doc1", "$.nested1.a", 1, 0) == [0] - assert await modclient.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": []}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": []}, "nested2": {"a": 31}}] + assert_resp_response(modclient, await modclient.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -772,9 +812,8 @@ async def test_arrtrim_dollar(modclient: redis.Redis): # Test single assert await modclient.json().arrtrim("doc1", ".nested1.a", "1", "1") == 1 - assert await modclient.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}}] + assert_resp_response(modclient, await modclient.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -872,13 +911,20 @@ async def test_type_dollar(modclient: redis.Redis): jdata, jtypes = load_types_data("a") await modclient.json().set("doc1", "$", jdata) # Test multi - assert await modclient.json().type("doc1", "$..a") == jtypes + assert_resp_response( + modclient, await modclient.json().type("doc1", "$..a"), jtypes, [jtypes] + ) # Test single - assert await modclient.json().type("doc1", "$.nested2.a") == [jtypes[1]] + res = await modclient.json().type("doc1", "$.nested2.a") + assert_resp_response( + modclient, res, [jtypes[1]], [[jtypes[1]]] + ) # Test missing key - assert await modclient.json().type("non_existing_doc", "..a") is None + assert_resp_response( + modclient, await modclient.json().type("non_existing_doc", "..a"), None, [None] + ) @pytest.mark.redismod @@ -898,9 +944,10 @@ async def test_clear_dollar(modclient: redis.Redis): # Test multi assert await modclient.json().clear("doc1", "$..a") == 3 - assert await modclient.json().get("doc1", "$") == [ + res = [ {"nested1": {"a": {}}, "a": [], "nested2": {"a": "claro"}, "nested3": {"a": {}}} ] + assert_resp_response(modclient, await modclient.json().get("doc1", "$"), res, [res]) # Test single await modclient.json().set( @@ -914,7 +961,7 @@ async def test_clear_dollar(modclient: redis.Redis): }, ) assert await modclient.json().clear("doc1", "$.nested1.a") == 1 - assert await modclient.json().get("doc1", "$") == [ + res = [ { "nested1": {"a": {}}, "a": ["foo"], @@ -922,10 +969,13 @@ async def test_clear_dollar(modclient: redis.Redis): "nested3": {"a": {"baz": 50}}, } ] + assert_resp_response(modclient, await modclient.json().get("doc1", "$"), res, [res]) # Test missing path (async defaults to root) assert await modclient.json().clear("doc1") == 1 - assert await modclient.json().get("doc1", "$") == [{}] + assert_resp_response( + modclient, await modclient.json().get("doc1", "$"), [{}], [[{}]] + ) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -946,7 +996,7 @@ async def test_toggle_dollar(modclient: redis.Redis): ) # Test multi assert await modclient.json().toggle("doc1", "$..a") == [None, 1, None, 0] - assert await modclient.json().get("doc1", "$") == [ + res = [ { "a": ["foo"], "nested1": {"a": True}, @@ -954,6 +1004,7 @@ async def test_toggle_dollar(modclient: redis.Redis): "nested3": {"a": False}, } ] + assert_resp_response(modclient, await modclient.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 8707cdf61b..1e83efae66 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -16,7 +16,12 @@ from redis.commands.search.query import GeoFilter, NumericFilter, Query from redis.commands.search.result import Result from redis.commands.search.suggestion import Suggestion -from tests.conftest import skip_if_redis_enterprise, skip_ifmodversion_lt +from tests.conftest import ( + assert_resp_response, + is_resp2_connection, + skip_if_redis_enterprise, + skip_ifmodversion_lt, +) WILL_PLAY_TEXT = os.path.abspath( os.path.join(os.path.dirname(__file__), "testdata", "will_play_text.csv.bz2") @@ -32,12 +37,16 @@ async def waitForIndex(env, idx, timeout=None): while True: res = await env.execute_command("FT.INFO", idx) try: - res.index("indexing") + if int(res[res.index("indexing") + 1]) == 0: + break except ValueError: break - - if int(res[res.index("indexing") + 1]) == 0: - break + except AttributeError: + try: + if int(res["indexing"]) == 0: + break + except ValueError: + break time.sleep(delay) if timeout is not None: @@ -119,89 +128,185 @@ async def test_client(modclient: redis.Redis): assert num_docs == int(info["num_docs"]) res = await modclient.ft().search("henry iv") - assert isinstance(res, Result) - assert 225 == res.total - assert 10 == len(res.docs) - assert res.duration > 0 - - for doc in res.docs: - assert doc.id - assert doc.play == "Henry IV" + if is_resp2_connection(modclient): + assert isinstance(res, Result) + assert 225 == res.total + assert 10 == len(res.docs) + assert res.duration > 0 + + for doc in res.docs: + assert doc.id + assert doc.play == "Henry IV" + assert len(doc.txt) > 0 + + # test no content + res = await modclient.ft().search(Query("king").no_content()) + assert 194 == res.total + assert 10 == len(res.docs) + for doc in res.docs: + assert "txt" not in doc.__dict__ + assert "play" not in doc.__dict__ + + # test verbatim vs no verbatim + total = (await modclient.ft().search(Query("kings").no_content())).total + vtotal = (await modclient.ft().search(Query("kings").no_content().verbatim())).total + assert total > vtotal + + # test in fields + txt_total = ( + await modclient.ft().search(Query("henry").no_content().limit_fields("txt")) + ).total + play_total = ( + await modclient.ft().search(Query("henry").no_content().limit_fields("play")) + ).total + both_total = ( + await ( + modclient.ft().search( + Query("henry").no_content().limit_fields("play", "txt") + ) + ) + ).total + assert 129 == txt_total + assert 494 == play_total + assert 494 == both_total + + # test load_document + doc = await modclient.ft().load_document("henry vi part 3:62") + assert doc is not None + assert "henry vi part 3:62" == doc.id + assert doc.play == "Henry VI Part 3" assert len(doc.txt) > 0 - # test no content - res = await modclient.ft().search(Query("king").no_content()) - assert 194 == res.total - assert 10 == len(res.docs) - for doc in res.docs: - assert "txt" not in doc.__dict__ - assert "play" not in doc.__dict__ - - # test verbatim vs no verbatim - total = (await modclient.ft().search(Query("kings").no_content())).total - vtotal = (await modclient.ft().search(Query("kings").no_content().verbatim())).total - assert total > vtotal - - # test in fields - txt_total = ( - await modclient.ft().search(Query("henry").no_content().limit_fields("txt")) - ).total - play_total = ( - await modclient.ft().search(Query("henry").no_content().limit_fields("play")) - ).total - both_total = ( - await ( - modclient.ft().search( - Query("henry").no_content().limit_fields("play", "txt") - ) + # test in-keys + ids = [x.id for x in (await modclient.ft().search(Query("henry"))).docs] + assert 10 == len(ids) + subset = ids[:5] + docs = await modclient.ft().search(Query("henry").limit_ids(*subset)) + assert len(subset) == docs.total + ids = [x.id for x in docs.docs] + assert set(ids) == set(subset) + + # test slop and in order + assert 193 == (await modclient.ft().search(Query("henry king"))).total + assert ( + 3 == (await modclient.ft().search(Query("henry king").slop(0).in_order())).total ) - ).total - assert 129 == txt_total - assert 494 == play_total - assert 494 == both_total - - # test load_document - doc = await modclient.ft().load_document("henry vi part 3:62") - assert doc is not None - assert "henry vi part 3:62" == doc.id - assert doc.play == "Henry VI Part 3" - assert len(doc.txt) > 0 - - # test in-keys - ids = [x.id for x in (await modclient.ft().search(Query("henry"))).docs] - assert 10 == len(ids) - subset = ids[:5] - docs = await modclient.ft().search(Query("henry").limit_ids(*subset)) - assert len(subset) == docs.total - ids = [x.id for x in docs.docs] - assert set(ids) == set(subset) - - # test slop and in order - assert 193 == (await modclient.ft().search(Query("henry king"))).total - assert ( - 3 == (await modclient.ft().search(Query("henry king").slop(0).in_order())).total - ) - assert ( - 52 - == (await modclient.ft().search(Query("king henry").slop(0).in_order())).total - ) - assert 53 == (await modclient.ft().search(Query("henry king").slop(0))).total - assert 167 == (await modclient.ft().search(Query("henry king").slop(100))).total + assert ( + 52 + == (await modclient.ft().search(Query("king henry").slop(0).in_order())).total + ) + assert 53 == (await modclient.ft().search(Query("henry king").slop(0))).total + assert 167 == (await modclient.ft().search(Query("henry king").slop(100))).total + + # test delete document + await modclient.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = await modclient.ft().search(Query("death of a salesman")) + assert 1 == res.total + + assert 1 == await modclient.ft().delete_document("doc-5ghs2") + res = await modclient.ft().search(Query("death of a salesman")) + assert 0 == res.total + assert 0 == await modclient.ft().delete_document("doc-5ghs2") + + await modclient.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = await modclient.ft().search(Query("death of a salesman")) + assert 1 == res.total + await modclient.ft().delete_document("doc-5ghs2") + else: + assert isinstance(res, dict) + assert 225 == res["total_results"] + assert 10 == len(res["results"]) + + for doc in res["results"]: + assert doc["id"] + assert doc["fields"]["play"] == "Henry IV" + assert len(doc["fields"]["txt"]) > 0 + + # test no content + res = await modclient.ft().search(Query("king").no_content()) + assert 194 == res["total_results"] + assert 10 == len(res["results"]) + for doc in res["results"]: + assert "fields" not in doc.keys() + + # test verbatim vs no verbatim + total = (await modclient.ft().search( + Query("kings").no_content() + ))["total_results"] + vtotal = (await modclient.ft().search(Query("kings").no_content().verbatim()))[ + "total_results" + ] + assert total > vtotal + + # test in fields + txt_total = (await modclient.ft().search( + Query("henry").no_content().limit_fields("txt") + ))["total_results"] + play_total = (await modclient.ft().search( + Query("henry").no_content().limit_fields("play") + ))["total_results"] + both_total = (await modclient.ft().search( + Query("henry").no_content().limit_fields("play", "txt") + ))["total_results"] + assert 129 == txt_total + assert 494 == play_total + assert 494 == both_total + + # test load_document + doc = await modclient.ft().load_document("henry vi part 3:62") + assert doc is not None + assert "henry vi part 3:62" == doc.id + assert doc.play == "Henry VI Part 3" + assert len(doc.txt) > 0 + + # test in-keys + ids = [ + x["id"] for x in (await modclient.ft().search(Query("henry")))["results"] + ] + assert 10 == len(ids) + subset = ids[:5] + docs = await modclient.ft().search(Query("henry").limit_ids(*subset)) + assert len(subset) == docs["total_results"] + ids = [x["id"] for x in docs["results"]] + assert set(ids) == set(subset) + + # test slop and in order + assert 193 == ( + await modclient.ft().search(Query("henry king")) + )["total_results"] + assert ( + 3 + == (await modclient.ft().search(Query("henry king").slop(0).in_order()))[ + "total_results" + ] + ) + assert ( + 52 + == (await modclient.ft().search(Query("king henry").slop(0).in_order()))[ + "total_results" + ] + ) + assert 53 == (await modclient.ft().search( + Query("henry king").slop(0) + ))["total_results"] + assert 167 == (await modclient.ft().search( + Query("henry king").slop(100) + ))["total_results"] - # test delete document - await modclient.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) - res = await modclient.ft().search(Query("death of a salesman")) - assert 1 == res.total + # test delete document + await modclient.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = await modclient.ft().search(Query("death of a salesman")) + assert 1 == res["total_results"] - assert 1 == await modclient.ft().delete_document("doc-5ghs2") - res = await modclient.ft().search(Query("death of a salesman")) - assert 0 == res.total - assert 0 == await modclient.ft().delete_document("doc-5ghs2") + assert 1 == await modclient.ft().delete_document("doc-5ghs2") + res = await modclient.ft().search(Query("death of a salesman")) + assert 0 == res["total_results"] + assert 0 == await modclient.ft().delete_document("doc-5ghs2") - await modclient.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) - res = await modclient.ft().search(Query("death of a salesman")) - assert 1 == res.total - await modclient.ft().delete_document("doc-5ghs2") + await modclient.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = await modclient.ft().search(Query("death of a salesman")) + assert 1 == res["total_results"] + await modclient.ft().delete_document("doc-5ghs2") @pytest.mark.redismod @@ -214,12 +319,16 @@ async def test_scores(modclient: redis.Redis): q = Query("foo ~bar").with_scores() res = await modclient.ft().search(q) - assert 2 == res.total - assert "doc2" == res.docs[0].id - assert 3.0 == res.docs[0].score - assert "doc1" == res.docs[1].id - # todo: enable once new RS version is tagged - # self.assertEqual(0.2, res.docs[1].score) + if is_resp2_connection(modclient): + assert 2 == res.total + assert "doc2" == res.docs[0].id + assert 3.0 == res.docs[0].score + assert "doc1" == res.docs[1].id + else: + assert 2 == res["total_results"] + assert "doc2" == res["results"][0]["id"] + assert 3.0 == res["results"][0]["score"] + assert "doc1" == res["results"][1]["id"] @pytest.mark.redismod @@ -233,8 +342,13 @@ async def test_stopwords(modclient: redis.Redis): q1 = Query("foo bar").no_content() q2 = Query("foo bar hello world").no_content() res1, res2 = await modclient.ft().search(q1), await modclient.ft().search(q2) - assert 0 == res1.total - assert 1 == res2.total + if is_resp2_connection(modclient): + assert 0 == res1.total + assert 1 == res2.total + else: + assert 0 == res1["total_results"] + assert 1 == res2["total_results"] + @pytest.mark.redismod @@ -263,24 +377,40 @@ async def test_filters(modclient: redis.Redis): ) res1, res2 = await modclient.ft().search(q1), await modclient.ft().search(q2) - assert 1 == res1.total - assert 1 == res2.total - assert "doc2" == res1.docs[0].id - assert "doc1" == res2.docs[0].id + if is_resp2_connection(modclient): + assert 1 == res1.total + assert 1 == res2.total + assert "doc2" == res1.docs[0].id + assert "doc1" == res2.docs[0].id + else: + assert 1 == res1["total_results"] + assert 1 == res2["total_results"] + assert "doc2" == res1["results"][0]["id"] + assert "doc1" == res2["results"][0]["id"] # Test geo filter q1 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 10)).no_content() q2 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 100)).no_content() res1, res2 = await modclient.ft().search(q1), await modclient.ft().search(q2) - assert 1 == res1.total - assert 2 == res2.total - assert "doc1" == res1.docs[0].id - - # Sort results, after RDB reload order may change - res = [res2.docs[0].id, res2.docs[1].id] - res.sort() - assert ["doc1", "doc2"] == res + if is_resp2_connection(modclient): + assert 1 == res1.total + assert 2 == res2.total + assert "doc1" == res1.docs[0].id + + # Sort results, after RDB reload order may change + res = [res2.docs[0].id, res2.docs[1].id] + res.sort() + assert ["doc1", "doc2"] == res + else: + assert 1 == res1["total_results"] + assert 2 == res2["total_results"] + assert "doc1" == res1["results"][0]["id"] + + # Sort results, after RDB reload order may change + res = [res2["results"][0]["id"], res2["results"][1]["id"]] + res.sort() + assert ["doc1", "doc2"] == res @pytest.mark.redismod @@ -299,14 +429,24 @@ async def test_sort_by(modclient: redis.Redis): q2 = Query("foo").sort_by("num", asc=False).no_content() res1, res2 = await modclient.ft().search(q1), await modclient.ft().search(q2) - assert 3 == res1.total - assert "doc1" == res1.docs[0].id - assert "doc2" == res1.docs[1].id - assert "doc3" == res1.docs[2].id - assert 3 == res2.total - assert "doc1" == res2.docs[2].id - assert "doc2" == res2.docs[1].id - assert "doc3" == res2.docs[0].id + if is_resp2_connection(modclient): + assert 3 == res1.total + assert "doc1" == res1.docs[0].id + assert "doc2" == res1.docs[1].id + assert "doc3" == res1.docs[2].id + assert 3 == res2.total + assert "doc1" == res2.docs[2].id + assert "doc2" == res2.docs[1].id + assert "doc3" == res2.docs[0].id + else: + assert 3 == res1["total_results"] + assert "doc1" == res1["results"][0]["id"] + assert "doc2" == res1["results"][1]["id"] + assert "doc3" == res1["results"][2]["id"] + assert 3 == res2["total_results"] + assert "doc1" == res2["results"][2]["id"] + assert "doc2" == res2["results"][1]["id"] + assert "doc3" == res2["results"][0]["id"] @pytest.mark.redismod @@ -424,27 +564,50 @@ async def test_no_index(modclient: redis.Redis): ) await waitForIndex(modclient, "idx") - res = await modclient.ft().search(Query("@text:aa*")) - assert 0 == res.total + if is_resp2_connection(modclient): + res = await modclient.ft().search(Query("@text:aa*")) + assert 0 == res.total - res = await modclient.ft().search(Query("@field:aa*")) - assert 2 == res.total + res = await modclient.ft().search(Query("@field:aa*")) + assert 2 == res.total - res = await modclient.ft().search(Query("*").sort_by("text", asc=False)) - assert 2 == res.total - assert "doc2" == res.docs[0].id + res = await modclient.ft().search(Query("*").sort_by("text", asc=False)) + assert 2 == res.total + assert "doc2" == res.docs[0].id - res = await modclient.ft().search(Query("*").sort_by("text", asc=True)) - assert "doc1" == res.docs[0].id + res = await modclient.ft().search(Query("*").sort_by("text", asc=True)) + assert "doc1" == res.docs[0].id - res = await modclient.ft().search(Query("*").sort_by("numeric", asc=True)) - assert "doc1" == res.docs[0].id + res = await modclient.ft().search(Query("*").sort_by("numeric", asc=True)) + assert "doc1" == res.docs[0].id - res = await modclient.ft().search(Query("*").sort_by("geo", asc=True)) - assert "doc1" == res.docs[0].id + res = await modclient.ft().search(Query("*").sort_by("geo", asc=True)) + assert "doc1" == res.docs[0].id - res = await modclient.ft().search(Query("*").sort_by("tag", asc=True)) - assert "doc1" == res.docs[0].id + res = await modclient.ft().search(Query("*").sort_by("tag", asc=True)) + assert "doc1" == res.docs[0].id + else: + res = await modclient.ft().search(Query("@text:aa*")) + assert 0 == res["total_results"] + + res = await modclient.ft().search(Query("@field:aa*")) + assert 2 == res["total_results"] + + res = await modclient.ft().search(Query("*").sort_by("text", asc=False)) + assert 2 == res["total_results"] + assert "doc2" == res["results"][0]["id"] + + res = await modclient.ft().search(Query("*").sort_by("text", asc=True)) + assert "doc1" == res["results"][0]["id"] + + res = await modclient.ft().search(Query("*").sort_by("numeric", asc=True)) + assert "doc1" == res["results"][0]["id"] + + res = await modclient.ft().search(Query("*").sort_by("geo", asc=True)) + assert "doc1" == res["results"][0]["id"] + + res = await modclient.ft().search(Query("*").sort_by("tag", asc=True)) + assert "doc1" == res["results"][0]["id"] # Ensure exception is raised for non-indexable, non-sortable fields with pytest.raises(Exception): @@ -481,21 +644,38 @@ async def test_summarize(modclient: redis.Redis): q.highlight(fields=("play", "txt"), tags=("", "")) q.summarize("txt") - doc = sorted((await modclient.ft().search(q)).docs)[0] - assert "Henry IV" == doc.play - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt - ) + if is_resp2_connection(modclient): + doc = sorted((await modclient.ft().search(q)).docs)[0] + assert "Henry IV" == doc.play + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc.txt + ) - q = Query("king henry").paging(0, 1).summarize().highlight() + q = Query("king henry").paging(0, 1).summarize().highlight() - doc = sorted((await modclient.ft().search(q)).docs)[0] - assert "Henry ... " == doc.play - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt - ) + doc = sorted((await modclient.ft().search(q)).docs)[0] + assert "Henry ... " == doc.play + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc.txt + ) + else: + doc = sorted((await modclient.ft().search(q))["results"])[0] + assert "Henry IV" == doc["fields"]["play"] + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc["fields"]["txt"] + ) + + q = Query("king henry").paging(0, 1).summarize().highlight() + + doc = sorted((await modclient.ft().search(q))["results"])[0] + assert "Henry ... " == doc["fields"]["play"] + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc["fields"]["txt"] + ) @pytest.mark.redismod @@ -515,25 +695,46 @@ async def test_alias(modclient: redis.Redis): await index1.hset("index1:lonestar", mapping={"name": "lonestar"}) await index2.hset("index2:yogurt", mapping={"name": "yogurt"}) - res = (await ftindex1.search("*")).docs[0] - assert "index1:lonestar" == res.id + if is_resp2_connection(modclient): + res = (await ftindex1.search("*")).docs[0] + assert "index1:lonestar" == res.id - # create alias and check for results - await ftindex1.aliasadd("spaceballs") - alias_client = getClient(modclient).ft("spaceballs") - res = (await alias_client.search("*")).docs[0] - assert "index1:lonestar" == res.id + # create alias and check for results + await ftindex1.aliasadd("spaceballs") + alias_client = getClient(modclient).ft("spaceballs") + res = (await alias_client.search("*")).docs[0] + assert "index1:lonestar" == res.id - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - await ftindex2.aliasadd("spaceballs") + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + await ftindex2.aliasadd("spaceballs") + + # update alias and ensure new results + await ftindex2.aliasupdate("spaceballs") + alias_client2 = getClient(modclient).ft("spaceballs") - # update alias and ensure new results - await ftindex2.aliasupdate("spaceballs") - alias_client2 = getClient(modclient).ft("spaceballs") + res = (await alias_client2.search("*")).docs[0] + assert "index2:yogurt" == res.id + else: + res = (await ftindex1.search("*"))["results"][0] + assert "index1:lonestar" == res["id"] - res = (await alias_client2.search("*")).docs[0] - assert "index2:yogurt" == res.id + # create alias and check for results + await ftindex1.aliasadd("spaceballs") + alias_client = getClient(await modclient).ft("spaceballs") + res = (await alias_client.search("*"))["results"][0] + assert "index1:lonestar" == res["id"] + + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + await ftindex2.aliasadd("spaceballs") + + # update alias and ensure new results + await ftindex2.aliasupdate("spaceballs") + alias_client2 = getClient(await modclient).ft("spaceballs") + + res = (await alias_client2.search("*"))["results"][0] + assert "index2:yogurt" == res["id"] await ftindex2.aliasdel("spaceballs") with pytest.raises(Exception): @@ -557,18 +758,34 @@ async def test_alias_basic(modclient: redis.Redis): # add the actual alias and check await index1.aliasadd("myalias") alias_client = getClient(modclient).ft("myalias") - res = sorted((await alias_client.search("*")).docs, key=lambda x: x.id) - assert "doc1" == res[0].id - - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - await index2.aliasadd("myalias") - - # update the alias and ensure we get doc2 - await index2.aliasupdate("myalias") - alias_client2 = getClient(modclient).ft("myalias") - res = sorted((await alias_client2.search("*")).docs, key=lambda x: x.id) - assert "doc1" == res[0].id + if is_resp2_connection(modclient): + res = sorted((await alias_client.search("*")).docs, key=lambda x: x.id) + assert "doc1" == res[0].id + + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + await index2.aliasadd("myalias") + + # update the alias and ensure we get doc2 + await index2.aliasupdate("myalias") + alias_client2 = getClient(modclient).ft("myalias") + res = sorted((await alias_client2.search("*")).docs, key=lambda x: x.id) + assert "doc1" == res[0].id + else: + res = sorted((await alias_client.search("*"))["results"], key=lambda x: x["id"]) + assert "doc1" == res[0]["id"] + + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + await index2.aliasadd("myalias") + + # update the alias and ensure we get doc2 + await index2.aliasupdate("myalias") + alias_client2 = getClient(client).ft("myalias") + res = sorted( + (await alias_client2.search("*"))["results"], key=lambda x: x["id"] + ) + assert "doc1" == res[0]["id"] # delete the alias and expect an error if we try to query again await index2.aliasdel("myalias") @@ -576,34 +793,34 @@ async def test_alias_basic(modclient: redis.Redis): _ = (await alias_client2.search("*")).docs[0] -@pytest.mark.redismod -async def test_tags(modclient: redis.Redis): - await modclient.ft().create_index((TextField("txt"), TagField("tags"))) - tags = "foo,foo bar,hello;world" - tags2 = "soba,ramen" +# @pytest.mark.redismod +# async def test_tags(modclient: redis.Redis): +# await modclient.ft().create_index((TextField("txt"), TagField("tags"))) +# tags = "foo,foo bar,hello;world" +# tags2 = "soba,ramen" - await modclient.hset("doc1", mapping={"txt": "fooz barz", "tags": tags}) - await modclient.hset("doc2", mapping={"txt": "noodles", "tags": tags2}) - await waitForIndex(modclient, "idx") +# await modclient.hset("doc1", mapping={"txt": "fooz barz", "tags": tags}) +# await modclient.hset("doc2", mapping={"txt": "noodles", "tags": tags2}) +# await waitForIndex(modclient, "idx") - q = Query("@tags:{foo}") - res = await modclient.ft().search(q) - assert 1 == res.total +# q = Query("@tags:{foo}") +# res = await modclient.ft().search(q) +# assert 1 == res.total - q = Query("@tags:{foo bar}") - res = await modclient.ft().search(q) - assert 1 == res.total +# q = Query("@tags:{foo bar}") +# res = await modclient.ft().search(q) +# assert 1 == res.total - q = Query("@tags:{foo\\ bar}") - res = await modclient.ft().search(q) - assert 1 == res.total +# q = Query("@tags:{foo\\ bar}") +# res = await modclient.ft().search(q) +# assert 1 == res.total - q = Query("@tags:{hello\\;world}") - res = await modclient.ft().search(q) - assert 1 == res.total +# q = Query("@tags:{hello\\;world}") +# res = await modclient.ft().search(q) +# assert 1 == res.total - q2 = await modclient.ft().tagvals("tags") - assert (tags.split(",") + tags2.split(",")).sort() == q2.sort() +# q2 = await modclient.ft().tagvals("tags") +# assert (tags.split(",") + tags2.split(",")).sort() == q2.sort() @pytest.mark.redismod @@ -613,8 +830,12 @@ async def test_textfield_sortable_nostem(modclient: redis.Redis): # Now get the index info to confirm its contents response = await modclient.ft().info() - assert "SORTABLE" in response["attributes"][0] - assert "NOSTEM" in response["attributes"][0] + if is_resp2_connection(modclient): + assert "SORTABLE" in response["attributes"][0] + assert "NOSTEM" in response["attributes"][0] + else: + assert "SORTABLE" in response["attributes"][0]["flags"] + assert "NOSTEM" in response["attributes"][0]["flags"] @pytest.mark.redismod @@ -635,7 +856,10 @@ async def test_alter_schema_add(modclient: redis.Redis): # Ensure we find the result searching on the added body field res = await modclient.ft().search(q) - assert 1 == res.total + if is_resp2_connection(modclient): + assert 1 == res.total + else: + assert 1 == res["total_results"] @pytest.mark.redismod @@ -650,33 +874,60 @@ async def test_spell_check(modclient: redis.Redis): await modclient.hset("doc2", mapping={"f1": "very important", "f2": "lorem ipsum"}) await waitForIndex(modclient, "idx") - # test spellcheck - res = await modclient.ft().spellcheck("impornant") - assert "important" == res["impornant"][0]["suggestion"] - - res = await modclient.ft().spellcheck("contnt") - assert "content" == res["contnt"][0]["suggestion"] - - # test spellcheck with Levenshtein distance - res = await modclient.ft().spellcheck("vlis") - assert res == {} - res = await modclient.ft().spellcheck("vlis", distance=2) - assert "valid" == res["vlis"][0]["suggestion"] - - # test spellcheck include - await modclient.ft().dict_add("dict", "lore", "lorem", "lorm") - res = await modclient.ft().spellcheck("lorm", include="dict") - assert len(res["lorm"]) == 3 - assert ( - res["lorm"][0]["suggestion"], - res["lorm"][1]["suggestion"], - res["lorm"][2]["suggestion"], - ) == ("lorem", "lore", "lorm") - assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0") - - # test spellcheck exclude - res = await modclient.ft().spellcheck("lorm", exclude="dict") - assert res == {} + if is_resp2_connection(modclient): + # test spellcheck + res = await modclient.ft().spellcheck("impornant") + assert "important" == res["impornant"][0]["suggestion"] + + res = await modclient.ft().spellcheck("contnt") + assert "content" == res["contnt"][0]["suggestion"] + + # test spellcheck with Levenshtein distance + res = await modclient.ft().spellcheck("vlis") + assert res == {} + res = await modclient.ft().spellcheck("vlis", distance=2) + assert "valid" == res["vlis"][0]["suggestion"] + + # test spellcheck include + await modclient.ft().dict_add("dict", "lore", "lorem", "lorm") + res = await modclient.ft().spellcheck("lorm", include="dict") + assert len(res["lorm"]) == 3 + assert ( + res["lorm"][0]["suggestion"], + res["lorm"][1]["suggestion"], + res["lorm"][2]["suggestion"], + ) == ("lorem", "lore", "lorm") + assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0") + + # test spellcheck exclude + res = await modclient.ft().spellcheck("lorm", exclude="dict") + assert res == {} + else: + # test spellcheck + res = await modclient.ft().spellcheck("impornant") + assert "important" in res["impornant"][0].keys() + + res = await modclient.ft().spellcheck("contnt") + assert "content" in res["contnt"][0].keys() + + # test spellcheck with Levenshtein distance + res = await modclient.ft().spellcheck("vlis") + assert res == {"vlis": []} + res = await modclient.ft().spellcheck("vlis", distance=2) + assert "valid" in res["vlis"][0].keys() + + # test spellcheck include + await modclient.ft().dict_add("dict", "lore", "lorem", "lorm") + res = await modclient.ft().spellcheck("lorm", include="dict") + assert len(res["lorm"]) == 3 + assert "lorem" in res["lorm"][0].keys() + assert "lore" in res["lorm"][1].keys() + assert "lorm" in res["lorm"][2].keys() + assert (res["lorm"][0]["lorem"], res["lorm"][1]["lore"]) == (0.5, 0) + + # test spellcheck exclude + res = await modclient.ft().spellcheck("lorm", exclude="dict") + assert res == {} @pytest.mark.redismod @@ -692,7 +943,7 @@ async def test_dict_operations(modclient: redis.Redis): # Dump dict and inspect content res = await modclient.ft().dict_dump("custom_dict") - assert ["item1", "item3"] == res + assert_resp_response(modclient, res, ["item1", "item3"], {"item1", "item3"}) # Remove rest of the items before reload await modclient.ft().dict_del("custom_dict", *res) @@ -705,8 +956,12 @@ async def test_phonetic_matcher(modclient: redis.Redis): await modclient.hset("doc2", mapping={"name": "John"}) res = await modclient.ft().search(Query("Jon")) - assert 1 == len(res.docs) - assert "Jon" == res.docs[0].name + if is_resp2_connection(modclient): + assert 1 == len(res.docs) + assert "Jon" == res.docs[0].name + else: + assert 1 == res["total_results"] + assert "Jon" == res["results"][0]["fields"]["name"] # Drop and create index with phonetic matcher await modclient.flushdb() @@ -716,8 +971,12 @@ async def test_phonetic_matcher(modclient: redis.Redis): await modclient.hset("doc2", mapping={"name": "John"}) res = await modclient.ft().search(Query("Jon")) - assert 2 == len(res.docs) - assert ["John", "Jon"] == sorted(d.name for d in res.docs) + if is_resp2_connection(modclient): + assert 2 == len(res.docs) + assert ["John", "Jon"] == sorted(d.name for d in res.docs) + else: + assert 2 == res["total_results"] + assert ["John", "Jon"] == sorted(d["fields"]["name"] for d in res["results"]) @pytest.mark.redismod @@ -735,23 +994,51 @@ async def test_scorer(modclient: redis.Redis): }, ) - # default scorer is TFIDF - res = await modclient.ft().search(Query("quick").with_scores()) - assert 1.0 == res.docs[0].score - res = await modclient.ft().search(Query("quick").scorer("TFIDF").with_scores()) - assert 1.0 == res.docs[0].score - res = await ( - modclient.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) - ) - assert 0.1111111111111111 == res.docs[0].score - res = await modclient.ft().search(Query("quick").scorer("BM25").with_scores()) - assert 0.17699114465425977 == res.docs[0].score - res = await modclient.ft().search(Query("quick").scorer("DISMAX").with_scores()) - assert 2.0 == res.docs[0].score - res = await modclient.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) - assert 1.0 == res.docs[0].score - res = await modclient.ft().search(Query("quick").scorer("HAMMING").with_scores()) - assert 0.0 == res.docs[0].score + if is_resp2_connection(modclient): + # default scorer is TFIDF + res = await modclient.ft().search(Query("quick").with_scores()) + assert 1.0 == res.docs[0].score + res = await modclient.ft().search(Query("quick").scorer("TFIDF").with_scores()) + assert 1.0 == res.docs[0].score + res = await ( + modclient.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) + ) + assert 0.1111111111111111 == res.docs[0].score + res = await modclient.ft().search(Query("quick").scorer("BM25").with_scores()) + assert 0.17699114465425977 == res.docs[0].score + res = await modclient.ft().search(Query("quick").scorer("DISMAX").with_scores()) + assert 2.0 == res.docs[0].score + res = await modclient.ft().search( + Query("quick").scorer("DOCSCORE").with_scores() + ) + assert 1.0 == res.docs[0].score + res = await modclient.ft().search( + Query("quick").scorer("HAMMING").with_scores() + ) + assert 0.0 == res.docs[0].score + else: + res = await modclient.ft().search(Query("quick").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = await modclient.ft().search(Query("quick").scorer("TFIDF").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = await modclient.ft().search( + Query("quick").scorer("TFIDF.DOCNORM").with_scores() + ) + assert 0.1111111111111111 == res["results"][0]["score"] + res = await modclient.ft().search( + Query("quick").scorer("BM25").with_scores() + ) + assert 0.17699114465425977 == res["results"][0]["score"] + res = await modclient.ft().search(Query("quick").scorer("DISMAX").with_scores()) + assert 2.0 == res["results"][0]["score"] + res = await modclient.ft().search( + Query("quick").scorer("DOCSCORE").with_scores() + ) + assert 1.0 == res["results"][0]["score"] + res = await modclient.ft().search( + Query("quick").scorer("HAMMING").with_scores() + ) + assert 0.0 == res["results"][0]["score"] @pytest.mark.redismod @@ -833,126 +1120,252 @@ async def test_aggregations_groupby(modclient: redis.Redis): ) for dialect in [1, 2]: - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.count()) - .dialect(dialect) - ) + if is_resp2_connection(modclient): + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count()) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = (await modclient.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3" - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.count_distinct("@title")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count_distinct("@title")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = (await modclient.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3" - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.count_distinctish("@title")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count_distinctish("@title")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = (await modclient.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3" - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.sum("@random_num")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.sum("@random_num")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "21" # 10+8+3 + res = (await modclient.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "21" # 10+8+3 - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.min("@random_num")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.min("@random_num")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3" # min(10,8,3) + res = (await modclient.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3" # min(10,8,3) - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.max("@random_num")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.max("@random_num")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "10" # max(10,8,3) + res = (await modclient.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "10" # max(10,8,3) - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.avg("@random_num")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.avg("@random_num")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "7" # (10+3+8)/3 + res = (await modclient.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "7" # (10+3+8)/3 - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.stddev("random_num")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.stddev("random_num")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3.60555127546" + res = (await modclient.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3.60555127546" - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.quantile("@random_num", 0.5)) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.quantile("@random_num", 0.5)) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "8" # median of 3,8,10 + res = (await modclient.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "8" # median of 3,8,10 - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.tolist("@title")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.tolist("@title")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"} + res = (await modclient.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"} - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.first_value("@title").alias("first")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.first_value("@title").alias("first")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res == ["parent", "redis", "first", "RediSearch"] + res = (await modclient.ft().aggregate(req)).rows[0] + assert res == ["parent", "redis", "first", "RediSearch"] - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.random_sample("@title", 2).alias("random")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.random_sample("@title", 2).alias("random")) + .dialect(dialect) + ) + + res = (await modclient.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[2] == "random" + assert len(res[3]) == 2 + assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"] + else: + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count()) + .dialect(dialect) + ) + + res = (await modclient.ft().aggregate(req))["results"][0] + assert res["fields"]["parent"] == "redis" + assert res["fields"]["__generated_aliascount"] == "3" + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count_distinct("@title")) + .dialect(dialect) + ) + + res = (await modclient.ft().aggregate(req))["results"][0] + assert res["fields"]["parent"] == "redis" + assert res["fields"]["__generated_aliascount_distincttitle"] == "3" - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[2] == "random" - assert len(res[3]) == 2 - assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"] + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count_distinctish("@title")) + .dialect(dialect) + ) + + res = (await modclient.ft().aggregate(req))["results"][0] + assert res["fields"]["parent"] == "redis" + assert res["fields"]["__generated_aliascount_distinctishtitle"] == "3" + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.sum("@random_num")) + .dialect(dialect) + ) + + res = (await modclient.ft().aggregate(req))["results"][0] + assert res["fields"]["parent"] == "redis" + assert res["fields"]["__generated_aliassumrandom_num"] == "21" # 10+8+3 + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.min("@random_num")) + .dialect(dialect) + ) + + res = (await modclient.ft().aggregate(req))["results"][0] + assert res["fields"]["parent"] == "redis" + assert res["fields"]["__generated_aliasminrandom_num"] == "3" # min(10,8,3) + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.max("@random_num")) + .dialect(dialect) + ) + + res = (await modclient.ft().aggregate(req))["results"][0] + assert res["fields"]["parent"] == "redis" + assert res["fields"]["__generated_aliasmaxrandom_num"] == "10" # max(10,8,3) + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.avg("@random_num")) + .dialect(dialect) + ) + + res = (await modclient.ft().aggregate(req))["results"][0] + assert res["fields"]["parent"] == "redis" + assert res["fields"]["__generated_aliasavgrandom_num"] == "7" # (10+3+8)/3 + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.stddev("random_num")) + .dialect(dialect) + ) + + res = (await modclient.ft().aggregate(req))["results"][0] + assert res["fields"]["parent"] == "redis" + assert res["fields"]["__generated_aliasstddevrandom_num"] == "3.60555127546" + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.quantile("@random_num", 0.5)) + .dialect(dialect) + ) + + res = (await modclient.ft().aggregate(req))["results"][0] + assert res["fields"]["parent"] == "redis" + assert res["fields"]["__generated_aliasquantilerandom_num,0.5"] == "8" + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.tolist("@title")) + .dialect(dialect) + ) + + res = (await modclient.ft().aggregate(req))["results"][0] + assert res["fields"]["parent"] == "redis" + assert set(res["fields"]["__generated_aliastolisttitle"]) == { + "RediSearch", + "RedisAI", + "RedisJson", + } + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.first_value("@title").alias("first")) + .dialect(dialect) + ) + + res = (await modclient.ft().aggregate(req))["results"][0] + assert res["fields"] == {"parent": "redis", "first": "RediSearch"} + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.random_sample("@title", 2).alias("random")) + .dialect(dialect) + ) + + res = (await modclient.ft().aggregate(req))["results"][0] + assert res["fields"]["parent"] == "redis" + assert "random" in res["fields"].keys() + assert len(res["fields"]["random"]) == 2 + assert res["fields"]["random"][0] in ["RediSearch", "RedisAI", "RedisJson"] @pytest.mark.redismod @@ -962,30 +1375,56 @@ async def test_aggregations_sort_by_and_limit(modclient: redis.Redis): await modclient.ft().client.hset("doc1", mapping={"t1": "a", "t2": "b"}) await modclient.ft().client.hset("doc2", mapping={"t1": "b", "t2": "a"}) - # test sort_by using SortDirection - req = aggregations.AggregateRequest("*").sort_by( - aggregations.Asc("@t2"), aggregations.Desc("@t1") - ) - res = await modclient.ft().aggregate(req) - assert res.rows[0] == ["t2", "a", "t1", "b"] - assert res.rows[1] == ["t2", "b", "t1", "a"] + if is_resp2_connection(modclient): + # test sort_by using SortDirection + req = aggregations.AggregateRequest("*").sort_by( + aggregations.Asc("@t2"), aggregations.Desc("@t1") + ) + res = await modclient.ft().aggregate(req) + assert res.rows[0] == ["t2", "a", "t1", "b"] + assert res.rows[1] == ["t2", "b", "t1", "a"] + + # test sort_by without SortDirection + req = aggregations.AggregateRequest("*").sort_by("@t1") + res = await modclient.ft().aggregate(req) + assert res.rows[0] == ["t1", "a"] + assert res.rows[1] == ["t1", "b"] + + # test sort_by with max + req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) + res = await modclient.ft().aggregate(req) + assert len(res.rows) == 1 + + # test limit + req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) + res = await modclient.ft().aggregate(req) + assert len(res.rows) == 1 + assert res.rows[0] == ["t1", "b"] + else: + # test sort_by using SortDirection + req = aggregations.AggregateRequest("*").sort_by( + aggregations.Asc("@t2"), aggregations.Desc("@t1") + ) + res = (await modclient.ft().aggregate(req))["results"] + assert res[0]["fields"] == {"t2": "a", "t1": "b"} + assert res[1]["fields"] == {"t2": "b", "t1": "a"} - # test sort_by without SortDirection - req = aggregations.AggregateRequest("*").sort_by("@t1") - res = await modclient.ft().aggregate(req) - assert res.rows[0] == ["t1", "a"] - assert res.rows[1] == ["t1", "b"] + # test sort_by without SortDirection + req = aggregations.AggregateRequest("*").sort_by("@t1") + res = (await modclient.ft().aggregate(req))["results"] + assert res[0]["fields"] == {"t1": "a"} + assert res[1]["fields"] == {"t1": "b"} - # test sort_by with max - req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) - res = await modclient.ft().aggregate(req) - assert len(res.rows) == 1 + # test sort_by with max + req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) + res = await modclient.ft().aggregate(req) + assert len(res["results"]) == 1 - # test limit - req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) - res = await modclient.ft().aggregate(req) - assert len(res.rows) == 1 - assert res.rows[0] == ["t1", "b"] + # test limit + req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) + res = await modclient.ft().aggregate(req) + assert len(res["results"]) == 1 + assert res["results"][0]["fields"] == {"t1": "b"} @pytest.mark.redismod @@ -994,22 +1433,40 @@ async def test_withsuffixtrie(modclient: redis.Redis): # create index assert await modclient.ft().create_index((TextField("txt"),)) await waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) - info = await modclient.ft().info() - assert "WITHSUFFIXTRIE" not in info["attributes"][0] - assert await modclient.ft().dropindex("idx") - - # create withsuffixtrie index (text field) - assert await modclient.ft().create_index((TextField("t", withsuffixtrie=True))) - await waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) - info = await modclient.ft().info() - assert "WITHSUFFIXTRIE" in info["attributes"][0] - assert await modclient.ft().dropindex("idx") - - # create withsuffixtrie index (tag field) - assert await modclient.ft().create_index((TagField("t", withsuffixtrie=True))) - await waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) - info = await modclient.ft().info() - assert "WITHSUFFIXTRIE" in info["attributes"][0] + if is_resp2_connection(modclient): + info = await modclient.ft().info() + assert "WITHSUFFIXTRIE" not in info["attributes"][0] + assert await modclient.ft().dropindex("idx") + + # create withsuffixtrie index (text field) + assert await modclient.ft().create_index((TextField("t", withsuffixtrie=True))) + await waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) + info = await modclient.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0] + assert await modclient.ft().dropindex("idx") + + # create withsuffixtrie index (tag field) + assert await modclient.ft().create_index((TagField("t", withsuffixtrie=True))) + await waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) + info = await modclient.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0] + else: + info = await modclient.ft().info() + assert "WITHSUFFIXTRIE" not in info["attributes"][0]["flags"] + assert await modclient.ft().dropindex("idx") + + # create withsuffixtrie index (text fiels) + assert await modclient.ft().create_index((TextField("t", withsuffixtrie=True))) + waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) + info = await modclient.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] + assert await modclient.ft().dropindex("idx") + + # create withsuffixtrie index (tag field) + assert await modclient.ft().create_index((TagField("t", withsuffixtrie=True))) + waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) + info = await modclient.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] @pytest.mark.redismod @@ -1022,12 +1479,24 @@ async def test_search_commands_in_pipeline(modclient: redis.Redis): q = Query("foo bar").with_payloads() await p.search(q) res = await p.execute() - assert res[:3] == ["OK", True, True] - assert 2 == res[3][0] - assert "doc1" == res[3][1] - assert "doc2" == res[3][4] - assert res[3][5] is None - assert res[3][3] == res[3][6] == ["txt", "foo bar"] + if is_resp2_connection(modclient): + assert res[:3] == ["OK", True, True] + assert 2 == res[3][0] + assert "doc1" == res[3][1] + assert "doc2" == res[3][4] + assert res[3][5] is None + assert res[3][3] == res[3][6] == ["txt", "foo bar"] + else: + assert res[:3] == ["OK", True, True] + assert 2 == res[3]["total_results"] + assert "doc1" == res[3]["results"][0]["id"] + assert "doc2" == res[3]["results"][1]["id"] + assert res[3]["results"][0]["payload"] is None + assert ( + res[3]["results"][0]["fields"] + == res[3]["results"][1]["fields"] + == {"txt": "foo bar"} + ) @pytest.mark.redismod diff --git a/tests/test_asyncio/test_timeseries.py b/tests/test_asyncio/test_timeseries.py index a7109938f2..f2580b7f97 100644 --- a/tests/test_asyncio/test_timeseries.py +++ b/tests/test_asyncio/test_timeseries.py @@ -4,7 +4,7 @@ import pytest import redis.asyncio as redis -from tests.conftest import skip_ifmodversion_lt +from tests.conftest import assert_resp_response, is_resp2_connection, skip_ifmodversion_lt @pytest.mark.redismod @@ -14,13 +14,15 @@ async def test_create(modclient: redis.Redis): assert await modclient.ts().create(3, labels={"Redis": "Labs"}) assert await modclient.ts().create(4, retention_msecs=20, labels={"Time": "Series"}) info = await modclient.ts().info(4) - assert 20 == info.retention_msecs - assert "Series" == info.labels["Time"] + assert_resp_response( + modclient, 20, info.get("retention_msecs"), info.get("retentionTime") + ) + assert "Series" == info["labels"]["Time"] # Test for a chunk size of 128 Bytes assert await modclient.ts().create("time-serie-1", chunk_size=128) info = await modclient.ts().info("time-serie-1") - assert 128, info.chunk_size + assert_resp_response(modclient, 128, info.get("chunk_size"), info.get("chunkSize")) @pytest.mark.redismod @@ -31,24 +33,35 @@ async def test_create_duplicate_policy(modclient: redis.Redis): ts_name = f"time-serie-ooo-{duplicate_policy}" assert await modclient.ts().create(ts_name, duplicate_policy=duplicate_policy) info = await modclient.ts().info(ts_name) - assert duplicate_policy == info.duplicate_policy + assert_resp_response( + modclient, + duplicate_policy, + info.get("duplicate_policy"), + info.get("duplicatePolicy"), + ) @pytest.mark.redismod async def test_alter(modclient: redis.Redis): assert await modclient.ts().create(1) res = await modclient.ts().info(1) - assert 0 == res.retention_msecs + assert_resp_response( + modclient, 0, res.get("retention_msecs"), res.get("retentionTime") + ) assert await modclient.ts().alter(1, retention_msecs=10) res = await modclient.ts().info(1) - assert {} == res.labels - res = await modclient.ts().info(1) - assert 10 == res.retention_msecs + assert {} == (await modclient.ts().info(1))["labels"] + info = await modclient.ts().info(1) + assert_resp_response( + modclient, 10, info.get("retention_msecs"), info.get("retentionTime") + ) assert await modclient.ts().alter(1, labels={"Time": "Series"}) res = await modclient.ts().info(1) - assert "Series" == res.labels["Time"] - res = await modclient.ts().info(1) - assert 10 == res.retention_msecs + assert "Series" == (await modclient.ts().info(1))["labels"]["Time"] + info = await modclient.ts().info(1) + assert_resp_response( + modclient, 10, info.get("retention_msecs"), info.get("retentionTime") + ) @pytest.mark.redismod @@ -56,10 +69,14 @@ async def test_alter(modclient: redis.Redis): async def test_alter_diplicate_policy(modclient: redis.Redis): assert await modclient.ts().create(1) info = await modclient.ts().info(1) - assert info.duplicate_policy is None + assert_resp_response( + modclient, None, info.get("duplicate_policy"), info.get("duplicatePolicy") + ) assert await modclient.ts().alter(1, duplicate_policy="min") info = await modclient.ts().info(1) - assert "min" == info.duplicate_policy + assert_resp_response( + modclient, "min", info.get("duplicate_policy"), info.get("duplicatePolicy") + ) @pytest.mark.redismod @@ -74,13 +91,15 @@ async def test_add(modclient: redis.Redis): assert abs(time.time() - round(float(res) / 1000)) < 1.0 info = await modclient.ts().info(4) - assert 10 == info.retention_msecs - assert "Labs" == info.labels["Redis"] + assert_resp_response( + modclient, 10, info.get("retention_msecs"), info.get("retentionTime") + ) + assert "Labs" == info["labels"]["Redis"] # Test for a chunk size of 128 Bytes on TS.ADD assert await modclient.ts().add("time-serie-1", 1, 10.0, chunk_size=128) info = await modclient.ts().info("time-serie-1") - assert 128 == info.chunk_size + assert_resp_response(modclient, 128, info.get("chunk_size"), info.get("chunkSize")) @pytest.mark.redismod @@ -147,21 +166,21 @@ async def test_incrby_decrby(modclient: redis.Redis): assert 0 == (await modclient.ts().get(1))[1] assert await modclient.ts().incrby(2, 1.5, timestamp=5) - assert (5, 1.5) == await modclient.ts().get(2) + assert_resp_response(modclient, await modclient.ts().get(2), (5, 1.5), [5, 1.5]) assert await modclient.ts().incrby(2, 2.25, timestamp=7) - assert (7, 3.75) == await modclient.ts().get(2) + assert_resp_response(modclient, await modclient.ts().get(2), (7, 3.75), [7, 3.75]) assert await modclient.ts().decrby(2, 1.5, timestamp=15) - assert (15, 2.25) == await modclient.ts().get(2) + assert_resp_response(modclient, await modclient.ts().get(2), (15, 2.25), [15, 2.25]) # Test for a chunk size of 128 Bytes on TS.INCRBY assert await modclient.ts().incrby("time-serie-1", 10, chunk_size=128) info = await modclient.ts().info("time-serie-1") - assert 128 == info.chunk_size + assert_resp_response(modclient, 128, info.get("chunk_size"), info.get("chunkSize")) # Test for a chunk size of 128 Bytes on TS.DECRBY assert await modclient.ts().decrby("time-serie-2", 10, chunk_size=128) info = await modclient.ts().info("time-serie-2") - assert 128 == info.chunk_size + assert_resp_response(modclient, 128, info.get("chunk_size"), info.get("chunkSize")) @pytest.mark.redismod @@ -177,12 +196,15 @@ async def test_create_and_delete_rule(modclient: redis.Redis): await modclient.ts().add(1, time * 2, 1.5) assert round((await modclient.ts().get(2))[1], 5) == 1.5 info = await modclient.ts().info(1) - assert info.rules[0][1] == 100 + if is_resp2_connection(modclient): + assert info.rules[0][1] == 100 + else: + assert info["rules"]["2"][0] == 100 # test rule deletion await modclient.ts().deleterule(1, 2) info = await modclient.ts().info(1) - assert not info.rules + assert not info["rules"] @pytest.mark.redismod @@ -197,7 +219,7 @@ async def test_del_range(modclient: redis.Redis): await modclient.ts().add(1, i, i % 7) assert 22 == await modclient.ts().delete(1, 0, 21) assert [] == await modclient.ts().range(1, 0, 21) - assert [(22, 1.0)] == await modclient.ts().range(1, 22, 22) + assert_resp_response(modclient, await modclient.ts().range(1, 22, 22), [(22, 1.0)], [[22, 1.0]]) @pytest.mark.redismod @@ -234,15 +256,16 @@ async def test_range_advanced(modclient: redis.Redis): filter_by_max_value=2, ) ) - assert [(0, 10.0), (10, 1.0)] == await modclient.ts().range( + res = await modclient.ts().range( 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" ) - assert [(0, 5.0), (5, 6.0)] == await modclient.ts().range( + assert_resp_response(modclient, res, [(0, 10.0), (10, 1.0)], [[0, 10.0], [10, 1.0]]) + res = await modclient.ts().range( 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=5 ) - assert [(0, 2.55), (10, 3.0)] == await modclient.ts().range( - 1, 0, 10, aggregation_type="twa", bucket_size_msec=10 - ) + assert_resp_response(modclient, res, [(0, 5.0), (5, 6.0)], [[0, 5.0], [5, 6.0]]) + res = await modclient.ts().range(1, 0, 10, aggregation_type="twa", bucket_size_msec=10) + assert_resp_response(modclient, res, [(0, 2.55), (10, 3.0)], [[0, 2.55], [10, 3.0]]) @pytest.mark.redismod @@ -271,17 +294,27 @@ async def test_rev_range(modclient: redis.Redis): filter_by_max_value=2, ) ) - assert [(10, 1.0), (0, 10.0)] == await modclient.ts().revrange( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" + assert_resp_response( + modclient, + await modclient.ts().revrange( + 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" + ), + [(10, 1.0), (0, 10.0)], + [[10, 1.0], [0, 10.0]], ) - assert [(1, 10.0), (0, 1.0)] == await modclient.ts().revrange( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=1 + assert_resp_response( + modclient, + await modclient.ts().revrange( + 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=1 + ), + [(1, 10.0), (0, 1.0)], + [[1, 10.0], [0, 1.0]], ) @pytest.mark.redismod @pytest.mark.onlynoncluster -async def testMultiRange(modclient: redis.Redis): +async def test_multi_range(modclient: redis.Redis): await modclient.ts().create(1, labels={"Test": "This", "team": "ny"}) await modclient.ts().create( 2, labels={"Test": "This", "Taste": "That", "team": "sf"} @@ -292,23 +325,42 @@ async def testMultiRange(modclient: redis.Redis): res = await modclient.ts().mrange(0, 200, filters=["Test=This"]) assert 2 == len(res) - assert 100 == len(res[0]["1"][1]) + if is_resp2_connection(modclient): + assert 100 == len(res[0]["1"][1]) - res = await modclient.ts().mrange(0, 200, filters=["Test=This"], count=10) - assert 10 == len(res[0]["1"][1]) + res = await modclient.ts().mrange(0, 200, filters=["Test=This"], count=10) + assert 10 == len(res[0]["1"][1]) - for i in range(100): - await modclient.ts().add(1, i + 200, i % 7) - res = await modclient.ts().mrange( - 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 - ) - assert 2 == len(res) - assert 20 == len(res[0]["1"][1]) + for i in range(100): + await modclient.ts().add(1, i + 200, i % 7) + res = await modclient.ts().mrange( + 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 + ) + assert 2 == len(res) + assert 20 == len(res[0]["1"][1]) + + # test withlabels + assert {} == res[0]["1"][0] + res = await modclient.ts().mrange(0, 200, filters=["Test=This"], with_labels=True) + assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + else: + assert 100 == len(res["1"][2]) + + res = await modclient.ts().mrange(0, 200, filters=["Test=This"], count=10) + assert 10 == len(res["1"][2]) + + for i in range(100): + await modclient.ts().add(1, i + 200, i % 7) + res = await modclient.ts().mrange( + 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 + ) + assert 2 == len(res) + assert 20 == len(res["1"][2]) - # test withlabels - assert {} == res[0]["1"][0] - res = await modclient.ts().mrange(0, 200, filters=["Test=This"], with_labels=True) - assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + # test withlabels + assert {} == res["1"][0] + res = await modclient.ts().mrange(0, 200, filters=["Test=This"], with_labels=True) + assert {"Test": "This", "team": "ny"} == res["1"][0] @pytest.mark.redismod @@ -327,55 +379,106 @@ async def test_multi_range_advanced(modclient: redis.Redis): res = await modclient.ts().mrange( 0, 200, filters=["Test=This"], select_labels=["team"] ) - assert {"team": "ny"} == res[0]["1"][0] - assert {"team": "sf"} == res[1]["2"][0] + if is_resp2_connection(modclient): + assert {"team": "ny"} == res[0]["1"][0] + assert {"team": "sf"} == res[1]["2"][0] - # test with filterby - res = await modclient.ts().mrange( - 0, - 200, - filters=["Test=This"], - filter_by_ts=[i for i in range(10, 20)], - filter_by_min_value=1, - filter_by_max_value=2, - ) - assert [(15, 1.0), (16, 2.0)] == res[0]["1"][1] + # test with filterby + res = await modclient.ts().mrange( + 0, + 200, + filters=["Test=This"], + filter_by_ts=[i for i in range(10, 20)], + filter_by_min_value=1, + filter_by_max_value=2, + ) + assert [(15, 1.0), (16, 2.0)] == res[0]["1"][1] - # test groupby - res = await modclient.ts().mrange( - 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" - ) - assert [(0, 0.0), (1, 2.0), (2, 4.0), (3, 6.0)] == res[0]["Test=This"][1] - res = await modclient.ts().mrange( - 0, 3, filters=["Test=This"], groupby="Test", reduce="max" - ) - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["Test=This"][1] - res = await modclient.ts().mrange( - 0, 3, filters=["Test=This"], groupby="team", reduce="min" - ) - assert 2 == len(res) - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["team=ny"][1] - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[1]["team=sf"][1] + # test groupby + res = await modclient.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" + ) + assert [(0, 0.0), (1, 2.0), (2, 4.0), (3, 6.0)] == res[0]["Test=This"][1] + res = await modclient.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="max" + ) + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["Test=This"][1] + res = await modclient.ts().mrange( + 0, 3, filters=["Test=This"], groupby="team", reduce="min" + ) + assert 2 == len(res) + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["team=ny"][1] + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[1]["team=sf"][1] - # test align - res = await modclient.ts().mrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align="-", - ) - assert [(0, 10.0), (10, 1.0)] == res[0]["1"][1] - res = await modclient.ts().mrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align=5, - ) - assert [(0, 5.0), (5, 6.0)] == res[0]["1"][1] + # test align + res = await modclient.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align="-", + ) + assert [(0, 10.0), (10, 1.0)] == res[0]["1"][1] + res = await modclient.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align=5, + ) + assert [(0, 5.0), (5, 6.0)] == res[0]["1"][1] + else: + assert {"team": "ny"} == res["1"][0] + assert {"team": "sf"} == res["2"][0] + + # test with filterby + res = await modclient.ts().mrange( + 0, + 200, + filters=["Test=This"], + filter_by_ts=[i for i in range(10, 20)], + filter_by_min_value=1, + filter_by_max_value=2, + ) + assert [[15, 1.0], [16, 2.0]] == res["1"][2] + + # test groupby + res = await modclient.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" + ) + assert [[0, 0.0], [1, 2.0], [2, 4.0], [3, 6.0]] == res["Test=This"][3] + res = await modclient.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="max" + ) + assert [[0, 0.0], [1, 1.0], [2, 2.0], [3, 3.0]] == res["Test=This"][3] + res = await modclient.ts().mrange( + 0, 3, filters=["Test=This"], groupby="team", reduce="min" + ) + assert 2 == len(res) + assert [[0, 0.0], [1, 1.0], [2, 2.0], [3, 3.0]] == res["team=ny"][3] + assert [[0, 0.0], [1, 1.0], [2, 2.0], [3, 3.0]] == res["team=sf"][3] + + # test align + res = await modclient.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align="-", + ) + assert [[0, 10.0], [10, 1.0]] == res["1"][2] + res = await modclient.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align=5, + ) + assert [[0, 5.0], [5, 6.0]] == res["1"][2] @pytest.mark.redismod @@ -392,86 +495,161 @@ async def test_multi_reverse_range(modclient: redis.Redis): res = await modclient.ts().mrange(0, 200, filters=["Test=This"]) assert 2 == len(res) - assert 100 == len(res[0]["1"][1]) + if is_resp2_connection(modclient): + assert 100 == len(res[0]["1"][1]) - res = await modclient.ts().mrange(0, 200, filters=["Test=This"], count=10) - assert 10 == len(res[0]["1"][1]) + res = await modclient.ts().mrange(0, 200, filters=["Test=This"], count=10) + assert 10 == len(res[0]["1"][1]) - for i in range(100): - await modclient.ts().add(1, i + 200, i % 7) - res = await modclient.ts().mrevrange( - 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 - ) - assert 2 == len(res) - assert 20 == len(res[0]["1"][1]) - assert {} == res[0]["1"][0] + for i in range(100): + await modclient.ts().add(1, i + 200, i % 7) + res = await modclient.ts().mrevrange( + 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 + ) + assert 2 == len(res) + assert 20 == len(res[0]["1"][1]) + assert {} == res[0]["1"][0] - # test withlabels - res = await modclient.ts().mrevrange( - 0, 200, filters=["Test=This"], with_labels=True - ) - assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + # test withlabels + res = await modclient.ts().mrevrange( + 0, 200, filters=["Test=This"], with_labels=True + ) + assert {"Test": "This", "team": "ny"} == res[0]["1"][0] - # test with selected labels - res = await modclient.ts().mrevrange( - 0, 200, filters=["Test=This"], select_labels=["team"] - ) - assert {"team": "ny"} == res[0]["1"][0] - assert {"team": "sf"} == res[1]["2"][0] - - # test filterby - res = await modclient.ts().mrevrange( - 0, - 200, - filters=["Test=This"], - filter_by_ts=[i for i in range(10, 20)], - filter_by_min_value=1, - filter_by_max_value=2, - ) - assert [(16, 2.0), (15, 1.0)] == res[0]["1"][1] + # test with selected labels + res = await modclient.ts().mrevrange( + 0, 200, filters=["Test=This"], select_labels=["team"] + ) + assert {"team": "ny"} == res[0]["1"][0] + assert {"team": "sf"} == res[1]["2"][0] - # test groupby - res = await modclient.ts().mrevrange( - 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" - ) - assert [(3, 6.0), (2, 4.0), (1, 2.0), (0, 0.0)] == res[0]["Test=This"][1] - res = await modclient.ts().mrevrange( - 0, 3, filters=["Test=This"], groupby="Test", reduce="max" - ) - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["Test=This"][1] - res = await modclient.ts().mrevrange( - 0, 3, filters=["Test=This"], groupby="team", reduce="min" - ) - assert 2 == len(res) - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["team=ny"][1] - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[1]["team=sf"][1] - - # test align - res = await modclient.ts().mrevrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align="-", - ) - assert [(10, 1.0), (0, 10.0)] == res[0]["1"][1] - res = await modclient.ts().mrevrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align=1, - ) - assert [(1, 10.0), (0, 1.0)] == res[0]["1"][1] + # test filterby + res = await modclient.ts().mrevrange( + 0, + 200, + filters=["Test=This"], + filter_by_ts=[i for i in range(10, 20)], + filter_by_min_value=1, + filter_by_max_value=2, + ) + assert [(16, 2.0), (15, 1.0)] == res[0]["1"][1] + + # test groupby + res = await modclient.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" + ) + assert [(3, 6.0), (2, 4.0), (1, 2.0), (0, 0.0)] == res[0]["Test=This"][1] + res = await modclient.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="max" + ) + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["Test=This"][1] + res = await modclient.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="team", reduce="min" + ) + assert 2 == len(res) + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["team=ny"][1] + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[1]["team=sf"][1] + + # test align + res = await modclient.ts().mrevrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align="-", + ) + assert [(10, 1.0), (0, 10.0)] == res[0]["1"][1] + res = await modclient.ts().mrevrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align=1, + ) + assert [(1, 10.0), (0, 1.0)] == res[0]["1"][1] + else: + assert 100 == len(res["1"][2]) + + res = await modclient.ts().mrange(0, 200, filters=["Test=This"], count=10) + assert 10 == len(res["1"][2]) + + for i in range(100): + await modclient.ts().add(1, i + 200, i % 7) + res = await modclient.ts().mrevrange( + 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 + ) + assert 2 == len(res) + assert 20 == len(res["1"][2]) + assert {} == res["1"][0] + + # test withlabels + res = await modclient.ts().mrevrange( + 0, 200, filters=["Test=This"], with_labels=True + ) + assert {"Test": "This", "team": "ny"} == res["1"][0] + + # test with selected labels + res = await modclient.ts().mrevrange( + 0, 200, filters=["Test=This"], select_labels=["team"] + ) + assert {"team": "ny"} == res["1"][0] + assert {"team": "sf"} == res["2"][0] + + # test filterby + res = await modclient.ts().mrevrange( + 0, + 200, + filters=["Test=This"], + filter_by_ts=[i for i in range(10, 20)], + filter_by_min_value=1, + filter_by_max_value=2, + ) + assert [[16, 2.0], [15, 1.0]] == res["1"][2] + + # test groupby + res = await modclient.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" + ) + assert [[3, 6.0], [2, 4.0], [1, 2.0], [0, 0.0]] == res["Test=This"][3] + res = await modclient.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="max" + ) + assert [[3, 3.0], [2, 2.0], [1, 1.0], [0, 0.0]] == res["Test=This"][3] + res = await modclient.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="team", reduce="min" + ) + assert 2 == len(res) + assert [[3, 3.0], [2, 2.0], [1, 1.0], [0, 0.0]] == res["team=ny"][3] + assert [[3, 3.0], [2, 2.0], [1, 1.0], [0, 0.0]] == res["team=sf"][3] + + # test align + res = await modclient.ts().mrevrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align="-", + ) + assert [[10, 1.0], [0, 10.0]] == res["1"][2] + res = await modclient.ts().mrevrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align=1, + ) + assert [[1, 10.0], [0, 1.0]] == res["1"][2] @pytest.mark.redismod async def test_get(modclient: redis.Redis): name = "test" await modclient.ts().create(name) - assert await modclient.ts().get(name) is None + assert not await modclient.ts().get(name) await modclient.ts().add(name, 2, 3) assert 2 == (await modclient.ts().get(name))[0] await modclient.ts().add(name, 3, 4) @@ -485,19 +663,33 @@ async def test_mget(modclient: redis.Redis): await modclient.ts().create(2, labels={"Test": "This", "Taste": "That"}) act_res = await modclient.ts().mget(["Test=This"]) exp_res = [{"1": [{}, None, None]}, {"2": [{}, None, None]}] - assert act_res == exp_res + exp_res_resp3 = {"1": [{}, []], "2": [{}, []]} + assert_resp_response(modclient, act_res, exp_res, exp_res_resp3) await modclient.ts().add(1, "*", 15) await modclient.ts().add(2, "*", 25) res = await modclient.ts().mget(["Test=This"]) - assert 15 == res[0]["1"][2] - assert 25 == res[1]["2"][2] + if is_resp2_connection(modclient): + assert 15 == res[0]["1"][2] + assert 25 == res[1]["2"][2] + else: + assert 15 == res["1"][1][1] + assert 25 == res["2"][1][1] res = await modclient.ts().mget(["Taste=That"]) - assert 25 == res[0]["2"][2] + if is_resp2_connection(modclient): + assert 25 == res[0]["2"][2] + else: + assert 25 == res["2"][1][1] # test with_labels - assert {} == res[0]["2"][0] + if is_resp2_connection(modclient): + assert {} == res[0]["2"][0] + else: + assert {} == res["2"][0] res = await modclient.ts().mget(["Taste=That"], with_labels=True) - assert {"Taste": "That", "Test": "This"} == res[0]["2"][0] + if is_resp2_connection(modclient): + assert {"Taste": "That", "Test": "This"} == res[0]["2"][0] + else: + assert {"Taste": "That", "Test": "This"} == res["2"][0] @pytest.mark.redismod @@ -506,8 +698,10 @@ async def test_info(modclient: redis.Redis): 1, retention_msecs=5, labels={"currentLabel": "currentData"} ) info = await modclient.ts().info(1) - assert 5 == info.retention_msecs - assert info.labels["currentLabel"] == "currentData" + assert_resp_response( + modclient, 5, info.get("retention_msecs"), info.get("retentionTime") + ) + assert info["labels"]["currentLabel"] == "currentData" @pytest.mark.redismod @@ -517,11 +711,15 @@ async def testInfoDuplicatePolicy(modclient: redis.Redis): 1, retention_msecs=5, labels={"currentLabel": "currentData"} ) info = await modclient.ts().info(1) - assert info.duplicate_policy is None + assert_resp_response( + modclient, None, info.get("duplicate_policy"), info.get("duplicatePolicy") + ) await modclient.ts().create("time-serie-2", duplicate_policy="min") info = await modclient.ts().info("time-serie-2") - assert "min" == info.duplicate_policy + assert_resp_response( + modclient, "min", info.get("duplicate_policy"), info.get("duplicatePolicy") + ) @pytest.mark.redismod @@ -531,7 +729,9 @@ async def test_query_index(modclient: redis.Redis): await modclient.ts().create(2, labels={"Test": "This", "Taste": "That"}) assert 2 == len(await modclient.ts().queryindex(["Test=This"])) assert 1 == len(await modclient.ts().queryindex(["Taste=That"])) - assert [2] == await modclient.ts().queryindex(["Taste=That"]) + assert_resp_response( + modclient, await modclient.ts().queryindex(["Taste=That"]), [2], {"2"} + ) # @pytest.mark.redismod @@ -554,4 +754,7 @@ async def test_uncompressed(modclient: redis.Redis): await modclient.ts().create("uncompressed", uncompressed=True) compressed_info = await modclient.ts().info("compressed") uncompressed_info = await modclient.ts().info("uncompressed") - assert compressed_info.memory_usage != uncompressed_info.memory_usage + if is_resp2_connection(modclient): + assert compressed_info.memory_usage != uncompressed_info.memory_usage + else: + assert compressed_info["memoryUsage"] != uncompressed_info["memoryUsage"] From bb4a0f934b966a97d7bc2574143f4ff1c4bc943e Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 15 Jun 2023 16:46:21 +0300 Subject: [PATCH 09/10] linters --- tests/test_asyncio/test_bloom.py | 6 +- tests/test_asyncio/test_json.py | 29 +++------ tests/test_asyncio/test_search.py | 86 +++++++++++++++++---------- tests/test_asyncio/test_timeseries.py | 22 +++++-- 4 files changed, 82 insertions(+), 61 deletions(-) diff --git a/tests/test_asyncio/test_bloom.py b/tests/test_asyncio/test_bloom.py index 0c9a933b12..bb1f0d58ad 100644 --- a/tests/test_asyncio/test_bloom.py +++ b/tests/test_asyncio/test_bloom.py @@ -5,7 +5,11 @@ import redis.asyncio as redis from redis.exceptions import ModuleError, RedisError from redis.utils import HIREDIS_AVAILABLE -from tests.conftest import assert_resp_response, is_resp2_connection, skip_ifmodversion_lt +from tests.conftest import ( + assert_resp_response, + is_resp2_connection, + skip_ifmodversion_lt, +) def intlist(obj): diff --git a/tests/test_asyncio/test_json.py b/tests/test_asyncio/test_json.py index 58c0601ea7..551e307805 100644 --- a/tests/test_asyncio/test_json.py +++ b/tests/test_asyncio/test_json.py @@ -106,13 +106,9 @@ async def test_numincrby(modclient): modclient, await modclient.json().numincrby("num", Path.root_path(), 1), 2, [2] ) res = await modclient.json().numincrby("num", Path.root_path(), 0.5) - assert_resp_response( - modclient, res, 2.5, [2.5] - ) + assert_resp_response(modclient, res, 2.5, [2.5]) res = await modclient.json().numincrby("num", Path.root_path(), -1.25) - assert_resp_response( - modclient, res, 1.25, [1.25] - ) + assert_resp_response(modclient, res, 1.25, [1.25]) @pytest.mark.redismod @@ -121,17 +117,11 @@ async def test_nummultby(modclient: redis.Redis): with pytest.deprecated_call(): res = await modclient.json().nummultby("num", Path.root_path(), 2) - assert_resp_response( - modclient, res, 2, [2] - ) + assert_resp_response(modclient, res, 2, [2]) res = await modclient.json().nummultby("num", Path.root_path(), 2.5) - assert_resp_response( - modclient, res, 5, [5] - ) + assert_resp_response(modclient, res, 5, [5]) res = await modclient.json().nummultby("num", Path.root_path(), 0.5) - assert_resp_response( - modclient, res, 2.5, [2.5] - ) + assert_resp_response(modclient, res, 2.5, [2.5]) @pytest.mark.redismod @@ -151,9 +141,7 @@ async def test_strappend(modclient: redis.Redis): await modclient.json().set("jsonkey", Path.root_path(), "foo") assert 6 == await modclient.json().strappend("jsonkey", "bar") res = await modclient.json().get("jsonkey", Path.root_path()) - assert_resp_response( - modclient, res, "foobar", [["foobar"]] - ) + assert_resp_response(modclient, res, "foobar", [["foobar"]]) @pytest.mark.redismod @@ -316,7 +304,6 @@ async def test_json_delete_with_dollar(modclient: redis.Redis): doc1 = {"a": 1, "nested": {"a": 2, "b": 3}} assert await modclient.json().set("doc1", "$", doc1) assert await modclient.json().delete("doc1", "$..a") == 2 - r = await modclient.json().get("doc1", "$") res = [{"nested": {"b": 3}}] assert_resp_response(modclient, await modclient.json().get("doc1", "$"), res, [res]) @@ -917,9 +904,7 @@ async def test_type_dollar(modclient: redis.Redis): # Test single res = await modclient.json().type("doc1", "$.nested2.a") - assert_resp_response( - modclient, res, [jtypes[1]], [[jtypes[1]]] - ) + assert_resp_response(modclient, res, [jtypes[1]], [[jtypes[1]]]) # Test missing key assert_resp_response( diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 1e83efae66..599631bfc9 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -149,7 +149,9 @@ async def test_client(modclient: redis.Redis): # test verbatim vs no verbatim total = (await modclient.ft().search(Query("kings").no_content())).total - vtotal = (await modclient.ft().search(Query("kings").no_content().verbatim())).total + vtotal = ( + await modclient.ft().search(Query("kings").no_content().verbatim()) + ).total assert total > vtotal # test in fields @@ -157,7 +159,9 @@ async def test_client(modclient: redis.Redis): await modclient.ft().search(Query("henry").no_content().limit_fields("txt")) ).total play_total = ( - await modclient.ft().search(Query("henry").no_content().limit_fields("play")) + await modclient.ft().search( + Query("henry").no_content().limit_fields("play") + ) ).total both_total = ( await ( @@ -189,11 +193,16 @@ async def test_client(modclient: redis.Redis): # test slop and in order assert 193 == (await modclient.ft().search(Query("henry king"))).total assert ( - 3 == (await modclient.ft().search(Query("henry king").slop(0).in_order())).total + 3 + == ( + await modclient.ft().search(Query("henry king").slop(0).in_order()) + ).total ) assert ( 52 - == (await modclient.ft().search(Query("king henry").slop(0).in_order())).total + == ( + await modclient.ft().search(Query("king henry").slop(0).in_order()) + ).total ) assert 53 == (await modclient.ft().search(Query("henry king").slop(0))).total assert 167 == (await modclient.ft().search(Query("henry king").slop(100))).total @@ -230,24 +239,28 @@ async def test_client(modclient: redis.Redis): assert "fields" not in doc.keys() # test verbatim vs no verbatim - total = (await modclient.ft().search( - Query("kings").no_content() - ))["total_results"] + total = (await modclient.ft().search(Query("kings").no_content()))[ + "total_results" + ] vtotal = (await modclient.ft().search(Query("kings").no_content().verbatim()))[ "total_results" ] assert total > vtotal # test in fields - txt_total = (await modclient.ft().search( - Query("henry").no_content().limit_fields("txt") - ))["total_results"] - play_total = (await modclient.ft().search( - Query("henry").no_content().limit_fields("play") - ))["total_results"] - both_total = (await modclient.ft().search( - Query("henry").no_content().limit_fields("play", "txt") - ))["total_results"] + txt_total = ( + await modclient.ft().search(Query("henry").no_content().limit_fields("txt")) + )["total_results"] + play_total = ( + await modclient.ft().search( + Query("henry").no_content().limit_fields("play") + ) + )["total_results"] + both_total = ( + await modclient.ft().search( + Query("henry").no_content().limit_fields("play", "txt") + ) + )["total_results"] assert 129 == txt_total assert 494 == play_total assert 494 == both_total @@ -271,9 +284,9 @@ async def test_client(modclient: redis.Redis): assert set(ids) == set(subset) # test slop and in order - assert 193 == ( - await modclient.ft().search(Query("henry king")) - )["total_results"] + assert ( + 193 == (await modclient.ft().search(Query("henry king")))["total_results"] + ) assert ( 3 == (await modclient.ft().search(Query("henry king").slop(0).in_order()))[ @@ -286,12 +299,18 @@ async def test_client(modclient: redis.Redis): "total_results" ] ) - assert 53 == (await modclient.ft().search( - Query("henry king").slop(0) - ))["total_results"] - assert 167 == (await modclient.ft().search( - Query("henry king").slop(100) - ))["total_results"] + assert ( + 53 + == (await modclient.ft().search(Query("henry king").slop(0)))[ + "total_results" + ] + ) + assert ( + 167 + == (await modclient.ft().search(Query("henry king").slop(100)))[ + "total_results" + ] + ) # test delete document await modclient.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) @@ -350,7 +369,6 @@ async def test_stopwords(modclient: redis.Redis): assert 1 == res2["total_results"] - @pytest.mark.redismod async def test_filters(modclient: redis.Redis): await ( @@ -406,7 +424,7 @@ async def test_filters(modclient: redis.Redis): assert 1 == res1["total_results"] assert 2 == res2["total_results"] assert "doc1" == res1["results"][0]["id"] - + # Sort results, after RDB reload order may change res = [res2["results"][0]["id"], res2["results"][1]["id"]] res.sort() @@ -1025,9 +1043,7 @@ async def test_scorer(modclient: redis.Redis): Query("quick").scorer("TFIDF.DOCNORM").with_scores() ) assert 0.1111111111111111 == res["results"][0]["score"] - res = await modclient.ft().search( - Query("quick").scorer("BM25").with_scores() - ) + res = await modclient.ft().search(Query("quick").scorer("BM25").with_scores()) assert 0.17699114465425977 == res["results"][0]["score"] res = await modclient.ft().search(Query("quick").scorer("DISMAX").with_scores()) assert 2.0 == res["results"][0]["score"] @@ -1232,7 +1248,9 @@ async def test_aggregations_groupby(modclient: redis.Redis): req = ( aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.random_sample("@title", 2).alias("random")) + .group_by( + "@parent", reducers.random_sample("@title", 2).alias("random") + ) .dialect(dialect) ) @@ -1300,7 +1318,7 @@ async def test_aggregations_groupby(modclient: redis.Redis): res = (await modclient.ft().aggregate(req))["results"][0] assert res["fields"]["parent"] == "redis" - assert res["fields"]["__generated_aliasmaxrandom_num"] == "10" # max(10,8,3) + assert res["fields"]["__generated_aliasmaxrandom_num"] == "10" req = ( aggregations.AggregateRequest("redis") @@ -1357,7 +1375,9 @@ async def test_aggregations_groupby(modclient: redis.Redis): req = ( aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.random_sample("@title", 2).alias("random")) + .group_by( + "@parent", reducers.random_sample("@title", 2).alias("random") + ) .dialect(dialect) ) diff --git a/tests/test_asyncio/test_timeseries.py b/tests/test_asyncio/test_timeseries.py index f2580b7f97..d09e992a7b 100644 --- a/tests/test_asyncio/test_timeseries.py +++ b/tests/test_asyncio/test_timeseries.py @@ -4,7 +4,11 @@ import pytest import redis.asyncio as redis -from tests.conftest import assert_resp_response, is_resp2_connection, skip_ifmodversion_lt +from tests.conftest import ( + assert_resp_response, + is_resp2_connection, + skip_ifmodversion_lt, +) @pytest.mark.redismod @@ -219,7 +223,9 @@ async def test_del_range(modclient: redis.Redis): await modclient.ts().add(1, i, i % 7) assert 22 == await modclient.ts().delete(1, 0, 21) assert [] == await modclient.ts().range(1, 0, 21) - assert_resp_response(modclient, await modclient.ts().range(1, 22, 22), [(22, 1.0)], [[22, 1.0]]) + assert_resp_response( + modclient, await modclient.ts().range(1, 22, 22), [(22, 1.0)], [[22, 1.0]] + ) @pytest.mark.redismod @@ -264,7 +270,9 @@ async def test_range_advanced(modclient: redis.Redis): 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=5 ) assert_resp_response(modclient, res, [(0, 5.0), (5, 6.0)], [[0, 5.0], [5, 6.0]]) - res = await modclient.ts().range(1, 0, 10, aggregation_type="twa", bucket_size_msec=10) + res = await modclient.ts().range( + 1, 0, 10, aggregation_type="twa", bucket_size_msec=10 + ) assert_resp_response(modclient, res, [(0, 2.55), (10, 3.0)], [[0, 2.55], [10, 3.0]]) @@ -341,7 +349,9 @@ async def test_multi_range(modclient: redis.Redis): # test withlabels assert {} == res[0]["1"][0] - res = await modclient.ts().mrange(0, 200, filters=["Test=This"], with_labels=True) + res = await modclient.ts().mrange( + 0, 200, filters=["Test=This"], with_labels=True + ) assert {"Test": "This", "team": "ny"} == res[0]["1"][0] else: assert 100 == len(res["1"][2]) @@ -359,7 +369,9 @@ async def test_multi_range(modclient: redis.Redis): # test withlabels assert {} == res["1"][0] - res = await modclient.ts().mrange(0, 200, filters=["Test=This"], with_labels=True) + res = await modclient.ts().mrange( + 0, 200, filters=["Test=This"], with_labels=True + ) assert {"Test": "This", "team": "ny"} == res["1"][0] From 2e494ef29180f1aee85a89da8b9c3b094c26ba37 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 15 Jun 2023 18:22:45 +0300 Subject: [PATCH 10/10] revert redismod-url change --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 187be1189e..6454750353 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ REDIS_INFO = {} default_redis_url = "redis://localhost:6379/0" -default_redismod_url = "redis://localhost:6379" +default_redismod_url = "redis://localhost:36379" default_redis_unstable_url = "redis://localhost:6378" # default ssl client ignores verification for the purpose of testing