Skip to content

Better handle errors by retrying #50

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 64 additions & 35 deletions adafruit_requests.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
__version__ = "0.0.0-auto.0"
__repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_Requests.git"

import errno


class _RawResponse:
def __init__(self, response):
Expand All @@ -73,6 +75,10 @@ def readinto(self, buf):
return self._response._readinto(buf) # pylint: disable=protected-access


class _SendFailed(Exception):
"""Custom exception to abort sending a request."""


class Response:
"""The response from a request, contains all the headers/content"""

Expand All @@ -94,11 +100,13 @@ def __init__(self, sock, session=None):
self._chunked = False

self._backwards_compatible = not hasattr(sock, "recv_into")
if self._backwards_compatible:
print("Socket missing recv_into. Using more memory to be compatible")

http = self._readto(b" ")
if not http:
if session:
session._close_socket(self.socket)
else:
self.socket.close()
raise RuntimeError("Unable to read HTTP response.")
self.status_code = int(bytes(self._readto(b" ")))
self.reason = self._readto(b"\r\n")
Expand Down Expand Up @@ -414,30 +422,41 @@ def _get_socket(self, host, port, proto, *, timeout=1):
addr_info = self._socket_pool.getaddrinfo(
host, port, 0, self._socket_pool.SOCK_STREAM
)[0]
sock = self._socket_pool.socket(addr_info[0], addr_info[1], addr_info[2])
connect_host = addr_info[-1][0]
if proto == "https:":
sock = self._ssl_context.wrap_socket(sock, server_hostname=host)
connect_host = host
sock.settimeout(timeout) # socket read timeout
ok = True
try:
ok = sock.connect((connect_host, port))
except MemoryError:
if not any(self._socket_free.items()):
raise
ok = False

# We couldn't connect due to memory so clean up the open sockets.
if not ok:
self._free_sockets()
# Recreate the socket because the ESP-IDF won't retry the connection if it failed once.
sock = None # Clear first so the first socket can be cleaned up.
sock = self._socket_pool.socket(addr_info[0], addr_info[1], addr_info[2])
retry_count = 0
sock = None
while retry_count < 5 and sock is None:
if retry_count > 0:
if any(self._socket_free.items()):
self._free_sockets()
else:
raise RuntimeError("Sending request failed")
retry_count += 1

try:
sock = self._socket_pool.socket(
addr_info[0], addr_info[1], addr_info[2]
)
except OSError:
continue

connect_host = addr_info[-1][0]
if proto == "https:":
sock = self._ssl_context.wrap_socket(sock, server_hostname=host)
connect_host = host
sock.settimeout(timeout) # socket read timeout
sock.connect((connect_host, port))

try:
sock.connect((connect_host, port))
except MemoryError:
sock.close()
sock = None
except OSError:
sock.close()
sock = None

if sock is None:
raise RuntimeError("Repeated socket failures")

self._open_sockets[key] = sock
self._socket_free[sock] = False
return sock
Expand All @@ -446,11 +465,15 @@ def _get_socket(self, host, port, proto, *, timeout=1):
def _send(socket, data):
total_sent = 0
while total_sent < len(data):
sent = socket.send(data[total_sent:])
# ESP32SPI sockets raise a RuntimeError when unable to send.
try:
sent = socket.send(data[total_sent:])
except RuntimeError:
sent = 0
if sent is None:
sent = len(data)
if sent == 0:
raise RuntimeError("Connection closed")
raise _SendFailed()
total_sent += sent

def _send_request(self, socket, host, method, path, headers, data, json):
Expand Down Expand Up @@ -532,12 +555,19 @@ def request(
self._last_response.close()
self._last_response = None

socket = self._get_socket(host, port, proto, timeout=timeout)
try:
self._send_request(socket, host, method, path, headers, data, json)
except:
self._close_socket(socket)
raise
# We may fail to send the request if the socket we got is closed already. So, try a second
# time in that case.
retry_count = 0
while retry_count < 2:
retry_count += 1
socket = self._get_socket(host, port, proto, timeout=timeout)
try:
self._send_request(socket, host, method, path, headers, data, json)
break
except _SendFailed:
self._close_socket(socket)
if retry_count > 1:
raise

resp = Response(socket, self) # our response
if "location" in resp.headers and 300 <= resp.status_code <= 399:
Expand Down Expand Up @@ -588,10 +618,9 @@ def __init__(self, socket, tls_mode):
def connect(self, address):
"""connect wrapper to add non-standard mode parameter"""
try:
self._socket.connect(address, self._mode)
return True
except RuntimeError:
return False
return self._socket.connect(address, self._mode)
except RuntimeError as error:
raise OSError(errno.ENOMEM) from error


class _FakeSSLContext:
Expand Down
86 changes: 86 additions & 0 deletions tests/concurrent_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from unittest import mock
import mocket
import pytest
import errno
import adafruit_requests

ip = "1.2.3.4"
host = "wifitest.adafruit.com"
host2 = "wifitest2.adafruit.com"
path = "/testwifi/index.html"
text = b"This is a test of Adafruit WiFi!\r\nIf you can read this, its working :)"
response = b"HTTP/1.0 200 OK\r\nContent-Length: 70\r\n\r\n" + text


def test_second_connect_fails_memoryerror():
pool = mocket.MocketPool()
pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),)
sock = mocket.Mocket(response)
sock2 = mocket.Mocket(response)
sock3 = mocket.Mocket(response)
pool.socket.call_count = 0 # Reset call count
pool.socket.side_effect = [sock, sock2, sock3]
sock2.connect.side_effect = MemoryError()

ssl = mocket.SSLContext()

s = adafruit_requests.Session(pool, ssl)
r = s.get("https://" + host + path)

sock.send.assert_has_calls(
[mock.call(b"testwifi/index.html"),]
)

sock.send.assert_has_calls(
[mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"), mock.call(b"\r\n"),]
)
assert r.text == str(text, "utf-8")

host2 = "test.adafruit.com"
s.get("https://" + host2 + path)

sock.connect.assert_called_once_with((host, 443))
sock2.connect.assert_called_once_with((host2, 443))
sock3.connect.assert_called_once_with((host2, 443))
# Make sure that the socket is closed after send fails.
sock.close.assert_called_once()
sock2.close.assert_called_once()
assert sock3.close.call_count == 0
assert pool.socket.call_count == 3


def test_second_connect_fails_oserror():
pool = mocket.MocketPool()
pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),)
sock = mocket.Mocket(response)
sock2 = mocket.Mocket(response)
sock3 = mocket.Mocket(response)
pool.socket.call_count = 0 # Reset call count
pool.socket.side_effect = [sock, sock2, sock3]
sock2.connect.side_effect = OSError(errno.ENOMEM)

ssl = mocket.SSLContext()

s = adafruit_requests.Session(pool, ssl)
r = s.get("https://" + host + path)

sock.send.assert_has_calls(
[mock.call(b"testwifi/index.html"),]
)

sock.send.assert_has_calls(
[mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"), mock.call(b"\r\n"),]
)
assert r.text == str(text, "utf-8")

host2 = "test.adafruit.com"
s.get("https://" + host2 + path)

sock.connect.assert_called_once_with((host, 443))
sock2.connect.assert_called_once_with((host2, 443))
sock3.connect.assert_called_once_with((host2, 443))
# Make sure that the socket is closed after send fails.
sock.close.assert_called_once()
sock2.close.assert_called_once()
assert sock3.close.call_count == 0
assert pool.socket.call_count == 3
7 changes: 6 additions & 1 deletion tests/legacy_mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@ def __init__(self, response):
self.send = mock.Mock(side_effect=self._send)
self.readline = mock.Mock(side_effect=self._readline)
self.recv = mock.Mock(side_effect=self._recv)
self.fail_next_send = False
self._response = response
self._position = 0

def _send(self, data):
return len(data)
if self.fail_next_send:
self.fail_next_send = False
raise RuntimeError("Send failed")
return None

def _readline(self):
i = self._response.find(b"\r\n", self._position)
Expand All @@ -32,4 +36,5 @@ def _recv(self, count):
end = self._position + count
r = self._response[self._position : end]
self._position = end
print(r)
return r
120 changes: 120 additions & 0 deletions tests/legacy_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest import mock
import legacy_mocket as mocket
import json
import pytest
import adafruit_requests

ip = "1.2.3.4"
Expand Down Expand Up @@ -49,3 +50,122 @@ def test_post_string():
sock.connect.assert_called_once_with((ip, 80))
sock.send.assert_called_with(b"31F")
r.close()


def test_second_tls_send_fails():
mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),)
sock = mocket.Mocket(headers + encoded)
sock2 = mocket.Mocket(headers + encoded)
mocket.socket.call_count = 0 # Reset call count
mocket.socket.side_effect = [sock, sock2]

adafruit_requests.set_socket(mocket, mocket.interface)
r = adafruit_requests.get("https://" + host + "/testwifi/index.html")

sock.send.assert_has_calls(
[mock.call(b"testwifi/index.html"),]
)

sock.send.assert_has_calls(
[mock.call(b"Host: "), mock.call(host.encode("utf-8")), mock.call(b"\r\n"),]
)
assert r.text == str(encoded, "utf-8")

sock.fail_next_send = True
adafruit_requests.get("https://" + host + "/get2")

sock.connect.assert_called_once_with((host, 443), mocket.interface.TLS_MODE)
sock2.connect.assert_called_once_with((host, 443), mocket.interface.TLS_MODE)
# Make sure that the socket is closed after send fails.
sock.close.assert_called_once()
assert sock2.close.call_count == 0
assert mocket.socket.call_count == 2


def test_second_send_fails():
mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),)
sock = mocket.Mocket(headers + encoded)
sock2 = mocket.Mocket(headers + encoded)
mocket.socket.call_count = 0 # Reset call count
mocket.socket.side_effect = [sock, sock2]

adafruit_requests.set_socket(mocket, mocket.interface)
r = adafruit_requests.get("http://" + host + "/testwifi/index.html")

sock.send.assert_has_calls(
[mock.call(b"testwifi/index.html"),]
)

sock.send.assert_has_calls(
[mock.call(b"Host: "), mock.call(host.encode("utf-8")), mock.call(b"\r\n"),]
)
assert r.text == str(encoded, "utf-8")

sock.fail_next_send = True
adafruit_requests.get("http://" + host + "/get2")

sock.connect.assert_called_once_with((ip, 80))
sock2.connect.assert_called_once_with((ip, 80))
# Make sure that the socket is closed after send fails.
sock.close.assert_called_once()
assert sock2.close.call_count == 0
assert mocket.socket.call_count == 2


def test_first_read_fails():
mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),)
sock = mocket.Mocket(b"")
mocket.socket.call_count = 0 # Reset call count
mocket.socket.side_effect = [sock]

adafruit_requests.set_socket(mocket, mocket.interface)

with pytest.raises(RuntimeError):
r = adafruit_requests.get("http://" + host + "/testwifi/index.html")

sock.send.assert_has_calls(
[mock.call(b"testwifi/index.html"),]
)

sock.send.assert_has_calls(
[mock.call(b"Host: "), mock.call(host.encode("utf-8")), mock.call(b"\r\n"),]
)

sock.connect.assert_called_once_with((ip, 80))
# Make sure that the socket is closed after the first receive fails.
sock.close.assert_called_once()
assert mocket.socket.call_count == 1


def test_second_tls_connect_fails():
mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),)
sock = mocket.Mocket(headers + encoded)
sock2 = mocket.Mocket(headers + encoded)
sock3 = mocket.Mocket(headers + encoded)
mocket.socket.call_count = 0 # Reset call count
mocket.socket.side_effect = [sock, sock2, sock3]
sock2.connect.side_effect = RuntimeError("error connecting")

adafruit_requests.set_socket(mocket, mocket.interface)
r = adafruit_requests.get("https://" + host + "/testwifi/index.html")

sock.send.assert_has_calls(
[mock.call(b"testwifi/index.html"),]
)

sock.send.assert_has_calls(
[mock.call(b"Host: "), mock.call(host.encode("utf-8")), mock.call(b"\r\n"),]
)
assert r.text == str(encoded, "utf-8")

host2 = "test.adafruit.com"
r = adafruit_requests.get("https://" + host2 + "/get2")

sock.connect.assert_called_once_with((host, 443), mocket.interface.TLS_MODE)
sock2.connect.assert_called_once_with((host2, 443), mocket.interface.TLS_MODE)
sock3.connect.assert_called_once_with((host2, 443), mocket.interface.TLS_MODE)
# Make sure that the socket is closed after send fails.
sock.close.assert_called_once()
sock2.close.assert_called_once()
assert sock3.close.call_count == 0
assert mocket.socket.call_count == 3
Loading