Skip to content

Commit 3648e02

Browse files
author
brentru
committed
remove socket reuse from requests, set context within set_interface instead, add a global ssl context
1 parent cafcb05 commit 3648e02

File tree

1 file changed

+39
-93
lines changed

1 file changed

+39
-93
lines changed

adafruit_minimqtt/adafruit_minimqtt.py

Lines changed: 39 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,8 @@
6161
const(0x05): "Connection Refused - Unauthorized",
6262
}
6363

64-
_the_interface = None # pylint: disable=invalid-name
65-
_the_sock = None # pylint: disable=invalid-name
66-
64+
_default_sock = None # pylint: disable=invalid-name
65+
_fake_context = None # pylint: disable=invalid-name
6766

6867
class MMQTTException(Exception):
6968
"""MiniMQTT Exception class."""
@@ -74,18 +73,17 @@ class MMQTTException(Exception):
7473

7574
# Legacy ESP32SPI Socket API
7675
def set_socket(sock, iface=None):
77-
"""Legacy API for setting the socket and network interface, use a Session instead.
78-
76+
"""Legacy API for setting the socket and network interface.
7977
:param sock: socket object.
8078
:param iface: internet interface object
79+
8180
"""
82-
global _the_sock # pylint: disable=invalid-name, global-statement
83-
_the_sock = sock
81+
global _default_sock # pylint: disable=invalid-name, global-statement
82+
global _fake_context # pylint: disable=invalid-name, global-statement
83+
_default_sock = sock
8484
if iface:
85-
global _the_interface # pylint: disable=invalid-name, global-statement
86-
_the_interface = iface
87-
_the_sock.set_interface(iface)
88-
85+
_default_sock.set_interface(iface)
86+
_fake_context = _FakeSSLContext(iface)
8987

9088
class _FakeSSLSocket:
9189
def __init__(self, socket, tls_mode):
@@ -103,7 +101,6 @@ def connect(self, address):
103101
except RuntimeError as error:
104102
raise OSError(errno.ENOMEM) from error
105103

106-
107104
class _FakeSSLContext:
108105
def __init__(self, iface):
109106
self._iface = iface
@@ -144,18 +141,7 @@ def __init__(
144141
):
145142

146143
self._socket_pool = socket_pool
147-
# Legacy API - if we do not have a socket pool, use default socket
148-
if self._socket_pool is None:
149-
self._socket_pool = _the_sock
150-
151144
self._ssl_context = ssl_context
152-
# Legacy API - if we do not have SSL context, fake it
153-
if self._ssl_context is None:
154-
self._ssl_context = _FakeSSLContext(_the_interface)
155-
156-
# Hang onto open sockets so that we can reuse them
157-
self._socket_free = {}
158-
self._open_sockets = {}
159145
self._sock = None
160146
self._backwards_compatible_sock = False
161147

@@ -214,93 +200,53 @@ def __init__(
214200
self.on_subscribe = None
215201
self.on_unsubscribe = None
216202

217-
# Socket helpers
218-
def _free_socket(self, socket):
219-
"""Frees a socket for re-use."""
220-
if socket not in self._open_sockets.values():
221-
raise RuntimeError("Socket not from MQTT client.")
222-
self._socket_free[socket] = True
223-
224-
def _close_socket(self, socket):
225-
"""Closes a slocket."""
226-
socket.close()
227-
del self._socket_free[socket]
228-
key = None
229-
for k in self._open_sockets:
230-
if self._open_sockets[k] == socket:
231-
key = k
232-
break
233-
if key:
234-
del self._open_sockets[key]
235-
236-
def _free_sockets(self):
237-
"""Closes all free sockets."""
238-
free_sockets = []
239-
for sock in self._socket_free:
240-
if self._socket_free[sock]:
241-
free_sockets.append(sock)
242-
for sock in free_sockets:
243-
self._close_socket(sock)
244203

245204
# pylint: disable=too-many-branches
246205
def _get_socket(self, host, port, *, timeout=1):
247-
key = (host, port)
248-
if key in self._open_sockets:
249-
sock = self._open_sockets[key]
250-
if self._socket_free[sock]:
251-
self._socket_free[sock] = False
252-
return sock
253-
if port == 8883 and not self._ssl_context:
254-
raise RuntimeError(
255-
"ssl_context must be set before using adafruit_mqtt for secure MQTT."
256-
)
206+
# For reconnections - check if we're using a socket already and close it
207+
if self._sock:
208+
self._sock.close()
257209

258210
# Legacy API - use a default socket instead of socket pool
259211
if self._socket_pool is None:
260-
self._socket_pool = _the_sock
212+
self._socket_pool = _default_sock
213+
214+
# Legacy API - fake the ssl context
215+
if self._ssl_context is None:
216+
self._ssl_context = _fake_context
217+
218+
if port == 8883 and self._ssl_context is None:
219+
raise RuntimeError(
220+
"ssl_context must be set before using adafruit_mqtt for secure MQTT."
221+
)
261222

262223
addr_info = self._socket_pool.getaddrinfo(
263224
host, port, 0, self._socket_pool.SOCK_STREAM
264225
)[0]
226+
265227
retry_count = 0
266228
sock = None
267-
while retry_count < 5 and sock is None:
268-
if retry_count > 0:
269-
if any(self._socket_free.items()):
270-
self._free_sockets()
271-
else:
272-
raise RuntimeError("Sending request failed")
273-
retry_count += 1
274229

275-
try:
276-
sock = self._socket_pool.socket(
277-
addr_info[0], addr_info[1], addr_info[2]
278-
)
279-
except OSError:
280-
continue
281-
282-
connect_host = addr_info[-1][0]
283-
if port == 8883:
284-
sock = self._ssl_context.wrap_socket(sock, server_hostname=host)
285-
connect_host = host
286-
sock.settimeout(timeout)
230+
sock = self._socket_pool.socket(
231+
addr_info[0], addr_info[1], addr_info[2]
232+
)
287233

288-
try:
289-
sock.connect((connect_host, port))
290-
except MemoryError:
291-
sock.close()
292-
sock = None
293-
except OSError:
294-
sock.close()
295-
sock = None
234+
connect_host = addr_info[-1][0]
235+
if port == 8883:
236+
sock = self._ssl_context.wrap_socket(sock, server_hostname=host)
237+
connect_host = host
238+
sock.settimeout(timeout)
296239

297-
if sock is None:
298-
raise RuntimeError("Repeated socket failures")
240+
try:
241+
sock.connect((connect_host, port))
242+
except MemoryError as err:
243+
sock.close()
244+
raise MemoryError(err)
245+
except OSError as err:
246+
sock.close()
247+
raise OSError(err)
299248

300249
self._backwards_compatible_sock = not hasattr(sock, "recv_into")
301-
302-
self._open_sockets[key] = sock
303-
self._socket_free[sock] = False
304250
return sock
305251

306252
def __enter__(self):

0 commit comments

Comments
 (0)