Skip to content

Commit 16b6c6d

Browse files
committed
Make sure socket is closed on exception
1 parent 906e676 commit 16b6c6d

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

adafruit_minimqtt/adafruit_minimqtt.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,7 @@ def connect(
438438
self.logger.warning(f"Socket error when connecting: {e}")
439439
backoff = False
440440
except MMQTTException as e:
441+
self._close_socket()
441442
self.logger.info(f"MMQT error: {e}")
442443
if e.code in [
443444
CONNACK_ERROR_INCORECT_USERNAME_PASSWORD,
@@ -452,9 +453,9 @@ def connect(
452453
exc_msg = "Repeated connect failures"
453454
else:
454455
exc_msg = "Connect failure"
456+
455457
if last_exception:
456458
raise MMQTTException(exc_msg) from last_exception
457-
458459
raise MMQTTException(exc_msg)
459460

460461
# pylint: disable=too-many-branches, too-many-statements, too-many-locals
@@ -565,6 +566,12 @@ def _connect(
565566
f"No data received from broker for {self._recv_timeout} seconds."
566567
)
567568

569+
def _close_socket(self):
570+
if self._sock:
571+
self.logger.debug("Closing socket")
572+
self._connection_manager.close_socket(self._sock)
573+
self._sock = None
574+
568575
# pylint: disable=no-self-use
569576
def _encode_remaining_length(
570577
self, fixed_header: bytearray, remaining_length: int
@@ -593,8 +600,7 @@ def disconnect(self) -> None:
593600
self._sock.send(MQTT_DISCONNECT)
594601
except RuntimeError as e:
595602
self.logger.warning(f"Unable to send DISCONNECT packet: {e}")
596-
self.logger.debug("Closing socket")
597-
self._connection_manager.close_socket(self._sock)
603+
self._close_socket()
598604
self._is_connected = False
599605
self._subscribed_topics = []
600606
self._last_msg_sent_timestamp = 0

tests/test_backoff.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def test_failing_connect(self) -> None:
5151
print("connecting")
5252
with pytest.raises(MQTT.MMQTTException) as context:
5353
mqtt_client.connect()
54+
assert mqtt_client._sock is None
5455
assert "Repeated connect failures" in str(context)
5556

5657
mock_method.assert_called()
@@ -86,6 +87,7 @@ def test_unauthorized(self) -> None:
8687
print("connecting")
8788
with pytest.raises(MQTT.MMQTTException) as context:
8889
mqtt_client.connect()
90+
assert mqtt_client._sock is None
8991
assert "Connection Refused - Unauthorized" in str(context)
9092

9193
mock_method.assert_called()

0 commit comments

Comments
 (0)