Skip to content

Commit ffe08db

Browse files
authored
Merge pull request #151 from vladak/exp_backoff_pr
exponential backoff for (re)connect
2 parents 52c52c3 + 7cab7d8 commit ffe08db

File tree

1 file changed

+160
-37
lines changed

1 file changed

+160
-37
lines changed

adafruit_minimqtt/adafruit_minimqtt.py

Lines changed: 160 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ class MMQTTException(Exception):
7777
# pass
7878

7979

80+
class TemporaryError(Exception):
81+
"""Temporary error class used for handling reconnects."""
82+
83+
8084
# Legacy ESP32SPI Socket API
8185
def set_socket(sock, iface=None):
8286
"""Legacy API for setting the socket and network interface.
@@ -137,12 +141,13 @@ class MQTT:
137141
:param bool use_binary_mode: Messages are passed as bytearray instead of string to callbacks.
138142
:param int socket_timeout: How often to check socket state for read/write/connect operations,
139143
in seconds.
140-
:param int connect_retries: How many times to try to connect to broker before giving up.
144+
:param int connect_retries: How many times to try to connect to the broker before giving up
145+
on connect or reconnect. Exponential backoff will be used for the retries.
141146
:param class user_data: arbitrary data to pass as a second argument to the callbacks.
142147
143148
"""
144149

145-
# pylint: disable=too-many-arguments,too-many-instance-attributes, not-callable, invalid-name, no-member
150+
# pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-statements, not-callable, invalid-name, no-member
146151
def __init__(
147152
self,
148153
*,
@@ -174,7 +179,6 @@ def __init__(
174179
)
175180
self._socket_timeout = socket_timeout
176181
self._recv_timeout = recv_timeout
177-
self._connect_retries = connect_retries
178182

179183
self.keep_alive = keep_alive
180184
self._user_data = user_data
@@ -184,6 +188,13 @@ def __init__(
184188
self._timestamp = 0
185189
self.logger = None
186190

191+
self._reconnect_attempt = 0
192+
self._reconnect_timeout = float(0)
193+
self._reconnect_maximum_backoff = 32
194+
if connect_retries <= 0:
195+
raise MMQTTException("connect_retries must be positive")
196+
self._reconnect_attempts_max = connect_retries
197+
187198
self.broker = broker
188199
self._username = username
189200
self._password = password
@@ -268,39 +279,37 @@ def _get_connect_socket(self, host, port, *, timeout=1):
268279
host, port, 0, self._socket_pool.SOCK_STREAM
269280
)[0]
270281

271-
sock = None
272-
retry_count = 0
273-
last_exception = None
274-
while retry_count < self._connect_retries and sock is None:
275-
retry_count += 1
282+
try:
283+
sock = self._socket_pool.socket(addr_info[0], addr_info[1])
284+
except OSError as exc:
285+
# Do not consider this for back-off.
286+
if self.logger is not None:
287+
self.logger.warning(
288+
f"Failed to create socket for host {addr_info[0]} and port {addr_info[1]}"
289+
)
290+
raise TemporaryError from exc
276291

277-
try:
278-
sock = self._socket_pool.socket(addr_info[0], addr_info[1])
279-
except OSError:
280-
continue
292+
connect_host = addr_info[-1][0]
293+
if port == MQTT_TLS_PORT:
294+
sock = self._ssl_context.wrap_socket(sock, server_hostname=host)
295+
connect_host = host
296+
sock.settimeout(timeout)
281297

282-
connect_host = addr_info[-1][0]
283-
if port == MQTT_TLS_PORT:
284-
sock = self._ssl_context.wrap_socket(sock, server_hostname=host)
285-
connect_host = host
286-
sock.settimeout(timeout)
298+
last_exception = None
299+
try:
300+
sock.connect((connect_host, port))
301+
except MemoryError as exc:
302+
sock.close()
303+
if self.logger is not None:
304+
self.logger.warning(f"Failed to allocate memory for connect: {exc}")
305+
# Do not consider this for back-off.
306+
raise TemporaryError from exc
307+
except OSError as exc:
308+
sock.close()
309+
last_exception = exc
287310

288-
try:
289-
sock.connect((connect_host, port))
290-
except MemoryError as exc:
291-
sock.close()
292-
sock = None
293-
last_exception = exc
294-
except OSError as exc:
295-
sock.close()
296-
sock = None
297-
last_exception = exc
298-
299-
if sock is None:
300-
if last_exception:
301-
raise RuntimeError("Repeated socket failures") from last_exception
302-
303-
raise RuntimeError("Repeated socket failures")
311+
if last_exception:
312+
raise last_exception
304313

305314
self._backwards_compatible_sock = not hasattr(sock, "recv_into")
306315
return sock
@@ -418,8 +427,66 @@ def username_pw_set(self, username, password=None):
418427
if password is not None:
419428
self._password = password
420429

421-
# pylint: disable=too-many-branches, too-many-statements, too-many-locals
422430
def connect(self, clean_session=True, host=None, port=None, keep_alive=None):
431+
"""Initiates connection with the MQTT Broker. Will perform exponential back-off
432+
on connect failures.
433+
434+
:param bool clean_session: Establishes a persistent session.
435+
:param str host: Hostname or IP address of the remote broker.
436+
:param int port: Network port of the remote broker.
437+
:param int keep_alive: Maximum period allowed for communication
438+
within single connection attempt, in seconds.
439+
440+
"""
441+
442+
last_exception = None
443+
backoff = False
444+
for i in range(0, self._reconnect_attempts_max):
445+
if i > 0:
446+
if backoff:
447+
self._recompute_reconnect_backoff()
448+
else:
449+
self._reset_reconnect_backoff()
450+
if self.logger is not None:
451+
self.logger.debug(
452+
f"Attempting to connect to MQTT broker (attempt #{self._reconnect_attempt})"
453+
)
454+
455+
try:
456+
ret = self._connect(
457+
clean_session=clean_session,
458+
host=host,
459+
port=port,
460+
keep_alive=keep_alive,
461+
)
462+
self._reset_reconnect_backoff()
463+
return ret
464+
except TemporaryError as e:
465+
if self.logger is not None:
466+
self.logger.warning(f"temporary error when connecting: {e}")
467+
backoff = False
468+
except OSError as e:
469+
last_exception = e
470+
if self.logger is not None:
471+
self.logger.info(f"failed to connect: {e}")
472+
backoff = True
473+
except MMQTTException as e:
474+
last_exception = e
475+
if self.logger is not None:
476+
self.logger.info(f"MMQT error: {e}")
477+
backoff = True
478+
479+
if self._reconnect_attempts_max > 1:
480+
exc_msg = "Repeated connect failures"
481+
else:
482+
exc_msg = "Connect failure"
483+
if last_exception:
484+
raise MMQTTException(exc_msg) from last_exception
485+
486+
raise MMQTTException(exc_msg)
487+
488+
# pylint: disable=too-many-branches, too-many-statements, too-many-locals
489+
def _connect(self, clean_session=True, host=None, port=None, keep_alive=None):
423490
"""Initiates connection with the MQTT Broker.
424491
425492
:param bool clean_session: Establishes a persistent session.
@@ -438,6 +505,12 @@ def connect(self, clean_session=True, host=None, port=None, keep_alive=None):
438505
if self.logger is not None:
439506
self.logger.debug("Attempting to establish MQTT connection...")
440507

508+
if self._reconnect_attempt > 0:
509+
self.logger.debug(
510+
f"Sleeping for {self._reconnect_timeout:.3} seconds due to connect back-off"
511+
)
512+
time.sleep(self._reconnect_timeout)
513+
441514
# Get a new socket
442515
self._sock = self._get_connect_socket(
443516
self.broker, self.port, timeout=self._socket_timeout
@@ -492,7 +565,7 @@ def connect(self, clean_session=True, host=None, port=None, keep_alive=None):
492565
fixed_header.append(0x00)
493566

494567
if self.logger is not None:
495-
self.logger.debug("Sending CONNECT to broker...")
568+
self.logger.debug("Sending CONNECT packet to broker...")
496569
self.logger.debug(
497570
"Fixed Header: %s\nVariable Header: %s", fixed_header, var_header
498571
)
@@ -521,6 +594,7 @@ def connect(self, clean_session=True, host=None, port=None, keep_alive=None):
521594
result = rc[0] & 1
522595
if self.on_connect is not None:
523596
self.on_connect(self, self._user_data, result, rc[2])
597+
524598
return result
525599

526600
if op is None:
@@ -782,15 +856,62 @@ def unsubscribe(self, topic):
782856
f"No data received from broker for {self._recv_timeout} seconds."
783857
)
784858

859+
def _recompute_reconnect_backoff(self):
860+
"""
861+
Recompute the reconnection timeout. The self._reconnect_timeout will be used
862+
in self._connect() to perform the actual sleep.
863+
864+
"""
865+
self._reconnect_attempt = self._reconnect_attempt + 1
866+
self._reconnect_timeout = 2**self._reconnect_attempt
867+
if self.logger is not None:
868+
# pylint: disable=consider-using-f-string
869+
self.logger.debug(
870+
"Reconnect timeout computed to {:.2f}".format(self._reconnect_timeout)
871+
)
872+
873+
if self._reconnect_timeout > self._reconnect_maximum_backoff:
874+
if self.logger is not None:
875+
self.logger.debug(
876+
f"Truncating reconnect timeout to {self._reconnect_maximum_backoff} seconds"
877+
)
878+
self._reconnect_timeout = float(self._reconnect_maximum_backoff)
879+
880+
# Add a sub-second jitter.
881+
# Even truncated timeout should have jitter added to it. This is why it is added here.
882+
jitter = randint(0, 1000) / 1000
883+
if self.logger is not None:
884+
# pylint: disable=consider-using-f-string
885+
self.logger.debug(
886+
"adding jitter {:.2f} to {:.2f} seconds".format(
887+
jitter, self._reconnect_timeout
888+
)
889+
)
890+
self._reconnect_timeout += jitter
891+
892+
def _reset_reconnect_backoff(self):
893+
"""
894+
Reset reconnect back-off to the initial state.
895+
896+
"""
897+
if self.logger is not None:
898+
self.logger.debug("Resetting reconnect backoff")
899+
self._reconnect_attempt = 0
900+
self._reconnect_timeout = float(0)
901+
785902
def reconnect(self, resub_topics=True):
786903
"""Attempts to reconnect to the MQTT broker.
904+
Return the value from connect() if successful. Will disconnect first if already connected.
905+
Will perform exponential back-off on connect failures.
787906
788-
:param bool resub_topics: Resubscribe to previously subscribed topics.
907+
:param bool resub_topics: Whether to resubscribe to previously subscribed topics.
789908
790909
"""
910+
791911
if self.logger is not None:
792912
self.logger.debug("Attempting to reconnect with MQTT broker")
793-
self.connect()
913+
914+
ret = self.connect()
794915
if self.logger is not None:
795916
self.logger.debug("Reconnected with broker")
796917
if resub_topics:
@@ -804,6 +925,8 @@ def reconnect(self, resub_topics=True):
804925
feed = subscribed_topics.pop()
805926
self.subscribe(feed)
806927

928+
return ret
929+
807930
def loop(self, timeout=0):
808931
# pylint: disable = too-many-return-statements
809932
"""Non-blocking message loop. Use this method to

0 commit comments

Comments
 (0)