Skip to content

Commit 46007a6

Browse files
authored
Merge pull request #36 from tannewt/check_send
Check for send to return 0
2 parents 9aaf781 + c044eab commit 46007a6

File tree

8 files changed

+222
-93
lines changed

8 files changed

+222
-93
lines changed

adafruit_requests.py

Lines changed: 73 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -377,26 +377,28 @@ def __init__(self, socket_pool, ssl_context=None):
377377
self._last_response = None
378378

379379
def _free_socket(self, socket):
380-
381380
if socket not in self._open_sockets.values():
382381
raise RuntimeError("Socket not from session")
383382
self._socket_free[socket] = True
384383

384+
def _close_socket(self, sock):
385+
sock.close()
386+
del self._socket_free[sock]
387+
key = None
388+
for k in self._open_sockets:
389+
if self._open_sockets[k] == sock:
390+
key = k
391+
break
392+
if key:
393+
del self._open_sockets[key]
394+
385395
def _free_sockets(self):
386396
free_sockets = []
387397
for sock in self._socket_free:
388398
if self._socket_free[sock]:
389-
sock.close()
390399
free_sockets.append(sock)
391400
for sock in free_sockets:
392-
del self._socket_free[sock]
393-
key = None
394-
for k in self._open_sockets:
395-
if self._open_sockets[k] == sock:
396-
key = k
397-
break
398-
if key:
399-
del self._open_sockets[key]
401+
self._close_socket(sock)
400402

401403
def _get_socket(self, host, port, proto, *, timeout=1):
402404
key = (host, port, proto)
@@ -440,6 +442,61 @@ def _get_socket(self, host, port, proto, *, timeout=1):
440442
self._socket_free[sock] = False
441443
return sock
442444

445+
@staticmethod
446+
def _send(socket, data):
447+
total_sent = 0
448+
while total_sent < len(data):
449+
sent = socket.send(data[total_sent:])
450+
if sent is None:
451+
sent = len(data)
452+
if sent == 0:
453+
raise RuntimeError("Connection closed")
454+
total_sent += sent
455+
456+
def _send_request(self, socket, host, method, path, headers, data, json):
457+
# pylint: disable=too-many-arguments
458+
self._send(socket, bytes(method, "utf-8"))
459+
self._send(socket, b" /")
460+
self._send(socket, bytes(path, "utf-8"))
461+
self._send(socket, b" HTTP/1.1\r\n")
462+
if "Host" not in headers:
463+
self._send(socket, b"Host: ")
464+
self._send(socket, bytes(host, "utf-8"))
465+
self._send(socket, b"\r\n")
466+
if "User-Agent" not in headers:
467+
self._send(socket, b"User-Agent: Adafruit CircuitPython\r\n")
468+
# Iterate over keys to avoid tuple alloc
469+
for k in headers:
470+
self._send(socket, k.encode())
471+
self._send(socket, b": ")
472+
self._send(socket, headers[k].encode())
473+
self._send(socket, b"\r\n")
474+
if json is not None:
475+
assert data is None
476+
# pylint: disable=import-outside-toplevel
477+
try:
478+
import json as json_module
479+
except ImportError:
480+
import ujson as json_module
481+
data = json_module.dumps(json)
482+
self._send(socket, b"Content-Type: application/json\r\n")
483+
if data:
484+
if isinstance(data, dict):
485+
self._send(
486+
socket, b"Content-Type: application/x-www-form-urlencoded\r\n"
487+
)
488+
_post_data = ""
489+
for k in data:
490+
_post_data = "{}&{}={}".format(_post_data, k, data[k])
491+
data = _post_data[1:]
492+
self._send(socket, b"Content-Length: %d\r\n" % len(data))
493+
self._send(socket, b"\r\n")
494+
if data:
495+
if isinstance(data, bytearray):
496+
self._send(socket, bytes(data))
497+
else:
498+
self._send(socket, bytes(data, "utf-8"))
499+
443500
# pylint: disable=too-many-branches, too-many-statements, unused-argument, too-many-arguments, too-many-locals
444501
def request(
445502
self, method, url, data=None, json=None, headers=None, stream=False, timeout=60
@@ -476,42 +533,11 @@ def request(
476533
self._last_response = None
477534

478535
socket = self._get_socket(host, port, proto, timeout=timeout)
479-
socket.send(
480-
b"%s /%s HTTP/1.1\r\n" % (bytes(method, "utf-8"), bytes(path, "utf-8"))
481-
)
482-
if "Host" not in headers:
483-
socket.send(b"Host: %s\r\n" % bytes(host, "utf-8"))
484-
if "User-Agent" not in headers:
485-
socket.send(b"User-Agent: Adafruit CircuitPython\r\n")
486-
# Iterate over keys to avoid tuple alloc
487-
for k in headers:
488-
socket.send(k.encode())
489-
socket.send(b": ")
490-
socket.send(headers[k].encode())
491-
socket.send(b"\r\n")
492-
if json is not None:
493-
assert data is None
494-
# pylint: disable=import-outside-toplevel
495-
try:
496-
import json as json_module
497-
except ImportError:
498-
import ujson as json_module
499-
data = json_module.dumps(json)
500-
socket.send(b"Content-Type: application/json\r\n")
501-
if data:
502-
if isinstance(data, dict):
503-
socket.send(b"Content-Type: application/x-www-form-urlencoded\r\n")
504-
_post_data = ""
505-
for k in data:
506-
_post_data = "{}&{}={}".format(_post_data, k, data[k])
507-
data = _post_data[1:]
508-
socket.send(b"Content-Length: %d\r\n" % len(data))
509-
socket.send(b"\r\n")
510-
if data:
511-
if isinstance(data, bytearray):
512-
socket.send(bytes(data))
513-
else:
514-
socket.send(bytes(data, "utf-8"))
536+
try:
537+
self._send_request(socket, host, method, path, headers, data, json)
538+
except:
539+
self._close_socket(socket)
540+
raise
515541

516542
resp = Response(socket, self) # our response
517543
if "location" in resp.headers and 300 <= resp.status_code <= 399:
@@ -557,6 +583,7 @@ def __init__(self, socket, tls_mode):
557583
self.settimeout = socket.settimeout
558584
self.send = socket.send
559585
self.recv = socket.recv
586+
self.close = socket.close
560587

561588
def connect(self, address):
562589
"""connect wrapper to add non-standard mode parameter"""

tests/chunk_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,16 @@ def test_get_text():
3939
r = s.get("http://" + host + path)
4040

4141
sock.connect.assert_called_once_with((ip, 80))
42+
4243
sock.send.assert_has_calls(
4344
[
44-
mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"),
45-
mock.call(b"Host: wifitest.adafruit.com\r\n"),
45+
mock.call(b"GET"),
46+
mock.call(b" /"),
47+
mock.call(b"testwifi/index.html"),
48+
mock.call(b" HTTP/1.1\r\n"),
4649
]
4750
)
51+
sock.send.assert_has_calls(
52+
[mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"),]
53+
)
4854
assert r.text == str(text, "utf-8")

tests/header_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@ def test_json():
1414
sock = mocket.Mocket(response_headers)
1515
pool.socket.return_value = sock
1616
sent = []
17-
sock.send.side_effect = sent.append
17+
18+
def _send(data):
19+
sent.append(data)
20+
return len(data)
21+
22+
sock.send.side_effect = _send
1823

1924
s = adafruit_requests.Session(pool)
2025
headers = {"user-agent": "blinka/1.0.0"}

tests/legacy_mocket.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@ def __init__(self, response):
1313
self.settimeout = mock.Mock()
1414
self.close = mock.Mock()
1515
self.connect = mock.Mock()
16-
self.send = mock.Mock()
16+
self.send = mock.Mock(side_effect=self._send)
1717
self.readline = mock.Mock(side_effect=self._readline)
1818
self.recv = mock.Mock(side_effect=self._recv)
1919
self._response = response
2020
self._position = 0
2121

22+
def _send(self, data):
23+
return len(data)
24+
2225
def _readline(self):
2326
i = self._response.find(b"\r\n", self._position)
2427
r = self._response[self._position : i + 2]

tests/mocket.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@ def __init__(self, response):
1414
self.settimeout = mock.Mock()
1515
self.close = mock.Mock()
1616
self.connect = mock.Mock()
17-
self.send = mock.Mock()
17+
self.send = mock.Mock(side_effect=self._send)
1818
self.readline = mock.Mock(side_effect=self._readline)
1919
self.recv = mock.Mock(side_effect=self._recv)
2020
self.recv_into = mock.Mock(side_effect=self._recv_into)
2121
self._response = response
2222
self._position = 0
2323

24+
def _send(self, data):
25+
return len(data)
26+
2427
def _readline(self):
2528
i = self._response.find(b"\r\n", self._position)
2629
r = self._response[self._position : i + 2]

tests/post_test.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,17 @@ def test_method():
2121
s = adafruit_requests.Session(pool)
2222
r = s.post("http://" + host + "/post")
2323
sock.connect.assert_called_once_with((ip, 80))
24+
25+
sock.send.assert_has_calls(
26+
[
27+
mock.call(b"POST"),
28+
mock.call(b" /"),
29+
mock.call(b"post"),
30+
mock.call(b" HTTP/1.1\r\n"),
31+
]
32+
)
2433
sock.send.assert_has_calls(
25-
[mock.call(b"POST /post HTTP/1.1\r\n"), mock.call(b"Host: httpbin.org\r\n")]
34+
[mock.call(b"Host: "), mock.call(b"httpbin.org"),]
2635
)
2736

2837

tests/protocol_test.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,18 @@ def test_get_https_text():
3232
r = s.get("https://" + host + path)
3333

3434
sock.connect.assert_called_once_with((host, 443))
35+
3536
sock.send.assert_has_calls(
3637
[
37-
mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"),
38-
mock.call(b"Host: wifitest.adafruit.com\r\n"),
38+
mock.call(b"GET"),
39+
mock.call(b" /"),
40+
mock.call(b"testwifi/index.html"),
41+
mock.call(b" HTTP/1.1\r\n"),
3942
]
4043
)
44+
sock.send.assert_has_calls(
45+
[mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"),]
46+
)
4147
assert r.text == str(text, "utf-8")
4248

4349
# Close isn't needed but can be called to release the socket early.
@@ -54,10 +60,16 @@ def test_get_http_text():
5460
r = s.get("http://" + host + path)
5561

5662
sock.connect.assert_called_once_with((ip, 80))
63+
5764
sock.send.assert_has_calls(
5865
[
59-
mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"),
60-
mock.call(b"Host: wifitest.adafruit.com\r\n"),
66+
mock.call(b"GET"),
67+
mock.call(b" /"),
68+
mock.call(b"testwifi/index.html"),
69+
mock.call(b" HTTP/1.1\r\n"),
6170
]
6271
)
72+
sock.send.assert_has_calls(
73+
[mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"),]
74+
)
6375
assert r.text == str(text, "utf-8")

0 commit comments

Comments
 (0)