Skip to content

Commit 75f3845

Browse files
authored
Merge pull request #231 from dhalbert/partial-send
handle partial socket send()'s
2 parents 76f8c28 + 9be1a4c commit 75f3845

File tree

2 files changed

+37
-24
lines changed

2 files changed

+37
-24
lines changed

adafruit_minimqtt/adafruit_minimqtt.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,21 @@ def connect(
461461
raise MMQTTException(exc_msg) from last_exception
462462
raise MMQTTException(exc_msg)
463463

464+
def _send_bytes(
465+
self,
466+
buffer: Union[bytes, bytearray, memoryview],
467+
):
468+
bytes_sent: int = 0
469+
bytes_to_send = len(buffer)
470+
view = memoryview(buffer)
471+
while bytes_sent < bytes_to_send:
472+
try:
473+
bytes_sent += self._sock.send(view[bytes_sent:])
474+
except OSError as exc:
475+
if exc.errno == EAGAIN:
476+
continue
477+
raise
478+
464479
def _connect( # noqa: PLR0912, PLR0915, Too many branches, Too many statements
465480
self,
466481
clean_session: bool = True,
@@ -529,8 +544,8 @@ def _connect( # noqa: PLR0912, PLR0915, Too many branches, Too many statements
529544
self.logger.debug("Sending CONNECT to broker...")
530545
self.logger.debug(f"Fixed Header: {fixed_header}")
531546
self.logger.debug(f"Variable Header: {var_header}")
532-
self._sock.send(fixed_header)
533-
self._sock.send(var_header)
547+
self._send_bytes(fixed_header)
548+
self._send_bytes(var_header)
534549
# [MQTT-3.1.3-4]
535550
self._send_str(self.client_id)
536551
if self._lw_topic:
@@ -591,7 +606,7 @@ def disconnect(self) -> None:
591606
self._connected()
592607
self.logger.debug("Sending DISCONNECT packet to broker")
593608
try:
594-
self._sock.send(MQTT_DISCONNECT)
609+
self._send_bytes(MQTT_DISCONNECT)
595610
except (MemoryError, OSError, RuntimeError) as e:
596611
self.logger.warning(f"Unable to send DISCONNECT packet: {e}")
597612
self._close_socket()
@@ -608,7 +623,7 @@ def ping(self) -> list[int]:
608623
"""
609624
self._connected()
610625
self.logger.debug("Sending PINGREQ")
611-
self._sock.send(MQTT_PINGREQ)
626+
self._send_bytes(MQTT_PINGREQ)
612627
ping_timeout = self.keep_alive
613628
stamp = ticks_ms()
614629

@@ -683,9 +698,9 @@ def publish( # noqa: PLR0912, Too many branches
683698
qos,
684699
retain,
685700
)
686-
self._sock.send(pub_hdr_fixed)
687-
self._sock.send(pub_hdr_var)
688-
self._sock.send(msg)
701+
self._send_bytes(pub_hdr_fixed)
702+
self._send_bytes(pub_hdr_var)
703+
self._send_bytes(msg)
689704
self._last_msg_sent_timestamp = ticks_ms()
690705
if qos == 0 and self.on_publish is not None:
691706
self.on_publish(self, self.user_data, topic, self._pid)
@@ -749,12 +764,12 @@ def subscribe( # noqa: PLR0912, PLR0915, Too many branches, Too many statements
749764
packet_length += sum(len(topic.encode("utf-8")) for topic, qos in topics)
750765
self._encode_remaining_length(fixed_header, remaining_length=packet_length)
751766
self.logger.debug(f"Fixed Header: {fixed_header}")
752-
self._sock.send(fixed_header)
767+
self._send_bytes(fixed_header)
753768
self._pid = self._pid + 1 if self._pid < 0xFFFF else 1
754769
packet_id_bytes = self._pid.to_bytes(2, "big")
755770
var_header = packet_id_bytes
756771
self.logger.debug(f"Variable Header: {var_header}")
757-
self._sock.send(var_header)
772+
self._send_bytes(var_header)
758773
# attaching topic and QOS level to the packet
759774
payload = b""
760775
for t, q in topics:
@@ -764,7 +779,7 @@ def subscribe( # noqa: PLR0912, PLR0915, Too many branches, Too many statements
764779
for t, q in topics:
765780
self.logger.debug(f"SUBSCRIBING to topic {t} with QoS {q}")
766781
self.logger.debug(f"payload: {payload}")
767-
self._sock.send(payload)
782+
self._send_bytes(payload)
768783
stamp = ticks_ms()
769784
self._last_msg_sent_timestamp = stamp
770785
while True:
@@ -829,19 +844,19 @@ def unsubscribe( # noqa: PLR0912, Too many branches
829844
packet_length += sum(len(topic.encode("utf-8")) for topic in topics)
830845
self._encode_remaining_length(fixed_header, remaining_length=packet_length)
831846
self.logger.debug(f"Fixed Header: {fixed_header}")
832-
self._sock.send(fixed_header)
847+
self._send_bytes(fixed_header)
833848
self._pid = self._pid + 1 if self._pid < 0xFFFF else 1
834849
packet_id_bytes = self._pid.to_bytes(2, "big")
835850
var_header = packet_id_bytes
836851
self.logger.debug(f"Variable Header: {var_header}")
837-
self._sock.send(var_header)
852+
self._send_bytes(var_header)
838853
payload = b""
839854
for t in topics:
840855
topic_size = len(t.encode("utf-8")).to_bytes(2, "big")
841856
payload += topic_size + t.encode()
842857
for t in topics:
843858
self.logger.debug(f"UNSUBSCRIBING from topic {t}")
844-
self._sock.send(payload)
859+
self._send_bytes(payload)
845860
self._last_msg_sent_timestamp = ticks_ms()
846861
self.logger.debug("Waiting for UNSUBACK...")
847862
while True:
@@ -1028,7 +1043,7 @@ def _wait_for_msg( # noqa: PLR0912, Too many branches
10281043
if res[0] & 0x06 == 0x02:
10291044
pkt = bytearray(b"\x40\x02\0\0")
10301045
struct.pack_into("!H", pkt, 2, pid)
1031-
self._sock.send(pkt)
1046+
self._send_bytes(pkt)
10321047
elif res[0] & 6 == 4:
10331048
assert 0
10341049

@@ -1109,11 +1124,11 @@ def _send_str(self, string: str) -> None:
11091124
11101125
"""
11111126
if isinstance(string, str):
1112-
self._sock.send(struct.pack("!H", len(string.encode("utf-8"))))
1113-
self._sock.send(str.encode(string, "utf-8"))
1127+
self._send_bytes(struct.pack("!H", len(string.encode("utf-8"))))
1128+
self._send_bytes(str.encode(string, "utf-8"))
11141129
else:
1115-
self._sock.send(struct.pack("!H", len(string)))
1116-
self._sock.send(string)
1130+
self._send_bytes(struct.pack("!H", len(string)))
1131+
self._send_bytes(string)
11171132

11181133
@staticmethod
11191134
def _valid_topic(topic: str) -> None:

tests/test_recv_timeout.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from unittest import TestCase, main
1010
from unittest.mock import Mock
1111

12+
from mocket import Mocket
13+
1214
import adafruit_minimqtt.adafruit_minimqtt as MQTT
1315

1416

@@ -34,7 +36,7 @@ def test_recv_timeout_vs_keepalive(self) -> None:
3436
)
3537

3638
# Create a mock socket that will accept anything and return nothing.
37-
socket_mock = Mock()
39+
socket_mock = Mocket(b"")
3840
socket_mock.recv_into = Mock(side_effect=side_effect)
3941
mqtt_client._sock = socket_mock
4042

@@ -43,12 +45,8 @@ def test_recv_timeout_vs_keepalive(self) -> None:
4345
with self.assertRaises(MQTT.MMQTTException):
4446
mqtt_client.ping()
4547

46-
# Verify the mock interactions.
47-
socket_mock.send.assert_called_once()
48-
socket_mock.recv_into.assert_called()
49-
5048
now = time.monotonic()
51-
assert recv_timeout <= (now - start) <= (keep_alive + 0.1)
49+
assert recv_timeout <= (now - start) <= (keep_alive + 0.2)
5250

5351

5452
if __name__ == "__main__":

0 commit comments

Comments
 (0)