Skip to content

Commit 52b6b2c

Browse files
committed
Add ability to post a file via the data param
1 parent fac4012 commit 52b6b2c

File tree

2 files changed

+60
-23
lines changed

2 files changed

+60
-23
lines changed

adafruit_requests.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050

5151
if not sys.implementation.name == "circuitpython":
5252
from types import TracebackType
53-
from typing import Any, Dict, Optional, Type
53+
from typing import IO, Any, Dict, Optional, Type
5454

5555
from circuitpython_typing.socket import (
5656
SocketpoolModuleType,
@@ -387,19 +387,7 @@ def _build_boundary_data(self, files: dict): # pylint: disable=too-many-locals
387387
boundary_objects.append("\r\n")
388388

389389
if hasattr(file_handle, "read"):
390-
is_binary = False
391-
try:
392-
content = file_handle.read(1)
393-
is_binary = isinstance(content, bytes)
394-
except UnicodeError:
395-
is_binary = False
396-
397-
if not is_binary:
398-
raise ValueError("Files must be opened in binary mode")
399-
400-
file_handle.seek(0, SEEK_END)
401-
content_length += file_handle.tell()
402-
file_handle.seek(0)
390+
content_length += self._get_file_length(file_handle)
403391

404392
boundary_objects.append(file_handle)
405393
boundary_objects.append("\r\n")
@@ -428,6 +416,24 @@ def _check_headers(headers: Dict[str, str]):
428416
f"Header part ({value}) from {key} must be of type str or bytes, not {type(value)}"
429417
)
430418

419+
@staticmethod
420+
def _get_file_length(file_handle: IO):
421+
is_binary = False
422+
try:
423+
file_handle.seek(0)
424+
content = file_handle.read(1)
425+
is_binary = isinstance(content, bytes)
426+
except UnicodeError:
427+
is_binary = False
428+
429+
if not is_binary:
430+
raise ValueError("Files must be opened in binary mode")
431+
432+
file_handle.seek(0, SEEK_END)
433+
content_length = file_handle.tell()
434+
file_handle.seek(0)
435+
return content_length
436+
431437
@staticmethod
432438
def _send(socket: SocketType, data: bytes):
433439
total_sent = 0
@@ -458,13 +464,16 @@ def _send_boundary_objects(self, socket: SocketType, boundary_objects: Any):
458464
if isinstance(boundary_object, str):
459465
self._send_as_bytes(socket, boundary_object)
460466
else:
461-
chunk_size = 32
462-
b = bytearray(chunk_size)
463-
while True:
464-
size = boundary_object.readinto(b)
465-
if size == 0:
466-
break
467-
self._send(socket, b[:size])
467+
self._send_file(socket, boundary_object)
468+
469+
def _send_file(self, socket: SocketType, file_handle: IO):
470+
chunk_size = 36
471+
b = bytearray(chunk_size)
472+
while True:
473+
size = file_handle.readinto(b)
474+
if size == 0:
475+
break
476+
self._send(socket, b[:size])
468477

469478
def _send_header(self, socket, header, value):
470479
if value is None:
@@ -517,12 +526,16 @@ def _send_request( # pylint: disable=too-many-arguments
517526

518527
# If files are send, build data to send and calculate length
519528
content_length = 0
529+
data_is_file = False
520530
boundary_objects = None
521531
if files and isinstance(files, dict):
522532
boundary_string, content_length, boundary_objects = (
523533
self._build_boundary_data(files)
524534
)
525535
content_type_header = f"multipart/form-data; boundary={boundary_string}"
536+
elif data and hasattr(data, "read"):
537+
data_is_file = True
538+
content_length = self._get_file_length(data)
526539
else:
527540
if data is None:
528541
data = b""
@@ -551,7 +564,9 @@ def _send_request( # pylint: disable=too-many-arguments
551564
self._send(socket, b"\r\n")
552565

553566
# Send data
554-
if data:
567+
if data_is_file:
568+
self._send_file(socket, data)
569+
elif data:
555570
self._send(socket, bytes(data))
556571
elif boundary_objects:
557572
self._send_boundary_objects(socket, boundary_objects)

tests/files_test.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def get_actual_request_data(log_stream):
5050
boundary = boundary_search[0]
5151
if content_length_search:
5252
content_length = content_length_search[0]
53-
if "Content-Disposition" in log_arg:
53+
if "Content-Disposition" in log_arg or "\\x" in log_arg:
5454
# this will look like:
5555
# b\'{content}\'
5656
# and escaped characters look like:
@@ -63,6 +63,28 @@ def get_actual_request_data(log_stream):
6363
return boundary, content_length, actual_request_post
6464

6565

66+
def test_post_file_as_data( # pylint: disable=unused-argument
67+
requests, sock, log_stream, post_url, request_logging
68+
):
69+
with open("tests/files/red_green.png", "rb") as file_1:
70+
python_requests.post(post_url, data=file_1, timeout=30)
71+
__, content_length, actual_request_post = get_actual_request_data(log_stream)
72+
73+
requests.post("http://" + mocket.MOCK_HOST_1 + "/post", data=file_1)
74+
75+
sock.connect.assert_called_once_with((mocket.MOCK_POOL_IP, 80))
76+
sock.send.assert_has_calls(
77+
[
78+
mock.call(b"Content-Length"),
79+
mock.call(b": "),
80+
mock.call(content_length.encode()),
81+
mock.call(b"\r\n"),
82+
]
83+
)
84+
sent = b"".join(sock.sent_data)
85+
assert sent.endswith(actual_request_post)
86+
87+
6688
def test_post_files_text( # pylint: disable=unused-argument
6789
sock, requests, log_stream, post_url, request_logging
6890
):

0 commit comments

Comments
 (0)