Skip to content

Commit e372574

Browse files
committed
Fix close, test it and run black
1 parent 2afe50d commit e372574

14 files changed

+312
-133
lines changed

adafruit_requests.py

Lines changed: 83 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
__version__ = "0.0.0-auto.0"
5656
__repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_Requests.git"
5757

58+
5859
class _RawResponse:
5960
def __init__(self, response):
6061
self._response = response
@@ -67,6 +68,7 @@ def read(self, size=-1):
6768
def readinto(self, buf):
6869
return self._response._readinto(buf)
6970

71+
7072
class Response:
7173
"""The response from a request, contains all the headers/content"""
7274

@@ -105,9 +107,9 @@ def __enter__(self):
105107
def __exit__(self, exc_type, exc_value, traceback):
106108
self.close()
107109

108-
def _recv_into(self, buf, size=None):
110+
def _recv_into(self, buf, size=0):
109111
if self._backwards_compatible:
110-
size = len(buf) if size is None else size
112+
size = len(buf) if size == 0 else size
111113
b = self.socket.recv(size)
112114
read_size = len(b)
113115
buf[:read_size] = b
@@ -119,7 +121,6 @@ def _readto(self, first, second=b""):
119121
buf = self._receive_buffer
120122
end = self._received_length
121123
while True:
122-
print("searching", buf[:end])
123124
firsti = buf.find(first, 0, end)
124125
secondi = -1
125126
if second:
@@ -149,7 +150,7 @@ def _readto(self, first, second=b""):
149150
if end == len(buf):
150151
new_size = len(buf) + 32
151152
new_buf = bytearray(new_size)
152-
new_buf[:len(buf)] = buf
153+
new_buf[: len(buf)] = buf
153154
buf = new_buf
154155
self._receive_buffer = buf
155156

@@ -172,13 +173,18 @@ def _read_from_buffer(self, buf=None, nbytes=None):
172173
buf[:read] = membuf[:read]
173174
if read < self._received_length:
174175
new_end = self._received_length - read
175-
self._receive_buffer[:new_end] = membuf[read:self._received_length]
176+
self._receive_buffer[:new_end] = membuf[read : self._received_length]
176177
self._received_length = new_end
177178
else:
178179
self._received_length = 0
179180
return read
180181

181182
def _readinto(self, buf):
183+
if not self.socket:
184+
raise RuntimeError(
185+
"Newer Response closed this one. Use Responses immediately."
186+
)
187+
182188
if not self._remaining:
183189
# Consume the chunk header if need be.
184190
if self._chunked:
@@ -204,33 +210,6 @@ def _readinto(self, buf):
204210
read = self._recv_into(buf, nbytes)
205211
self._remaining -= read
206212

207-
# else:
208-
# print("chunked")
209-
# pending_bytes = 0
210-
# buf = memoryview(bytearray(chunk_size))
211-
# while True:
212-
# print("chunk", self._content_read, self._content_length)
213-
# print("chunk header", chunk_header)
214-
# self._content_length = http_chunk_size
215-
# self._content_read = 0
216-
# remaining_in_http_chunk = http_chunk_size
217-
218-
# pending_bytes = 0
219-
# while remaining_in_http_chunk:
220-
# read_now = chunk_size - pending_bytes
221-
# if read_now > remaining_in_http_chunk:
222-
# read_now = remaining_in_http_chunk
223-
# read_now = self._readinto(buf[pending_bytes:pending_bytes+read_now])
224-
# pending_bytes += read_now
225-
# if pending_bytes == chunk_size:
226-
# break
227-
# yield bytes(buf)
228-
229-
# self._throw_away(2) # Read the trailing CR LF
230-
#
231-
# if pending_bytes > 0:
232-
# yield bytes(buf[:pending_bytes])
233-
234213
return read
235214

236215
def _throw_away(self, nbytes):
@@ -243,27 +222,27 @@ def _throw_away(self, nbytes):
243222
if remaining:
244223
self._recv_into(buf, remaining)
245224

246-
def _close(self):
225+
def close(self):
247226
"""Drain the remaining ESP socket buffers. We assume we already got what we wanted."""
248-
if self.socket:
249-
# Make sure we've read all of our response.
250-
# print("Content length:", content_length)
251-
if self._cached is None:
252-
if self._remaining > 0:
253-
self._throw_away(self._remaining)
254-
elif self._chunked:
255-
while True:
256-
chunk_header = self._readto(b";", b"\r\n")
257-
chunk_size = int(chunk_header, 16)
258-
if chunk_size == 0:
259-
break
260-
self._throw_away(chunk_size + 2)
261-
self._parse_headers()
262-
if self._session:
263-
self._session.free_socket(self.socket)
264-
else:
265-
self.socket.close()
266-
self.socket = None
227+
if not self.socket:
228+
return
229+
# Make sure we've read all of our response.
230+
if self._cached is None:
231+
if self._remaining > 0:
232+
self._throw_away(self._remaining)
233+
elif self._chunked:
234+
while True:
235+
chunk_header = self._readto(b";", b"\r\n")
236+
chunk_size = int(chunk_header, 16)
237+
if chunk_size == 0:
238+
break
239+
self._throw_away(chunk_size + 2)
240+
self._parse_headers()
241+
if self._session:
242+
self._session.free_socket(self.socket)
243+
else:
244+
self.socket.close()
245+
self.socket = None
267246

268247
def _parse_headers(self):
269248
"""
@@ -277,14 +256,20 @@ def _parse_headers(self):
277256

278257
content = self._readto(b"\r\n")
279258
if title and content:
280-
title = str(title, 'utf-8')
281-
content = str(content, 'utf-8')
259+
title = str(title, "utf-8")
260+
content = str(content, "utf-8")
282261
# Check len first so we can skip the .lower allocation most of the time.
283-
if len(title) == len("content-length") and title.lower() == "content-length":
262+
if (
263+
len(title) == len("content-length")
264+
and title.lower() == "content-length"
265+
):
284266
self._remaining = int(content)
285-
if len(title) == len("transfer-encoding") and title.lower() == "transfer-encoding":
267+
if (
268+
len(title) == len("transfer-encoding")
269+
and title.lower() == "transfer-encoding"
270+
):
286271
self._chunked = content.lower() == "chunked"
287-
self._headers[title] = content
272+
self._headers[title] = content
288273

289274
@property
290275
def headers(self):
@@ -332,7 +317,7 @@ def json(self):
332317
obj = json.load(self._raw)
333318
if not self._cached:
334319
self._cached = obj
335-
self._close()
320+
self.close()
336321
return obj
337322

338323
def iter_content(self, chunk_size=1, decode_unicode=False):
@@ -351,7 +336,8 @@ def iter_content(self, chunk_size=1, decode_unicode=False):
351336
else:
352337
chunk = bytes(b)
353338
yield chunk
354-
self._close()
339+
self.close()
340+
355341

356342
class Session:
357343
def __init__(self, socket_pool, ssl_context=None):
@@ -375,8 +361,12 @@ def _get_socket(self, host, port, proto, *, timeout=1):
375361
self._socket_free[sock] = False
376362
return sock
377363
if proto == "https:" and not self._ssl_context:
378-
raise RuntimeError("ssl_context must be set before using adafruit_requests for https")
379-
addr_info = self._socket_pool.getaddrinfo(host, port, 0, self._socket_pool.SOCK_STREAM)[0]
364+
raise RuntimeError(
365+
"ssl_context must be set before using adafruit_requests for https"
366+
)
367+
addr_info = self._socket_pool.getaddrinfo(
368+
host, port, 0, self._socket_pool.SOCK_STREAM
369+
)[0]
380370
sock = self._socket_pool.socket(addr_info[0], addr_info[1], addr_info[2])
381371
if proto == "https:":
382372
sock = self._ssl_context.wrap_socket(sock, server_hostname=host)
@@ -391,15 +381,22 @@ def _get_socket(self, host, port, proto, *, timeout=1):
391381

392382
# We couldn't connect due to memory so clean up the open sockets.
393383
if not ok:
384+
free_sockets = []
394385
for s in self._socket_free:
395386
if self._socket_free[s]:
396387
s.close()
397-
del self._socket_free[s]
398-
for k in self._open_sockets:
399-
if self._open_sockets[k] == s:
400-
del self._open_sockets[k]
388+
free_sockets.append(s)
389+
for s in free_sockets:
390+
del self._socket_free[s]
391+
key = None
392+
for k in self._open_sockets:
393+
if self._open_sockets[k] == s:
394+
key = k
395+
break
396+
if key:
397+
del self._open_sockets[key]
401398
# Recreate the socket because the ESP-IDF won't retry the connection if it failed once.
402-
sock = None # Clear first so the first socket can be cleaned up.
399+
sock = None # Clear first so the first socket can be cleaned up.
403400
sock = self._socket_pool.socket(addr_info[0], addr_info[1], addr_info[2])
404401
if proto == "https:":
405402
sock = self._ssl_context.wrap_socket(sock, server_hostname=host)
@@ -410,7 +407,9 @@ def _get_socket(self, host, port, proto, *, timeout=1):
410407
return sock
411408

412409
# pylint: disable=too-many-branches, too-many-statements, unused-argument, too-many-arguments, too-many-locals
413-
def request(self, method, url, data=None, json=None, headers=None, stream=False, timeout=60):
410+
def request(
411+
self, method, url, data=None, json=None, headers=None, stream=False, timeout=60
412+
):
414413
"""Perform an HTTP request to the given url which we will parse to determine
415414
whether to use SSL ('https://') or not. We can also send some provided 'data'
416415
or a json dictionary which we will stringify. 'headers' is optional HTTP headers
@@ -439,11 +438,13 @@ def request(self, method, url, data=None, json=None, headers=None, stream=False,
439438
port = int(port)
440439

441440
if self._last_response:
442-
self._last_response._close()
441+
self._last_response.close()
443442
self._last_response = None
444443

445444
socket = self._get_socket(host, port, proto, timeout=timeout)
446-
socket.send(b"%s /%s HTTP/1.1\r\n" % (bytes(method, "utf-8"), bytes(path, "utf-8")))
445+
socket.send(
446+
b"%s /%s HTTP/1.1\r\n" % (bytes(method, "utf-8"), bytes(path, "utf-8"))
447+
)
447448
if "Host" not in headers:
448449
socket.send(b"Host: %s\r\n" % bytes(host, "utf-8"))
449450
if "User-Agent" not in headers:
@@ -489,47 +490,54 @@ def head(self, url, **kw):
489490
"""Send HTTP HEAD request"""
490491
return self.request("HEAD", url, **kw)
491492

492-
493493
def get(self, url, **kw):
494494
"""Send HTTP GET request"""
495495
return self.request("GET", url, **kw)
496496

497-
498497
def post(self, url, **kw):
499498
"""Send HTTP POST request"""
500499
return self.request("POST", url, **kw)
501500

502-
503501
def put(self, url, **kw):
504502
"""Send HTTP PUT request"""
505503
return self.request("PUT", url, **kw)
506504

507-
508505
def patch(self, url, **kw):
509506
"""Send HTTP PATCH request"""
510507
return self.request("PATCH", url, **kw)
511508

512-
513509
def delete(self, url, **kw):
514510
"""Send HTTP DELETE request"""
515511
return self.request("DELETE", url, **kw)
516512

513+
517514
# Backwards compatible API:
518515

519516
_default_session = None
520517

518+
521519
class FakeSSLContext:
522520
def wrap_socket(self, socket, server_hostname=None):
523521
return socket
524522

523+
525524
def set_socket(sock, iface=None):
526525
global _default_session
527526
_default_session = Session(sock, FakeSSLContext())
528527
if iface:
529528
sock.set_interface(iface)
530529

530+
531531
def request(method, url, data=None, json=None, headers=None, stream=False, timeout=1):
532-
_default_session.request(method, url, data=data, json=json, headers=headers, stream=stream, timeout=timeout)
532+
_default_session.request(
533+
method,
534+
url,
535+
data=data,
536+
json=json,
537+
headers=headers,
538+
stream=stream,
539+
timeout=timeout,
540+
)
533541

534542

535543
def head(url, **kw):

examples/requests_advanced_cpython.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import socket
2-
import adafruit_requests as requests
3-
requests.socket_module = socket
2+
import adafruit_requests
3+
4+
http = adafruit_requests.Session(socket)
45

56
JSON_GET_URL = "http://httpbin.org/get"
67

78
# Define a custom header as a dict.
89
headers = {"user-agent": "blinka/1.0.0"}
910

1011
print("Fetching JSON data from %s..." % JSON_GET_URL)
11-
response = requests.get(JSON_GET_URL, headers=headers)
12+
response = http.get(JSON_GET_URL, headers=headers)
1213
print("-" * 60)
1314

1415
json_data = response.json()

examples/requests_github_cpython.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
print("Getting CircuitPython star count")
99
headers = {"Transfer-Encoding": "chunked"}
10-
response = http.get("https://api.github.com/repos/adafruit/circuitpython", headers=headers)
11-
print(response.headers)
10+
response = http.get(
11+
"https://api.github.com/repos/adafruit/circuitpython", headers=headers
12+
)
1213
print("circuitpython stars", response.json()["stargazers_count"])
13-

examples/requests_https_cpython.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import ssl
44
import adafruit_requests as requests
55

6-
http = requests.Session(socket, ssl.create_default_context())
6+
https = requests.Session(socket, ssl.create_default_context())
77

88
TEXT_URL = "https://wifitest.adafruit.com/testwifi/index.html"
99
JSON_GET_URL = "https://httpbin.org/get"
@@ -18,15 +18,15 @@
1818
# response.close()
1919

2020
print("Fetching JSON data from %s" % JSON_GET_URL)
21-
response = http.get(JSON_GET_URL)
21+
response = https.get(JSON_GET_URL)
2222
print("-" * 40)
2323

2424
print("JSON Response: ", response.json())
2525
print("-" * 40)
2626

2727
data = "31F"
2828
print("POSTing data to {0}: {1}".format(JSON_POST_URL, data))
29-
response = http.post(JSON_POST_URL, data=data)
29+
response = https.post(JSON_POST_URL, data=data)
3030
print("-" * 40)
3131

3232
json_resp = response.json()
@@ -36,7 +36,7 @@
3636

3737
json_data = {"Date": "July 25, 2019"}
3838
print("POSTing data to {0}: {1}".format(JSON_POST_URL, json_data))
39-
response = http.post(JSON_POST_URL, json=json_data)
39+
response = https.post(JSON_POST_URL, json=json_data)
4040
print("-" * 40)
4141

4242
json_resp = response.json()

0 commit comments

Comments
 (0)