Skip to content

Commit b6d4b4f

Browse files
committed
tests: add 'connect' tests for all Redis connection classes
1 parent 363f6e3 commit b6d4b4f

File tree

2 files changed

+313
-0
lines changed

2 files changed

+313
-0
lines changed

tests/test_asyncio/test_connect.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import asyncio
2+
import logging
3+
import re
4+
import socket
5+
import ssl
6+
7+
import pytest
8+
9+
from redis.asyncio.connection import (
10+
Connection,
11+
SSLConnection,
12+
UnixDomainSocketConnection,
13+
)
14+
15+
from ..ssl_utils import get_ssl_filename
16+
17+
_logger = logging.getLogger(__name__)
18+
19+
20+
_CLIENT_NAME = "test-suite-client"
21+
_CMD_SEP = b"\r\n"
22+
_SUCCESS_RESP = b"+OK" + _CMD_SEP
23+
_ERROR_RESP = b"-ERR" + _CMD_SEP
24+
_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP}
25+
26+
27+
@pytest.fixture
28+
def tcp_address():
29+
with socket.socket() as sock:
30+
sock.bind(("127.0.0.1", 0))
31+
return sock.getsockname()
32+
33+
34+
@pytest.fixture
35+
def uds_address(tmpdir):
36+
return tmpdir / "uds.sock"
37+
38+
39+
async def test_tcp_connect(tcp_address):
40+
host, port = tcp_address
41+
conn = Connection(host=host, port=port, client_name=_CLIENT_NAME, socket_timeout=10)
42+
await _assert_connect(conn, tcp_address)
43+
44+
45+
async def test_uds_connect(uds_address):
46+
path = str(uds_address)
47+
conn = UnixDomainSocketConnection(path, client_name=_CLIENT_NAME, socket_timeout=10)
48+
await _assert_connect(conn, path)
49+
50+
51+
@pytest.mark.ssl
52+
async def test_tcp_ssl_connect(tcp_address):
53+
host, port = tcp_address
54+
certfile = get_ssl_filename("server-cert.pem")
55+
keyfile = get_ssl_filename("server-key.pem")
56+
conn = SSLConnection(
57+
host=host,
58+
port=port,
59+
client_name=_CLIENT_NAME,
60+
ssl_ca_certs=certfile,
61+
socket_timeout=10,
62+
)
63+
await _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)
64+
65+
66+
async def _assert_connect(conn, server_address, certfile=None, keyfile=None):
67+
async def _handler(reader, writer):
68+
await _redis_request_handler(reader, writer, server)
69+
70+
if isinstance(server_address, str):
71+
server = await asyncio.start_unix_server(_handler, path=server_address)
72+
elif certfile:
73+
host, port = server_address
74+
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
75+
context.minimum_version = ssl.TLSVersion.TLSv1_2
76+
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
77+
server = await asyncio.start_server(_handler, host=host, port=port, ssl=context)
78+
else:
79+
host, port = server_address
80+
server = await asyncio.start_server(_handler, host=host, port=port)
81+
82+
async with server:
83+
await server.start_serving()
84+
try:
85+
await conn.connect()
86+
await conn.disconnect()
87+
finally:
88+
server.close()
89+
await server.wait_closed()
90+
91+
92+
async def _redis_request_handler(reader, writer, server):
93+
buffer = b""
94+
command = None
95+
command_ptr = None
96+
fragment_length = None
97+
while server.is_serving() or buffer:
98+
buffer += await reader.read()
99+
if not buffer:
100+
continue
101+
parts = re.split(_CMD_SEP, buffer)
102+
buffer = parts[-1]
103+
for fragment in parts[:-1]:
104+
fragment = fragment.decode()
105+
_logger.info("Command fragment: %s", fragment)
106+
107+
if fragment.startswith("*") and command is None:
108+
command = [None for _ in range(int(fragment[1:]))]
109+
command_ptr = 0
110+
fragment_length = None
111+
continue
112+
113+
if fragment.startswith("$") and command[command_ptr] is None:
114+
fragment_length = int(fragment[1:])
115+
continue
116+
117+
assert len(fragment) == fragment_length
118+
command[command_ptr] = fragment
119+
command_ptr += 1
120+
121+
if command_ptr < len(command):
122+
continue
123+
124+
command = " ".join(command)
125+
_logger.info("Command in %s", command)
126+
resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP)
127+
_logger.info("Response from %s", resp)
128+
await writer.write(resp)
129+
command = None

tests/test_connect.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import logging
2+
import re
3+
import socket
4+
import socketserver
5+
import ssl
6+
import threading
7+
8+
import pytest
9+
10+
from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection
11+
12+
from .ssl_utils import get_ssl_filename
13+
14+
_logger = logging.getLogger(__name__)
15+
16+
17+
_CLIENT_NAME = "test-suite-client"
18+
_CMD_SEP = b"\r\n"
19+
_SUCCESS_RESP = b"+OK" + _CMD_SEP
20+
_ERROR_RESP = b"-ERR" + _CMD_SEP
21+
_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP}
22+
23+
24+
@pytest.fixture
25+
def tcp_address():
26+
with socket.socket() as sock:
27+
sock.bind(("127.0.0.1", 0))
28+
return sock.getsockname()
29+
30+
31+
@pytest.fixture
32+
def uds_address(tmpdir):
33+
return tmpdir / "uds.sock"
34+
35+
36+
def test_tcp_connect(tcp_address):
37+
host, port = tcp_address
38+
conn = Connection(host=host, port=port, client_name=_CLIENT_NAME, socket_timeout=10)
39+
_assert_connect(conn, tcp_address)
40+
41+
42+
def test_uds_connect(uds_address):
43+
path = str(uds_address)
44+
conn = UnixDomainSocketConnection(path, client_name=_CLIENT_NAME, socket_timeout=10)
45+
_assert_connect(conn, path)
46+
47+
48+
@pytest.mark.ssl
49+
def test_tcp_ssl_connect(tcp_address):
50+
host, port = tcp_address
51+
certfile = get_ssl_filename("server-cert.pem")
52+
keyfile = get_ssl_filename("server-key.pem")
53+
conn = SSLConnection(
54+
host=host,
55+
port=port,
56+
client_name=_CLIENT_NAME,
57+
ssl_ca_certs=certfile,
58+
socket_timeout=10,
59+
)
60+
_assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)
61+
62+
63+
def _assert_connect(conn, server_address, certfile=None, keyfile=None):
64+
if isinstance(server_address, str):
65+
server = _RedisUDSServer(server_address, _RedisRequestHandler)
66+
else:
67+
server = _RedisTCPServer(
68+
server_address, _RedisRequestHandler, certfile=certfile, keyfile=keyfile
69+
)
70+
with server as aserver:
71+
t = threading.Thread(target=aserver.serve_forever)
72+
t.start()
73+
try:
74+
aserver.wait_online()
75+
conn.connect()
76+
conn.disconnect()
77+
finally:
78+
aserver.stop()
79+
t.join(timeout=5)
80+
81+
82+
class _RedisTCPServer(socketserver.TCPServer):
83+
def __init__(self, *args, certfile=None, keyfile=None, **kw) -> None:
84+
self._ready_event = threading.Event()
85+
self._stop_requested = False
86+
self._certfile = certfile
87+
self._keyfile = keyfile
88+
super().__init__(*args, **kw)
89+
90+
def service_actions(self):
91+
self._ready_event.set()
92+
93+
def wait_online(self):
94+
self._ready_event.wait()
95+
96+
def stop(self):
97+
self._stop_requested = True
98+
self.shutdown()
99+
100+
def is_serving(self):
101+
return not self._stop_requested
102+
103+
def get_request(self):
104+
if self._certfile is None:
105+
return super().get_request()
106+
newsocket, fromaddr = self.socket.accept()
107+
connstream = ssl.wrap_socket(
108+
newsocket,
109+
server_side=True,
110+
certfile=self._certfile,
111+
keyfile=self._keyfile,
112+
ssl_version=ssl.PROTOCOL_TLSv1_2,
113+
)
114+
return connstream, fromaddr
115+
116+
117+
class _RedisUDSServer(socketserver.UnixStreamServer):
118+
def __init__(self, *args, **kw) -> None:
119+
self._ready_event = threading.Event()
120+
self._stop_requested = False
121+
super().__init__(*args, **kw)
122+
123+
def service_actions(self):
124+
self._ready_event.set()
125+
126+
def wait_online(self):
127+
self._ready_event.wait()
128+
129+
def stop(self):
130+
self._stop_requested = True
131+
self.shutdown()
132+
133+
def is_serving(self):
134+
return not self._stop_requested
135+
136+
137+
class _RedisRequestHandler(socketserver.StreamRequestHandler):
138+
def setup(self):
139+
_logger.info("%s connected", self.client_address)
140+
141+
def finish(self):
142+
_logger.info("%s disconnected", self.client_address)
143+
144+
def handle(self):
145+
buffer = b""
146+
command = None
147+
command_ptr = None
148+
fragment_length = None
149+
while self.server.is_serving() or buffer:
150+
try:
151+
buffer += self.request.recv(1024)
152+
except socket.timeout:
153+
continue
154+
if not buffer:
155+
continue
156+
parts = re.split(_CMD_SEP, buffer)
157+
buffer = parts[-1]
158+
for fragment in parts[:-1]:
159+
fragment = fragment.decode()
160+
_logger.info("Command fragment: %s", fragment)
161+
162+
if fragment.startswith("*") and command is None:
163+
command = [None for _ in range(int(fragment[1:]))]
164+
command_ptr = 0
165+
fragment_length = None
166+
continue
167+
168+
if fragment.startswith("$") and command[command_ptr] is None:
169+
fragment_length = int(fragment[1:])
170+
continue
171+
172+
assert len(fragment) == fragment_length
173+
command[command_ptr] = fragment
174+
command_ptr += 1
175+
176+
if command_ptr < len(command):
177+
continue
178+
179+
command = " ".join(command)
180+
_logger.info("Command %s", command)
181+
resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP)
182+
_logger.info("Response %s", resp)
183+
self.request.sendall(resp)
184+
command = None

0 commit comments

Comments
 (0)