Skip to content

Remove pre 7.0 compatibility; other cleanup #102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

repos:
- repo: https://github.com/python/black
rev: 20.8b1
rev: 22.1.0
hooks:
- id: black
- repo: https://github.com/fsfe/reuse-tool
Expand Down
142 changes: 44 additions & 98 deletions adafruit_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@
import errno
import sys

import json as json_module

if sys.implementation.name == "circuitpython":

def cast(_t, value):
"""No-op shim for the typing.cast() function which is not available in CircuitPython."""
return value


else:
from ssl import SSLContext
from types import ModuleType, TracebackType
Expand Down Expand Up @@ -148,16 +149,6 @@ def TLS_MODE(self) -> int: # pylint: disable=invalid-name
SSLContextType = Union[SSLContext, "_FakeSSLContext"]


# CircuitPython 6.0 does not have the bytearray.split method.
# This function emulates buf.split(needle)[0], which is the functionality
# required.
def _buffer_split0(buf: Union[bytes, bytearray], needle: Union[bytes, bytearray]):
index = buf.find(needle)
if index == -1:
return buf
return buf[:index]


class _RawResponse:
def __init__(self, response: "Response") -> None:
self._response = response
Expand All @@ -177,10 +168,6 @@ def readinto(self, buf: bytearray) -> int:
return self._response._readinto(buf) # pylint: disable=protected-access


class _SendFailed(Exception):
"""Custom exception to abort sending a request."""


class OutOfRetries(Exception):
"""Raised when requests has retried to make a request unsuccessfully."""

Expand Down Expand Up @@ -240,56 +227,25 @@ def _recv_into(self, buf: bytearray, size: int = 0) -> int:
return read_size
return cast("SupportsRecvInto", self.socket).recv_into(buf, size)

@staticmethod
def _find(buf: bytes, needle: bytes, start: int, end: int) -> int:
if hasattr(buf, "find"):
return buf.find(needle, start, end)
result = -1
i = start
while i < end:
j = 0
while j < len(needle) and i + j < end and buf[i + j] == needle[j]:
j += 1
if j == len(needle):
result = i
break
i += 1

return result

def _readto(self, first: bytes, second: bytes = b"") -> bytes:
def _readto(self, stop: bytes) -> bytearray:
buf = self._receive_buffer
end = self._received_length
while True:
firsti = self._find(buf, first, 0, end)
secondi = -1
if second:
secondi = self._find(buf, second, 0, end)

i = -1
needle_len = 0
if firsti >= 0:
i = firsti
needle_len = len(first)
if secondi >= 0 and (firsti < 0 or secondi < firsti):
i = secondi
needle_len = len(second)
i = buf.find(stop, 0, end)
if i >= 0:
# Stop was found. Return everything up to but not including stop.
result = buf[:i]
new_start = i + needle_len

if i + needle_len <= end:
new_end = end - new_start
buf[:new_end] = buf[new_start:end]
self._received_length = new_end
new_start = i + len(stop)
# Remove everything up to and including stop from the buffer.
new_end = end - new_start
buf[:new_end] = buf[new_start:end]
self._received_length = new_end
return result

# Not found so load more.

# Not found so load more bytes.
# If our buffer is full, then make it bigger to load more.
if end == len(buf):
new_size = len(buf) + 32
new_buf = bytearray(new_size)
new_buf = bytearray(len(buf) + 32)
new_buf[: len(buf)] = buf
buf = new_buf
self._receive_buffer = buf
Expand All @@ -300,8 +256,6 @@ def _readto(self, first: bytes, second: bytes = b"") -> bytes:
return buf[:end]
end += read

return b""

def _read_from_buffer(
self, buf: Optional[bytearray] = None, nbytes: Optional[int] = None
) -> int:
Expand Down Expand Up @@ -333,7 +287,7 @@ def _readinto(self, buf: bytearray) -> int:
# Consume trailing \r\n for chunks 2+
if self._remaining == 0:
self._throw_away(2)
chunk_header = _buffer_split0(self._readto(b"\r\n"), b";")
chunk_header = bytes(self._readto(b"\r\n")).split(b";", 1)[0]
http_chunk_size = int(bytes(chunk_header), 16)
if http_chunk_size == 0:
self._chunked = False
Expand Down Expand Up @@ -374,7 +328,7 @@ def close(self) -> None:
self._throw_away(self._remaining)
elif self._chunked:
while True:
chunk_header = _buffer_split0(self._readto(b"\r\n"), b";")
chunk_header = bytes(self._readto(b"\r\n")).split(b";", 1)[0]
chunk_size = int(bytes(chunk_header), 16)
if chunk_size == 0:
break
Expand All @@ -392,11 +346,10 @@ def _parse_headers(self) -> None:
Expects first line of HTTP request/response to have been read already.
"""
while True:
title = self._readto(b": ", b"\r\n")
if not title:
header = self._readto(b"\r\n")
if not header:
break

content = self._readto(b"\r\n")
title, content = bytes(header).split(b": ", 1)
if title and content:
# enforce that all headers are lowercase
title = str(title, "utf-8").lower()
Expand All @@ -407,6 +360,17 @@ def _parse_headers(self) -> None:
self._chunked = content.strip().lower() == "chunked"
self._headers[title] = content

def _validate_not_gzip(self) -> None:
"""gzip encoding is not supported. Raise an exception if found."""
if (
"content-encoding" in self.headers
and self.headers["content-encoding"] == "gzip"
):
raise ValueError(
"Content-encoding is gzip, data cannot be accessed as json or text. "
"Use content property to access raw bytes."
)

@property
def headers(self) -> Dict[str, str]:
"""
Expand Down Expand Up @@ -435,22 +399,13 @@ def text(self) -> str:
return self._cached
raise RuntimeError("Cannot access text after getting content or json")

if (
"content-encoding" in self.headers
and self.headers["content-encoding"] == "gzip"
):
raise ValueError(
"Content-encoding is gzip, data cannot be accessed as json or text. "
"Use content property to access raw bytes."
)
self._validate_not_gzip()

self._cached = str(self.content, self.encoding)
return self._cached

def json(self) -> Any:
"""The HTTP content, parsed into a json dictionary"""
# pylint: disable=import-outside-toplevel
import json

# The cached JSON will be a list or dictionary.
if self._cached:
if isinstance(self._cached, (list, dict)):
Expand All @@ -459,20 +414,9 @@ def json(self) -> Any:
if not self._raw:
self._raw = _RawResponse(self)

if (
"content-encoding" in self.headers
and self.headers["content-encoding"] == "gzip"
):
raise ValueError(
"Content-encoding is gzip, data cannot be accessed as json or text. "
"Use content property to access raw bytes."
)
try:
obj = json.load(self._raw)
except OSError:
# <5.3.1 doesn't piecemeal load json from any object with readinto so load the whole
# string.
obj = json.loads(self._raw.read())
self._validate_not_gzip()

obj = json_module.load(self._raw)
if not self._cached:
self._cached = obj
self.close()
Expand Down Expand Up @@ -599,12 +543,19 @@ def _send(socket: SocketType, data: bytes):
# ESP32SPI sockets raise a RuntimeError when unable to send.
try:
sent = socket.send(data[total_sent:])
except RuntimeError:
sent = 0
except OSError as exc:
if exc.errno == errno.EAGAIN:
# Can't send right now (e.g., no buffer space), try again.
continue
# Some worse error.
raise
except RuntimeError as exc:
raise OSError(errno.EIO) from exc
if sent is None:
sent = len(data)
if sent == 0:
raise _SendFailed()
# Not EAGAIN; that was already handled.
raise OSError(errno.EIO)
total_sent += sent

def _send_request(
Expand Down Expand Up @@ -636,11 +587,6 @@ def _send_request(
self._send(socket, b"\r\n")
if json is not None:
assert data is None
# pylint: disable=import-outside-toplevel
try:
import json as json_module
except ImportError:
import ujson as json_module
data = json_module.dumps(json)
self._send(socket, b"Content-Type: application/json\r\n")
if data:
Expand Down Expand Up @@ -711,7 +657,7 @@ def request(
ok = True
try:
self._send_request(socket, host, method, path, headers, data, json)
except (_SendFailed, OSError):
except OSError:
ok = False
if ok:
# Read the H of "HTTP/1.1" to make sure the socket is alive. send can appear to work
Expand Down
2 changes: 1 addition & 1 deletion tests/legacy_mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


class Mocket: # pylint: disable=too-few-public-methods
""" Mock Socket """
"""Mock Socket"""

def __init__(self, response):
self.settimeout = mock.Mock()
Expand Down
6 changes: 3 additions & 3 deletions tests/mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class MocketPool: # pylint: disable=too-few-public-methods
""" Mock SocketPool """
"""Mock SocketPool"""

SOCK_STREAM = 0

Expand All @@ -18,7 +18,7 @@ def __init__(self):


class Mocket: # pylint: disable=too-few-public-methods
""" Mock Socket """
"""Mock Socket"""

def __init__(self, response):
self.settimeout = mock.Mock()
Expand Down Expand Up @@ -62,7 +62,7 @@ def _recv_into(self, buf, nbytes=0):


class SSLContext: # pylint: disable=too-few-public-methods
""" Mock SSL Context """
"""Mock SSL Context"""

def __init__(self):
self.wrap_socket = mock.Mock(side_effect=self._wrap_socket)
Expand Down