|
1 | 1 | import binascii
|
2 | 2 | import datetime
|
| 3 | +import select |
| 4 | +import socket |
| 5 | +import socketserver |
| 6 | +import threading |
3 | 7 | import warnings
|
4 | 8 | from queue import LifoQueue, Queue
|
5 | 9 | from time import sleep
|
6 | 10 | from unittest.mock import DEFAULT, Mock, call, patch
|
| 11 | +from urllib.parse import urlparse |
7 | 12 |
|
8 | 13 | import pytest
|
9 | 14 |
|
|
53 | 58 | ]
|
54 | 59 |
|
55 | 60 |
|
| 61 | +class ProxyRequestHandler(socketserver.BaseRequestHandler): |
| 62 | + def recv(self, sock): |
| 63 | + """A recv with a timeout""" |
| 64 | + r = select.select([sock], [], [], 0.01) |
| 65 | + if not r[0]: |
| 66 | + return None |
| 67 | + return sock.recv(1000) |
| 68 | + |
| 69 | + def handle(self): |
| 70 | + self.server.proxy.n_connections += 1 |
| 71 | + conn = socket.create_connection(self.server.proxy.redis_addr) |
| 72 | + stop = False |
| 73 | + |
| 74 | + def from_server(): |
| 75 | + # read from server and pass to client |
| 76 | + while not stop: |
| 77 | + data = self.recv(conn) |
| 78 | + if data is None: |
| 79 | + continue |
| 80 | + if not data: |
| 81 | + return |
| 82 | + self.request.sendall(data) |
| 83 | + |
| 84 | + thread = threading.Thread(target=from_server) |
| 85 | + thread.start() |
| 86 | + try: |
| 87 | + while True: |
| 88 | + # read from client and send to server |
| 89 | + data = self.request.recv(1000) |
| 90 | + if not data: |
| 91 | + return |
| 92 | + conn.sendall(data) |
| 93 | + finally: |
| 94 | + stop = True |
| 95 | + conn.close() |
| 96 | + thread.join() |
| 97 | + |
| 98 | + |
| 99 | +class NodeProxy: |
| 100 | + """A class to proxy a node connection to a different port""" |
| 101 | + |
| 102 | + def __init__(self, addr, redis_addr): |
| 103 | + self.addr = addr |
| 104 | + self.redis_addr = redis_addr |
| 105 | + self.server = socketserver.ThreadingTCPServer(self.addr, ProxyRequestHandler) |
| 106 | + self.server.proxy = self |
| 107 | + self.server.socket_reuse_address = True |
| 108 | + self.thread = None |
| 109 | + self.n_connections = 0 |
| 110 | + |
| 111 | + def start(self): |
| 112 | + # test that we can connect to redis |
| 113 | + s = socket.create_connection(self.redis_addr, timeout=2) |
| 114 | + s.close() |
| 115 | + # Start a thread with the server -- that thread will then start one |
| 116 | + # more thread for each request |
| 117 | + self.thread = threading.Thread(target=self.server.serve_forever) |
| 118 | + # Exit the server thread when the main thread terminates |
| 119 | + self.thread.daemon = True |
| 120 | + self.thread.start() |
| 121 | + |
| 122 | + def close(self): |
| 123 | + self.server.shutdown() |
| 124 | + |
| 125 | + |
| 126 | +@pytest.fixture |
| 127 | +def redis_addr(request): |
| 128 | + redis_url = request.config.getoption("--redis-url") |
| 129 | + scheme, netloc = urlparse(redis_url)[:2] |
| 130 | + assert scheme == "redis" |
| 131 | + if ":" in netloc: |
| 132 | + host, port = netloc.split(":") |
| 133 | + return host, int(port) |
| 134 | + else: |
| 135 | + return netloc, 6379 |
| 136 | + |
| 137 | + |
56 | 138 | @pytest.fixture()
|
57 | 139 | def slowlog(request, r):
|
58 | 140 | """
|
@@ -823,6 +905,47 @@ def raise_connection_error():
|
823 | 905 | assert "myself" not in nodes.get(curr_default_node.name).get("flags")
|
824 | 906 | assert r.get_default_node() != curr_default_node
|
825 | 907 |
|
| 908 | + def test_host_port_remap(self, request, redis_addr): |
| 909 | + """Test that we can create a rediscluster object with |
| 910 | + a host-port remapper and map connections through proxy objects |
| 911 | + """ |
| 912 | + |
| 913 | + # we remap the first n nodes |
| 914 | + offset = 1000 |
| 915 | + n = 6 |
| 916 | + ports = [redis_addr[1] + i for i in range(n)] |
| 917 | + |
| 918 | + def host_port_remap(host, port): |
| 919 | + # remap first three nodes to our local proxy |
| 920 | + old = host, port |
| 921 | + if int(port) in ports: |
| 922 | + host, port = "127.0.0.1", int(port) + offset |
| 923 | + # print(f"{old} {host, port}") |
| 924 | + return host, port |
| 925 | + |
| 926 | + # create the proxies |
| 927 | + proxies = [ |
| 928 | + NodeProxy(("127.0.0.1", port + offset), (redis_addr[0], port)) |
| 929 | + for port in ports |
| 930 | + ] |
| 931 | + for p in proxies: |
| 932 | + p.start() |
| 933 | + try: |
| 934 | + # create cluster: |
| 935 | + r = _get_client( |
| 936 | + RedisCluster, request, flushdb=False, host_port_remap=host_port_remap |
| 937 | + ) |
| 938 | + try: |
| 939 | + assert r.ping() is True |
| 940 | + finally: |
| 941 | + r.close() |
| 942 | + finally: |
| 943 | + for p in proxies: |
| 944 | + p.close() |
| 945 | + |
| 946 | + # verify that the proxies were indeed used |
| 947 | + assert any(p.n_connections for p in proxies) |
| 948 | + |
826 | 949 |
|
827 | 950 | @pytest.mark.onlycluster
|
828 | 951 | class TestClusterRedisCommands:
|
|
0 commit comments