diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 81630e4a..4d6e5080 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -32,6 +32,16 @@ import time from random import randint +try: + from typing import List, Optional, Tuple, Type, Union +except ImportError: + pass + +try: + from types import TracebackType +except ImportError: + pass + from micropython import const from .matcher import MQTTMatcher @@ -84,7 +94,7 @@ class TemporaryError(Exception): # Legacy ESP32SPI Socket API -def set_socket(sock, iface=None): +def set_socket(sock, iface=None) -> None: """Legacy API for setting the socket and network interface. :param sock: socket object. @@ -100,7 +110,7 @@ def set_socket(sock, iface=None): class _FakeSSLSocket: - def __init__(self, socket, tls_mode): + def __init__(self, socket, tls_mode) -> None: self._socket = socket self._mode = tls_mode self.settimeout = socket.settimeout @@ -117,10 +127,10 @@ def connect(self, address): class _FakeSSLContext: - def __init__(self, iface): + def __init__(self, iface) -> None: self._iface = iface - def wrap_socket(self, socket, server_hostname=None): + def wrap_socket(self, socket, server_hostname=None) -> _FakeSSLSocket: """Return the same socket""" # pylint: disable=unused-argument return _FakeSSLSocket(socket, self._iface.TLS_MODE) @@ -134,7 +144,7 @@ def nothing(self, msg: str, *args) -> None: """no action""" pass - def __init__(self): + def __init__(self) -> None: for log_level in ["debug", "info", "warning", "error", "critical"]: setattr(NullLogger, log_level, self.nothing) @@ -166,21 +176,21 @@ class MQTT: def __init__( self, *, - broker, - port=None, - username=None, - password=None, - client_id=None, - is_ssl=None, - keep_alive=60, - recv_timeout=10, + broker: str, + port: Optional[int] = None, + username: Optional[str] = None, + password: Optional[str] = None, + client_id: Optional[str] = None, + is_ssl: Optional[bool] = None, + keep_alive: int = 60, + recv_timeout: int = 10, socket_pool=None, ssl_context=None, - use_binary_mode=False, - socket_timeout=1, - connect_retries=5, + use_binary_mode: bool = False, + socket_timeout: int = 1, + connect_retries: int = 5, user_data=None, - ): + ) -> None: self._socket_pool = socket_pool self._ssl_context = ssl_context @@ -200,7 +210,7 @@ def __init__( self._is_connected = False self._msg_size_lim = MQTT_MSG_SZ_LIM self._pid = 0 - self._timestamp = 0 + self._timestamp: float = 0 self.logger = NullLogger() """An optional logging attribute that can be set with with a Logger to enable debug logging.""" @@ -253,7 +263,7 @@ def __init__( self._lw_retain = False # List of subscribed topics, used for tracking - self._subscribed_topics = [] + self._subscribed_topics: List[str] = [] self._on_message_filtered = MQTTMatcher() # Default topic callback methods @@ -265,7 +275,7 @@ def __init__( self.on_unsubscribe = None # pylint: disable=too-many-branches - def _get_connect_socket(self, host, port, *, timeout=1): + def _get_connect_socket(self, host: str, port: int, *, timeout: int = 1): """Obtains a new socket and connects to a broker. :param str host: Desired broker hostname @@ -338,20 +348,25 @@ def _get_connect_socket(self, host, port, *, timeout=1): def __enter__(self): return self - def __exit__(self, exception_type, exception_value, traceback): + def __exit__( + self, + exception_type: Optional[Type[BaseException]], + exception_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: self.deinit() - def deinit(self): + def deinit(self) -> None: """De-initializes the MQTT client and disconnects from the mqtt broker.""" self.disconnect() @property - def mqtt_msg(self): + def mqtt_msg(self) -> Tuple[int, int]: """Returns maximum MQTT payload and topic size.""" return self._msg_size_lim, MQTT_TOPIC_LENGTH_LIMIT @mqtt_msg.setter - def mqtt_msg(self, msg_size): + def mqtt_msg(self, msg_size: int) -> None: """Sets the maximum MQTT message payload size. :param int msg_size: Maximum MQTT payload size. @@ -359,7 +374,13 @@ def mqtt_msg(self, msg_size): if msg_size < MQTT_MSG_MAX_SZ: self._msg_size_lim = msg_size - def will_set(self, topic=None, payload=None, qos=0, retain=False): + def will_set( + self, + topic: Optional[str] = None, + payload: Optional[Union[int, float, str]] = None, + qos: int = 0, + retain: bool = False, + ) -> None: """Sets the last will and testament properties. MUST be called before `connect()`. :param str topic: MQTT Broker topic. @@ -388,7 +409,7 @@ def will_set(self, topic=None, payload=None, qos=0, retain=False): self._lw_msg = payload self._lw_retain = retain - def add_topic_callback(self, mqtt_topic, callback_method): + def add_topic_callback(self, mqtt_topic: str, callback_method) -> None: """Registers a callback_method for a specific MQTT topic. :param str mqtt_topic: MQTT topic identifier. @@ -398,7 +419,7 @@ def add_topic_callback(self, mqtt_topic, callback_method): raise ValueError("MQTT topic and callback method must both be defined.") self._on_message_filtered[mqtt_topic] = callback_method - def remove_topic_callback(self, mqtt_topic): + def remove_topic_callback(self, mqtt_topic: str) -> None: """Removes a registered callback method. :param str mqtt_topic: MQTT topic identifier string. @@ -421,10 +442,10 @@ def on_message(self): return self._on_message @on_message.setter - def on_message(self, method): + def on_message(self, method) -> None: self._on_message = method - def _handle_on_message(self, client, topic, message): + def _handle_on_message(self, client, topic: str, message: str): matched = False if topic is not None: for callback in self._on_message_filtered.iter_match(topic): @@ -434,7 +455,7 @@ def _handle_on_message(self, client, topic, message): if not matched and self.on_message: # regular on_message self.on_message(client, topic, message) - def username_pw_set(self, username, password=None): + def username_pw_set(self, username: str, password: Optional[str] = None) -> None: """Set client's username and an optional password. :param str username: Username to use with your MQTT broker. @@ -447,7 +468,13 @@ def username_pw_set(self, username, password=None): if password is not None: self._password = password - def connect(self, clean_session=True, host=None, port=None, keep_alive=None): + def connect( + self, + clean_session: bool = True, + host: Optional[str] = None, + port: Optional[int] = None, + keep_alive: Optional[int] = None, + ) -> int: """Initiates connection with the MQTT Broker. Will perform exponential back-off on connect failures. @@ -503,7 +530,13 @@ def connect(self, clean_session=True, host=None, port=None, keep_alive=None): raise MMQTTException(exc_msg) # pylint: disable=too-many-branches, too-many-statements, too-many-locals - def _connect(self, clean_session=True, host=None, port=None, keep_alive=None): + def _connect( + self, + clean_session: bool = True, + host: Optional[str] = None, + port: Optional[int] = None, + keep_alive: Optional[int] = None, + ) -> int: """Initiates connection with the MQTT Broker. :param bool clean_session: Establishes a persistent session. @@ -616,7 +649,7 @@ def _connect(self, clean_session=True, host=None, port=None, keep_alive=None): f"No data received from broker for {self._recv_timeout} seconds." ) - def disconnect(self): + def disconnect(self) -> None: """Disconnects the MiniMQTT client from the MQTT broker.""" self._connected() self.logger.debug("Sending DISCONNECT packet to broker") @@ -631,7 +664,7 @@ def disconnect(self): if self.on_disconnect is not None: self.on_disconnect(self, self._user_data, 0) - def ping(self): + def ping(self) -> list[int]: """Pings the MQTT Broker to confirm if the broker is alive or if there is an active network connection. Returns response codes of any messages received while waiting for PINGRESP. @@ -651,7 +684,13 @@ def ping(self): return rcs # pylint: disable=too-many-branches, too-many-statements - def publish(self, topic, msg, retain=False, qos=0): + def publish( + self, + topic: str, + msg: Union[str, int, float, bytes], + retain: bool = False, + qos: int = 0, + ) -> None: """Publishes a message to a topic provided. :param str topic: Unique topic identifier. @@ -727,8 +766,8 @@ def publish(self, topic, msg, retain=False, qos=0): if op == 0x40: sz = self._sock_exact_recv(1) assert sz == b"\x02" - rcv_pid = self._sock_exact_recv(2) - rcv_pid = rcv_pid[0] << 0x08 | rcv_pid[1] + rcv_pid_buf = self._sock_exact_recv(2) + rcv_pid = rcv_pid_buf[0] << 0x08 | rcv_pid_buf[1] if self._pid == rcv_pid: if self.on_publish is not None: self.on_publish(self, self._user_data, topic, rcv_pid) @@ -740,7 +779,7 @@ def publish(self, topic, msg, retain=False, qos=0): f"No data received from broker for {self._recv_timeout} seconds." ) - def subscribe(self, topic, qos=0): + def subscribe(self, topic: str, qos: int = 0) -> None: """Subscribes to a topic on the MQTT Broker. This method can subscribe to one topics or multiple topics. @@ -807,7 +846,7 @@ def subscribe(self, topic, qos=0): f"No data received from broker for {self._recv_timeout} seconds." ) - def unsubscribe(self, topic): + def unsubscribe(self, topic: str) -> None: """Unsubscribes from a MQTT topic. :param str|list topic: Unique MQTT topic identifier string or list. @@ -861,7 +900,7 @@ def unsubscribe(self, topic): f"No data received from broker for {self._recv_timeout} seconds." ) - def _recompute_reconnect_backoff(self): + def _recompute_reconnect_backoff(self) -> None: """ Recompute the reconnection timeout. The self._reconnect_timeout will be used in self._connect() to perform the actual sleep. @@ -891,7 +930,7 @@ def _recompute_reconnect_backoff(self): ) self._reconnect_timeout += jitter - def _reset_reconnect_backoff(self): + def _reset_reconnect_backoff(self) -> None: """ Reset reconnect back-off to the initial state. @@ -900,7 +939,7 @@ def _reset_reconnect_backoff(self): self._reconnect_attempt = 0 self._reconnect_timeout = float(0) - def reconnect(self, resub_topics=True): + def reconnect(self, resub_topics: bool = True) -> int: """Attempts to reconnect to the MQTT broker. Return the value from connect() if successful. Will disconnect first if already connected. Will perform exponential back-off on connect failures. @@ -924,13 +963,13 @@ def reconnect(self, resub_topics=True): return ret - def loop(self, timeout=0): + def loop(self, timeout: float = 0) -> Optional[list[int]]: # pylint: disable = too-many-return-statements """Non-blocking message loop. Use this method to check incoming subscription messages. Returns response codes of any messages received. - :param int timeout: Socket timeout, in seconds. + :param float timeout: Socket timeout, in seconds. """ @@ -964,7 +1003,7 @@ def loop(self, timeout=0): return rcs if rcs else None - def _wait_for_msg(self, timeout=0.1): + def _wait_for_msg(self, timeout: float = 0.1) -> Optional[int]: # pylint: disable = too-many-return-statements """Reads and processes network events. @@ -1003,21 +1042,21 @@ def _wait_for_msg(self, timeout=0.1): # Handle only the PUBLISH packet type from now on. sz = self._recv_len() # topic length MSB & LSB - topic_len = self._sock_exact_recv(2) - topic_len = (topic_len[0] << 8) | topic_len[1] + topic_len_buf = self._sock_exact_recv(2) + topic_len = int((topic_len_buf[0] << 8) | topic_len_buf[1]) if topic_len > sz - 2: raise MMQTTException( f"Topic length {topic_len} in PUBLISH packet exceeds remaining length {sz} - 2" ) - topic = self._sock_exact_recv(topic_len) - topic = str(topic, "utf-8") + topic_buf = self._sock_exact_recv(topic_len) + topic = str(topic_buf, "utf-8") sz -= topic_len + 2 pid = 0 if res[0] & 0x06: - pid = self._sock_exact_recv(2) - pid = pid[0] << 0x08 | pid[1] + pid_buf = self._sock_exact_recv(2) + pid = pid_buf[0] << 0x08 | pid_buf[1] sz -= 0x02 # read message contents @@ -1034,11 +1073,10 @@ def _wait_for_msg(self, timeout=0.1): return res[0] - def _recv_len(self): + def _recv_len(self) -> int: """Unpack MQTT message length.""" n = 0 sh = 0 - b = bytearray(1) while True: b = self._sock_exact_recv(1)[0] n |= (b & 0x7F) << sh @@ -1046,9 +1084,9 @@ def _recv_len(self): return n sh += 7 - def _sock_exact_recv(self, bufsize): + def _sock_exact_recv(self, bufsize: int) -> bytearray: """Reads _exact_ number of bytes from the connected socket. Will only return - string with the exact number of bytes requested. + bytearray with the exact number of bytes requested. The semantics of native socket receive is that it returns no more than the specified number of bytes (i.e. max size). However, it makes no guarantees in @@ -1100,7 +1138,7 @@ def _sock_exact_recv(self, bufsize): ) return rc - def _send_str(self, string): + def _send_str(self, string: str) -> None: """Encodes a string and sends it to a socket. :param str string: String to write to the socket. @@ -1114,7 +1152,7 @@ def _send_str(self, string): self._sock.send(string) @staticmethod - def _valid_topic(topic): + def _valid_topic(topic: str) -> None: """Validates if topic provided is proper MQTT topic format. :param str topic: Topic identifier @@ -1130,7 +1168,7 @@ def _valid_topic(topic): raise MMQTTException("Topic length is too large.") @staticmethod - def _valid_qos(qos_level): + def _valid_qos(qos_level: int) -> None: """Validates if the QoS level is supported by this library :param int qos_level: Desired QoS level. @@ -1142,21 +1180,21 @@ def _valid_qos(qos_level): else: raise MMQTTException("QoS must be an integer.") - def _connected(self): + def _connected(self) -> None: """Returns MQTT client session status as True if connected, raises a `MMQTTException` if `False`. """ if not self.is_connected(): raise MMQTTException("MiniMQTT is not connected") - def is_connected(self): + def is_connected(self) -> bool: """Returns MQTT client session status as True if connected, False if not. """ return self._is_connected and self._sock is not None # Logging - def enable_logger(self, log_pkg, log_level=20, logger_name="log"): + def enable_logger(self, log_pkg, log_level: int = 20, logger_name: str = "log"): """Enables library logging by getting logger from the specified logging package and setting its log level. @@ -1173,6 +1211,6 @@ def enable_logger(self, log_pkg, log_level=20, logger_name="log"): return self.logger - def disable_logger(self): + def disable_logger(self) -> None: """Disables logging.""" self.logger = NullLogger()