Skip to content

Commit e9e70fe

Browse files
authored
Merge branch 'adafruit:main' into connection-manager
2 parents fe12510 + 70faa4f commit e9e70fe

File tree

4 files changed

+499
-64
lines changed

4 files changed

+499
-64
lines changed

adafruit_minimqtt/adafruit_minimqtt.py

Lines changed: 77 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@
6969
MQTT_PINGREQ = b"\xc0\0"
7070
MQTT_PINGRESP = const(0xD0)
7171
MQTT_PUBLISH = const(0x30)
72-
MQTT_SUB = b"\x82"
73-
MQTT_UNSUB = b"\xA2"
72+
MQTT_SUB = const(0x82)
73+
MQTT_UNSUB = const(0xA2)
7474
MQTT_DISCONNECT = b"\xe0\0"
7575

7676
MQTT_PKT_TYPE_MASK = const(0xF0)
@@ -494,13 +494,12 @@ def _connect(
494494
exception_passthrough=True,
495495
)
496496

497-
# Fixed Header
498497
fixed_header = bytearray([0x10])
499498

500499
# Variable CONNECT header [MQTT 3.1.2]
501500
# The byte array is used as a template.
502-
var_header = bytearray(b"\x04MQTT\x04\x02\0\0")
503-
var_header[6] = clean_session << 1
501+
var_header = bytearray(b"\x00\x04MQTT\x04\x02\0\0")
502+
var_header[7] = clean_session << 1
504503

505504
# Set up variable header and remaining_length
506505
remaining_length = 12 + len(self.client_id.encode("utf-8"))
@@ -511,36 +510,19 @@ def _connect(
511510
+ 2
512511
+ len(self._password.encode("utf-8"))
513512
)
514-
var_header[6] |= 0xC0
513+
var_header[7] |= 0xC0
515514
if self.keep_alive:
516515
assert self.keep_alive < MQTT_TOPIC_LENGTH_LIMIT
517-
var_header[7] |= self.keep_alive >> 8
518-
var_header[8] |= self.keep_alive & 0x00FF
516+
var_header[8] |= self.keep_alive >> 8
517+
var_header[9] |= self.keep_alive & 0x00FF
519518
if self._lw_topic:
520519
remaining_length += (
521520
2 + len(self._lw_topic.encode("utf-8")) + 2 + len(self._lw_msg)
522521
)
523-
var_header[6] |= 0x4 | (self._lw_qos & 0x1) << 3 | (self._lw_qos & 0x2) << 3
524-
var_header[6] |= self._lw_retain << 5
525-
526-
# Remaining length calculation
527-
large_rel_length = False
528-
if remaining_length > 0x7F:
529-
large_rel_length = True
530-
# Calculate Remaining Length [2.2.3]
531-
while remaining_length > 0:
532-
encoded_byte = remaining_length % 0x80
533-
remaining_length = remaining_length // 0x80
534-
# if there is more data to encode, set the top bit of the byte
535-
if remaining_length > 0:
536-
encoded_byte |= 0x80
537-
fixed_header.append(encoded_byte)
538-
if large_rel_length:
539-
fixed_header.append(0x00)
540-
else:
541-
fixed_header.append(remaining_length)
542-
fixed_header.append(0x00)
522+
var_header[7] |= 0x4 | (self._lw_qos & 0x1) << 3 | (self._lw_qos & 0x2) << 3
523+
var_header[7] |= self._lw_retain << 5
543524

525+
self._encode_remaining_length(fixed_header, remaining_length)
544526
self.logger.debug("Sending CONNECT to broker...")
545527
self.logger.debug(f"Fixed Header: {fixed_header}")
546528
self.logger.debug(f"Variable Header: {var_header}")
@@ -577,6 +559,26 @@ def _connect(
577559
f"No data received from broker for {self._recv_timeout} seconds."
578560
)
579561

562+
# pylint: disable=no-self-use
563+
def _encode_remaining_length(
564+
self, fixed_header: bytearray, remaining_length: int
565+
) -> None:
566+
"""Encode Remaining Length [2.2.3]"""
567+
if remaining_length > 268_435_455:
568+
raise MMQTTException("invalid remaining length")
569+
570+
# Remaining length calculation
571+
if remaining_length > 0x7F:
572+
while remaining_length > 0:
573+
encoded_byte = remaining_length % 0x80
574+
remaining_length = remaining_length // 0x80
575+
# if there is more data to encode, set the top bit of the byte
576+
if remaining_length > 0:
577+
encoded_byte |= 0x80
578+
fixed_header.append(encoded_byte)
579+
else:
580+
fixed_header.append(remaining_length)
581+
580582
def disconnect(self) -> None:
581583
"""Disconnects the MiniMQTT client from the MQTT broker."""
582584
self._connected()
@@ -663,16 +665,7 @@ def publish(
663665
pub_hdr_var.append(self._pid >> 8)
664666
pub_hdr_var.append(self._pid & 0xFF)
665667

666-
# Calculate remaining length [2.2.3]
667-
if remaining_length > 0x7F:
668-
while remaining_length > 0:
669-
encoded_byte = remaining_length % 0x80
670-
remaining_length = remaining_length // 0x80
671-
if remaining_length > 0:
672-
encoded_byte |= 0x80
673-
pub_hdr_fixed.append(encoded_byte)
674-
else:
675-
pub_hdr_fixed.append(remaining_length)
668+
self._encode_remaining_length(pub_hdr_fixed, remaining_length)
676669

677670
self.logger.debug(
678671
"Sending PUBLISH\nTopic: %s\nMsg: %s\
@@ -707,9 +700,9 @@ def publish(
707700
f"No data received from broker for {self._recv_timeout} seconds."
708701
)
709702

710-
def subscribe(self, topic: str, qos: int = 0) -> None:
703+
def subscribe(self, topic: Optional[Union[tuple, str, list]], qos: int = 0) -> None:
711704
"""Subscribes to a topic on the MQTT Broker.
712-
This method can subscribe to one topics or multiple topics.
705+
This method can subscribe to one topic or multiple topics.
713706
714707
:param str|tuple|list topic: Unique MQTT topic identifier string. If
715708
this is a `tuple`, then the tuple should
@@ -739,21 +732,28 @@ def subscribe(self, topic: str, qos: int = 0) -> None:
739732
self._valid_topic(t)
740733
topics.append((t, q))
741734
# Assemble packet
735+
self.logger.debug("Sending SUBSCRIBE to broker...")
736+
fixed_header = bytearray([MQTT_SUB])
742737
packet_length = 2 + (2 * len(topics)) + (1 * len(topics))
743738
packet_length += sum(len(topic.encode("utf-8")) for topic, qos in topics)
744-
packet_length_byte = packet_length.to_bytes(1, "big")
739+
self._encode_remaining_length(fixed_header, remaining_length=packet_length)
740+
self.logger.debug(f"Fixed Header: {fixed_header}")
741+
self._sock.send(fixed_header)
745742
self._pid = self._pid + 1 if self._pid < 0xFFFF else 1
746743
packet_id_bytes = self._pid.to_bytes(2, "big")
747-
# Packet with variable and fixed headers
748-
packet = MQTT_SUB + packet_length_byte + packet_id_bytes
744+
var_header = packet_id_bytes
745+
self.logger.debug(f"Variable Header: {var_header}")
746+
self._sock.send(var_header)
749747
# attaching topic and QOS level to the packet
748+
payload = bytes()
750749
for t, q in topics:
751750
topic_size = len(t.encode("utf-8")).to_bytes(2, "big")
752751
qos_byte = q.to_bytes(1, "big")
753-
packet += topic_size + t.encode() + qos_byte
752+
payload += topic_size + t.encode() + qos_byte
754753
for t, q in topics:
755-
self.logger.debug("SUBSCRIBING to topic %s with QoS %d", t, q)
756-
self._sock.send(packet)
754+
self.logger.debug(f"SUBSCRIBING to topic {t} with QoS {q}")
755+
self.logger.debug(f"payload: {payload}")
756+
self._sock.send(payload)
757757
stamp = self.get_monotonic_time()
758758
while True:
759759
op = self._wait_for_msg()
@@ -764,13 +764,13 @@ def subscribe(self, topic: str, qos: int = 0) -> None:
764764
)
765765
else:
766766
if op == 0x90:
767-
rc = self._sock_exact_recv(3)
768-
# Check packet identifier.
769-
assert rc[1] == packet[2] and rc[2] == packet[3]
770-
remaining_len = rc[0] - 2
767+
remaining_len = self._decode_remaining_length()
771768
assert remaining_len > 0
772-
rc = self._sock_exact_recv(remaining_len)
773-
for i in range(0, remaining_len):
769+
rc = self._sock_exact_recv(2)
770+
# Check packet identifier.
771+
assert rc[0] == var_header[0] and rc[1] == var_header[1]
772+
rc = self._sock_exact_recv(remaining_len - 2)
773+
for i in range(0, remaining_len - 2):
774774
if rc[i] not in [0, 1, 2]:
775775
raise MMQTTException(
776776
f"SUBACK Failure for topic {topics[i][0]}: {hex(rc[i])}"
@@ -780,13 +780,17 @@ def subscribe(self, topic: str, qos: int = 0) -> None:
780780
if self.on_subscribe is not None:
781781
self.on_subscribe(self, self.user_data, t, q)
782782
self._subscribed_topics.append(t)
783+
783784
return
784785

785-
raise MMQTTException(
786-
f"invalid message received as response to SUBSCRIBE: {hex(op)}"
787-
)
786+
if op != MQTT_PUBLISH:
787+
# [3.8.4] The Server is permitted to start sending PUBLISH packets
788+
# matching the Subscription before the Server sends the SUBACK Packet.
789+
raise MMQTTException(
790+
f"invalid message received as response to SUBSCRIBE: {hex(op)}"
791+
)
788792

789-
def unsubscribe(self, topic: str) -> None:
793+
def unsubscribe(self, topic: Optional[Union[str, list]]) -> None:
790794
"""Unsubscribes from a MQTT topic.
791795
792796
:param str|list topic: Unique MQTT topic identifier string or list.
@@ -807,18 +811,25 @@ def unsubscribe(self, topic: str) -> None:
807811
"Topic must be subscribed to before attempting unsubscribe."
808812
)
809813
# Assemble packet
814+
self.logger.debug("Sending UNSUBSCRIBE to broker...")
815+
fixed_header = bytearray([MQTT_UNSUB])
810816
packet_length = 2 + (2 * len(topics))
811817
packet_length += sum(len(topic.encode("utf-8")) for topic in topics)
812-
packet_length_byte = packet_length.to_bytes(1, "big")
818+
self._encode_remaining_length(fixed_header, remaining_length=packet_length)
819+
self.logger.debug(f"Fixed Header: {fixed_header}")
820+
self._sock.send(fixed_header)
813821
self._pid = self._pid + 1 if self._pid < 0xFFFF else 1
814822
packet_id_bytes = self._pid.to_bytes(2, "big")
815-
packet = MQTT_UNSUB + packet_length_byte + packet_id_bytes
823+
var_header = packet_id_bytes
824+
self.logger.debug(f"Variable Header: {var_header}")
825+
self._sock.send(var_header)
826+
payload = bytes()
816827
for t in topics:
817828
topic_size = len(t.encode("utf-8")).to_bytes(2, "big")
818-
packet += topic_size + t.encode()
829+
payload += topic_size + t.encode()
819830
for t in topics:
820-
self.logger.debug("UNSUBSCRIBING from topic %s", t)
821-
self._sock.send(packet)
831+
self.logger.debug(f"UNSUBSCRIBING from topic {t}")
832+
self._sock.send(payload)
822833
self.logger.debug("Waiting for UNSUBACK...")
823834
while True:
824835
stamp = self.get_monotonic_time()
@@ -979,7 +990,7 @@ def _wait_for_msg(self) -> Optional[int]:
979990
return pkt_type
980991

981992
# Handle only the PUBLISH packet type from now on.
982-
sz = self._recv_len()
993+
sz = self._decode_remaining_length()
983994
# topic length MSB & LSB
984995
topic_len_buf = self._sock_exact_recv(2)
985996
topic_len = int((topic_len_buf[0] << 8) | topic_len_buf[1])
@@ -1012,11 +1023,13 @@ def _wait_for_msg(self) -> Optional[int]:
10121023

10131024
return pkt_type
10141025

1015-
def _recv_len(self) -> int:
1016-
"""Unpack MQTT message length."""
1026+
def _decode_remaining_length(self) -> int:
1027+
"""Decode Remaining Length [2.2.3]"""
10171028
n = 0
10181029
sh = 0
10191030
while True:
1031+
if sh > 28:
1032+
raise MMQTTException("invalid remaining length encoding")
10201033
b = self._sock_exact_recv(1)[0]
10211034
n |= (b & 0x7F) << sh
10221035
if not b & 0x80:

tests/mocket.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# SPDX-FileCopyrightText: 2023 Vladimír Kotal
2+
#
3+
# SPDX-License-Identifier: Unlicense
4+
5+
"""fake socket class for protocol level testing"""
6+
7+
from unittest import mock
8+
9+
10+
class Mocket:
11+
"""
12+
Mock Socket tailored for MiniMQTT testing. Records sent data,
13+
hands out pre-recorded reply.
14+
15+
Inspired by the Mocket class from Adafruit_CircuitPython_Requests
16+
"""
17+
18+
def __init__(self, to_send):
19+
self._to_send = to_send
20+
21+
self.sent = bytearray()
22+
23+
self.timeout = mock.Mock()
24+
self.connect = mock.Mock()
25+
self.close = mock.Mock()
26+
27+
def send(self, bytes_to_send):
28+
"""merely record the bytes. return the length of this bytearray."""
29+
self.sent.extend(bytes_to_send)
30+
return len(bytes_to_send)
31+
32+
# MiniMQTT checks for the presence of "recv_into" and switches behavior based on that.
33+
def recv_into(self, retbuf, bufsize):
34+
"""return data from internal buffer"""
35+
size = min(bufsize, len(self._to_send))
36+
if size == 0:
37+
return size
38+
chop = self._to_send[0:size]
39+
retbuf[0:] = chop
40+
self._to_send = self._to_send[size:]
41+
return size

0 commit comments

Comments
 (0)