Skip to content

Commit 97cb3bb

Browse files
authored
Merge branch 'adafruit:main' into connection-manager
2 parents e9e70fe + a05b19f commit 97cb3bb

File tree

2 files changed

+216
-20
lines changed

2 files changed

+216
-20
lines changed

adafruit_minimqtt/adafruit_minimqtt.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def __init__(
187187
self._is_connected = False
188188
self._msg_size_lim = MQTT_MSG_SZ_LIM
189189
self._pid = 0
190-
self._timestamp: float = 0
190+
self._last_msg_sent_timestamp: float = 0
191191
self.logger = NullLogger()
192192
"""An optional logging attribute that can be set with with a Logger
193193
to enable debug logging."""
@@ -537,6 +537,7 @@ def _connect(
537537
if self._username is not None:
538538
self._send_str(self._username)
539539
self._send_str(self._password)
540+
self._last_msg_sent_timestamp = self.get_monotonic_time()
540541
self.logger.debug("Receiving CONNACK packet from broker")
541542
stamp = self.get_monotonic_time()
542543
while True:
@@ -591,6 +592,7 @@ def disconnect(self) -> None:
591592
self._connection_manager.free_socket(self._sock)
592593
self._is_connected = False
593594
self._subscribed_topics = []
595+
self._last_msg_sent_timestamp = 0
594596
if self.on_disconnect is not None:
595597
self.on_disconnect(self, self.user_data, 0)
596598

@@ -604,6 +606,7 @@ def ping(self) -> list[int]:
604606
self._sock.send(MQTT_PINGREQ)
605607
ping_timeout = self.keep_alive
606608
stamp = self.get_monotonic_time()
609+
self._last_msg_sent_timestamp = stamp
607610
rc, rcs = None, []
608611
while rc != MQTT_PINGRESP:
609612
rc = self._wait_for_msg()
@@ -678,6 +681,7 @@ def publish(
678681
self._sock.send(pub_hdr_fixed)
679682
self._sock.send(pub_hdr_var)
680683
self._sock.send(msg)
684+
self._last_msg_sent_timestamp = self.get_monotonic_time()
681685
if qos == 0 and self.on_publish is not None:
682686
self.on_publish(self, self.user_data, topic, self._pid)
683687
if qos == 1:
@@ -755,6 +759,7 @@ def subscribe(self, topic: Optional[Union[tuple, str, list]], qos: int = 0) -> N
755759
self.logger.debug(f"payload: {payload}")
756760
self._sock.send(payload)
757761
stamp = self.get_monotonic_time()
762+
self._last_msg_sent_timestamp = stamp
758763
while True:
759764
op = self._wait_for_msg()
760765
if op is None:
@@ -830,6 +835,7 @@ def unsubscribe(self, topic: Optional[Union[str, list]]) -> None:
830835
for t in topics:
831836
self.logger.debug(f"UNSUBSCRIBING from topic {t}")
832837
self._sock.send(payload)
838+
self._last_msg_sent_timestamp = self.get_monotonic_time()
833839
self.logger.debug("Waiting for UNSUBACK...")
834840
while True:
835841
stamp = self.get_monotonic_time()
@@ -919,31 +925,41 @@ def reconnect(self, resub_topics: bool = True) -> int:
919925
return ret
920926

921927
def loop(self, timeout: float = 0) -> Optional[list[int]]:
922-
# pylint: disable = too-many-return-statements
923928
"""Non-blocking message loop. Use this method to check for incoming messages.
924929
Returns list of packet types of any messages received or None.
925930
926931
:param float timeout: return after this timeout, in seconds.
927932
928933
"""
934+
if timeout < self._socket_timeout:
935+
raise MMQTTException(
936+
# pylint: disable=consider-using-f-string
937+
"loop timeout ({}) must be bigger ".format(timeout)
938+
+ "than socket timeout ({}))".format(self._socket_timeout)
939+
)
940+
929941
self._connected()
930942
self.logger.debug(f"waiting for messages for {timeout} seconds")
931-
if self._timestamp == 0:
932-
self._timestamp = self.get_monotonic_time()
933-
current_time = self.get_monotonic_time()
934-
if current_time - self._timestamp >= self.keep_alive:
935-
self._timestamp = 0
936-
# Handle KeepAlive by expecting a PINGREQ/PINGRESP from the server
937-
self.logger.debug(
938-
"KeepAlive period elapsed - requesting a PINGRESP from the server..."
939-
)
940-
rcs = self.ping()
941-
return rcs
942943

943944
stamp = self.get_monotonic_time()
944945
rcs = []
945946

946947
while True:
948+
if (
949+
self.get_monotonic_time() - self._last_msg_sent_timestamp
950+
>= self.keep_alive
951+
):
952+
# Handle KeepAlive by expecting a PINGREQ/PINGRESP from the server
953+
self.logger.debug(
954+
"KeepAlive period elapsed - requesting a PINGRESP from the server..."
955+
)
956+
rcs.extend(self.ping())
957+
# ping() itself contains a _wait_for_msg() loop which might have taken a while,
958+
# so check here as well.
959+
if self.get_monotonic_time() - stamp > timeout:
960+
self.logger.debug(f"Loop timed out after {timeout} seconds")
961+
break
962+
947963
rc = self._wait_for_msg()
948964
if rc is not None:
949965
rcs.append(rc)
@@ -953,11 +969,13 @@ def loop(self, timeout: float = 0) -> Optional[list[int]]:
953969

954970
return rcs if rcs else None
955971

956-
def _wait_for_msg(self) -> Optional[int]:
972+
def _wait_for_msg(self, timeout: Optional[float] = None) -> Optional[int]:
957973
# pylint: disable = too-many-return-statements
958974

959975
"""Reads and processes network events.
960976
Return the packet type or None if there is nothing to be received.
977+
978+
:param float timeout: return after this timeout, in seconds.
961979
"""
962980
# CPython socket module contains a timeout attribute
963981
if hasattr(self._socket_pool, "timeout"):
@@ -967,7 +985,7 @@ def _wait_for_msg(self) -> Optional[int]:
967985
return None
968986
else: # socketpool, esp32spi
969987
try:
970-
res = self._sock_exact_recv(1)
988+
res = self._sock_exact_recv(1, timeout=timeout)
971989
except OSError as error:
972990
if error.errno in (errno.ETIMEDOUT, errno.EAGAIN):
973991
# raised by a socket timeout if 0 bytes were present
@@ -1036,7 +1054,9 @@ def _decode_remaining_length(self) -> int:
10361054
return n
10371055
sh += 7
10381056

1039-
def _sock_exact_recv(self, bufsize: int) -> bytearray:
1057+
def _sock_exact_recv(
1058+
self, bufsize: int, timeout: Optional[float] = None
1059+
) -> bytearray:
10401060
"""Reads _exact_ number of bytes from the connected socket. Will only return
10411061
bytearray with the exact number of bytes requested.
10421062
@@ -1047,6 +1067,7 @@ def _sock_exact_recv(self, bufsize: int) -> bytearray:
10471067
bytes is returned or trigger a timeout exception.
10481068
10491069
:param int bufsize: number of bytes to receive
1070+
:param float timeout: timeout, in seconds. Defaults to keep_alive
10501071
:return: byte array
10511072
"""
10521073
stamp = self.get_monotonic_time()
@@ -1058,7 +1079,7 @@ def _sock_exact_recv(self, bufsize: int) -> bytearray:
10581079
to_read = bufsize - recv_len
10591080
if to_read < 0:
10601081
raise MMQTTException(f"negative number of bytes to read: {to_read}")
1061-
read_timeout = self.keep_alive
1082+
read_timeout = timeout if timeout is not None else self.keep_alive
10621083
mv = mv[recv_len:]
10631084
while to_read > 0:
10641085
recv_len = self._sock.recv_into(mv, to_read)

tests/test_loop.py

Lines changed: 178 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,109 @@
88
import socket
99
import ssl
1010
import time
11+
import errno
12+
1113
from unittest import TestCase, main
1214
from unittest.mock import patch
15+
from unittest import mock
1316

1417
import adafruit_minimqtt.adafruit_minimqtt as MQTT
1518

1619

20+
class Nulltet:
21+
"""
22+
Mock Socket that does nothing.
23+
24+
Inspired by the Mocket class from Adafruit_CircuitPython_Requests
25+
"""
26+
27+
def __init__(self):
28+
self.sent = bytearray()
29+
30+
self.timeout = mock.Mock()
31+
self.connect = mock.Mock()
32+
self.close = mock.Mock()
33+
34+
def send(self, bytes_to_send):
35+
"""
36+
Record the bytes. return the length of this bytearray.
37+
"""
38+
self.sent.extend(bytes_to_send)
39+
return len(bytes_to_send)
40+
41+
# MiniMQTT checks for the presence of "recv_into" and switches behavior based on that.
42+
# pylint: disable=unused-argument,no-self-use
43+
def recv_into(self, retbuf, bufsize):
44+
"""Always raise timeout exception."""
45+
exc = OSError()
46+
exc.errno = errno.ETIMEDOUT
47+
raise exc
48+
49+
50+
class Pingtet:
51+
"""
52+
Mock Socket tailored for PINGREQ testing.
53+
Records sent data, hands out PINGRESP for each PINGREQ received.
54+
55+
Inspired by the Mocket class from Adafruit_CircuitPython_Requests
56+
"""
57+
58+
PINGRESP = bytearray([0xD0, 0x00])
59+
60+
def __init__(self):
61+
self._to_send = self.PINGRESP
62+
63+
self.sent = bytearray()
64+
65+
self.timeout = mock.Mock()
66+
self.connect = mock.Mock()
67+
self.close = mock.Mock()
68+
69+
self._got_pingreq = False
70+
71+
def send(self, bytes_to_send):
72+
"""
73+
Recognize PINGREQ and record the indication that it was received.
74+
Assumes it was sent in one chunk (of 2 bytes).
75+
Also record the bytes. return the length of this bytearray.
76+
"""
77+
self.sent.extend(bytes_to_send)
78+
if bytes_to_send == b"\xc0\0":
79+
self._got_pingreq = True
80+
return len(bytes_to_send)
81+
82+
# MiniMQTT checks for the presence of "recv_into" and switches behavior based on that.
83+
def recv_into(self, retbuf, bufsize):
84+
"""
85+
If the PINGREQ indication is on, return PINGRESP, otherwise raise timeout exception.
86+
"""
87+
if self._got_pingreq:
88+
size = min(bufsize, len(self._to_send))
89+
if size == 0:
90+
return size
91+
chop = self._to_send[0:size]
92+
retbuf[0:] = chop
93+
self._to_send = self._to_send[size:]
94+
if len(self._to_send) == 0:
95+
self._got_pingreq = False
96+
self._to_send = self.PINGRESP
97+
return size
98+
99+
exc = OSError()
100+
exc.errno = errno.ETIMEDOUT
101+
raise exc
102+
103+
17104
class Loop(TestCase):
18105
"""basic loop() test"""
19106

20107
connect_times = []
21108
INITIAL_RCS_VAL = 42
22109
rcs_val = INITIAL_RCS_VAL
23110

24-
def fake_wait_for_msg(self):
111+
def fake_wait_for_msg(self, timeout=1):
25112
"""_wait_for_msg() replacement. Sleeps for 1 second and returns an integer."""
26-
time.sleep(1)
113+
time.sleep(timeout)
27114
retval = self.rcs_val
28115
self.rcs_val += 1
29116
return retval
@@ -54,6 +141,8 @@ def test_loop_basic(self) -> None:
54141

55142
time_before = time.monotonic()
56143
timeout = random.randint(3, 8)
144+
# pylint: disable=protected-access
145+
mqtt_client._last_msg_sent_timestamp = mqtt_client.get_monotonic_time()
57146
rcs = mqtt_client.loop(timeout=timeout)
58147
time_after = time.monotonic()
59148

@@ -62,12 +151,33 @@ def test_loop_basic(self) -> None:
62151

63152
# Check the return value.
64153
assert rcs is not None
65-
assert len(rcs) > 1
154+
assert len(rcs) >= 1
66155
expected_rc = self.INITIAL_RCS_VAL
156+
# pylint: disable=not-an-iterable
67157
for ret_code in rcs:
68158
assert ret_code == expected_rc
69159
expected_rc += 1
70160

161+
# pylint: disable=invalid-name
162+
def test_loop_timeout_vs_socket_timeout(self):
163+
"""
164+
loop() should throw MMQTTException if the timeout argument
165+
is bigger than the socket timeout.
166+
"""
167+
mqtt_client = MQTT.MQTT(
168+
broker="127.0.0.1",
169+
port=1883,
170+
socket_pool=socket,
171+
ssl_context=ssl.create_default_context(),
172+
socket_timeout=1,
173+
)
174+
175+
mqtt_client.is_connected = lambda: True
176+
with self.assertRaises(MQTT.MMQTTException) as context:
177+
mqtt_client.loop(timeout=0.5)
178+
179+
assert "loop timeout" in str(context.exception)
180+
71181
def test_loop_is_connected(self):
72182
"""
73183
loop() should throw MMQTTException if not connected
@@ -84,6 +194,71 @@ def test_loop_is_connected(self):
84194

85195
assert "not connected" in str(context.exception)
86196

197+
# pylint: disable=no-self-use
198+
def test_loop_ping_timeout(self):
199+
"""Verify that ping will be sent even with loop timeout bigger than keep alive timeout
200+
and no outgoing messages are sent."""
201+
202+
recv_timeout = 2
203+
keep_alive_timeout = recv_timeout * 2
204+
mqtt_client = MQTT.MQTT(
205+
broker="localhost",
206+
port=1883,
207+
ssl_context=ssl.create_default_context(),
208+
connect_retries=1,
209+
socket_timeout=1,
210+
recv_timeout=recv_timeout,
211+
keep_alive=keep_alive_timeout,
212+
)
213+
214+
# patch is_connected() to avoid CONNECT/CONNACK handling.
215+
mqtt_client.is_connected = lambda: True
216+
mocket = Pingtet()
217+
# pylint: disable=protected-access
218+
mqtt_client._sock = mocket
219+
220+
start = time.monotonic()
221+
res = mqtt_client.loop(timeout=2 * keep_alive_timeout)
222+
assert time.monotonic() - start >= 2 * keep_alive_timeout
223+
assert len(mocket.sent) > 0
224+
assert len(res) == 2
225+
assert set(res) == {int(0xD0)}
226+
227+
# pylint: disable=no-self-use
228+
def test_loop_ping_vs_msgs_sent(self):
229+
"""Verify that ping will not be sent unnecessarily."""
230+
231+
recv_timeout = 2
232+
keep_alive_timeout = recv_timeout * 2
233+
mqtt_client = MQTT.MQTT(
234+
broker="localhost",
235+
port=1883,
236+
ssl_context=ssl.create_default_context(),
237+
connect_retries=1,
238+
socket_timeout=1,
239+
recv_timeout=recv_timeout,
240+
keep_alive=keep_alive_timeout,
241+
)
242+
243+
# patch is_connected() to avoid CONNECT/CONNACK handling.
244+
mqtt_client.is_connected = lambda: True
245+
246+
# With QoS=0 no PUBACK message is sent, so Nulltet can be used.
247+
mocket = Nulltet()
248+
# pylint: disable=protected-access
249+
mqtt_client._sock = mocket
250+
251+
i = 0
252+
topic = "foo"
253+
message = "bar"
254+
for _ in range(3 * keep_alive_timeout):
255+
mqtt_client.publish(topic, message, qos=0)
256+
mqtt_client.loop(1)
257+
i += 1
258+
259+
# This means no other messages than the PUBLISH messages generated by the code above.
260+
assert len(mocket.sent) == i * (2 + 2 + len(topic) + len(message))
261+
87262

88263
if __name__ == "__main__":
89264
main()

0 commit comments

Comments
 (0)