Skip to content

Commit 3c3b44c

Browse files
committed
Don't trust send works. Do one recv before creating Response
1 parent 531e845 commit 3c3b44c

File tree

3 files changed

+60
-8
lines changed

3 files changed

+60
-8
lines changed

adafruit_requests.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def readinto(self, buf):
8787
class _SendFailed(Exception):
8888
"""Custom exception to abort sending a request."""
8989

90+
class OutOfRetries(Exception):
91+
"""Raised when requests has retried to make a request unsuccessfully."""
9092

9193
class Response:
9294
"""The response from a request, contains all the headers/content"""
@@ -570,13 +572,27 @@ def request(
570572
while retry_count < 2:
571573
retry_count += 1
572574
socket = self._get_socket(host, port, proto, timeout=timeout)
575+
ok = True
573576
try:
574577
self._send_request(socket, host, method, path, headers, data, json)
575-
break
576578
except _SendFailed:
577-
self._close_socket(socket)
578-
if retry_count > 1:
579-
raise
579+
ok = False
580+
if ok:
581+
# Read the H of "HTTP/1.1" to make sure the socket is alive. send can appear to work
582+
# even when the socket is closed.
583+
if hasattr(socket, "recv"):
584+
result = socket.recv(1)
585+
else:
586+
result = bytearray(1)
587+
socket.recv_into(result)
588+
if result == b"H":
589+
# Things seem to be ok so break with socket set.
590+
break
591+
self._close_socket(socket)
592+
socket = None
593+
594+
if not socket:
595+
raise OutOfRetries()
580596

581597
resp = Response(socket, self) # our response
582598
if "location" in resp.headers and 300 <= resp.status_code <= 399:

tests/legacy_test.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,13 @@ def test_second_send_fails():
115115
def test_first_read_fails():
116116
mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),)
117117
sock = mocket.Mocket(b"")
118+
sock2 = mocket.Mocket(headers + encoded)
118119
mocket.socket.call_count = 0 # Reset call count
119-
mocket.socket.side_effect = [sock]
120+
mocket.socket.side_effect = [sock, sock2]
120121

121122
adafruit_requests.set_socket(mocket, mocket.interface)
122123

123-
with pytest.raises(RuntimeError):
124-
r = adafruit_requests.get("http://" + host + "/testwifi/index.html")
124+
r = adafruit_requests.get("http://" + host + "/testwifi/index.html")
125125

126126
sock.send.assert_has_calls(
127127
[mock.call(b"testwifi/index.html"),]
@@ -131,10 +131,16 @@ def test_first_read_fails():
131131
[mock.call(b"Host: "), mock.call(host.encode("utf-8")), mock.call(b"\r\n"),]
132132
)
133133

134+
135+
sock2.send.assert_has_calls(
136+
[mock.call(b"Host: "), mock.call(host.encode("utf-8")), mock.call(b"\r\n"),]
137+
)
138+
134139
sock.connect.assert_called_once_with((ip, 80))
140+
sock2.connect.assert_called_once_with((ip, 80))
135141
# Make sure that the socket is closed after the first receive fails.
136142
sock.close.assert_called_once()
137-
assert mocket.socket.call_count == 1
143+
assert mocket.socket.call_count == 2
138144

139145

140146
def test_second_tls_connect_fails():

tests/reuse_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,33 @@ def test_second_send_fails():
170170
sock.close.assert_called_once()
171171
assert sock2.close.call_count == 0
172172
assert pool.socket.call_count == 2
173+
174+
def test_second_send_lies_recv_fails():
175+
pool = mocket.MocketPool()
176+
pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),)
177+
sock = mocket.Mocket(response)
178+
sock2 = mocket.Mocket(response)
179+
pool.socket.side_effect = [sock, sock2]
180+
181+
ssl = mocket.SSLContext()
182+
183+
s = adafruit_requests.Session(pool, ssl)
184+
r = s.get("https://" + host + path)
185+
186+
sock.send.assert_has_calls(
187+
[mock.call(b"testwifi/index.html"),]
188+
)
189+
190+
sock.send.assert_has_calls(
191+
[mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"), mock.call(b"\r\n"),]
192+
)
193+
assert r.text == str(text, "utf-8")
194+
195+
s.get("https://" + host + path + "2")
196+
197+
sock.connect.assert_called_once_with((host, 443))
198+
sock2.connect.assert_called_once_with((host, 443))
199+
# Make sure that the socket is closed after send fails.
200+
sock.close.assert_called_once()
201+
assert sock2.close.call_count == 0
202+
assert pool.socket.call_count == 2

0 commit comments

Comments
 (0)