35
35
36
36
37
37
if not sys .implementation .name == "circuitpython" :
38
- from typing import Optional , Tuple
38
+ from typing import List , Optional , Tuple
39
39
40
40
from circuitpython_typing .socket import (
41
41
CircuitPythonSocketType ,
@@ -68,15 +68,14 @@ def connect(self, address: Tuple[str, int]) -> None:
68
68
try :
69
69
return self ._socket .connect (address , self ._mode )
70
70
except RuntimeError as error :
71
- raise OSError (errno .ENOMEM ) from error
71
+ raise OSError (errno .ENOMEM , str ( error ) ) from error
72
72
73
73
74
74
class _FakeSSLContext :
75
75
def __init__ (self , iface : InterfaceType ) -> None :
76
76
self ._iface = iface
77
77
78
- # pylint: disable=unused-argument
79
- def wrap_socket (
78
+ def wrap_socket ( # pylint: disable=unused-argument
80
79
self , socket : CircuitPythonSocketType , server_hostname : Optional [str ] = None
81
80
) -> _FakeSSLSocket :
82
81
"""Return the same socket"""
@@ -106,7 +105,8 @@ def create_fake_ssl_context(
106
105
return _FakeSSLContext (iface )
107
106
108
107
109
- _global_socketpool = {}
108
+ _global_connection_managers = {}
109
+ _global_socketpools = {}
110
110
_global_ssl_contexts = {}
111
111
112
112
@@ -127,7 +127,7 @@ def get_radio_socketpool(radio):
127
127
* Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing)
128
128
"""
129
129
key = _get_radio_hash_key (radio )
130
- if key not in _global_socketpool :
130
+ if key not in _global_socketpools :
131
131
class_name = radio .__class__ .__name__
132
132
if class_name == "Radio" :
133
133
import ssl # pylint: disable=import-outside-toplevel
@@ -168,10 +168,10 @@ def get_radio_socketpool(radio):
168
168
else :
169
169
raise AttributeError (f"Unsupported radio class: { class_name } " )
170
170
171
- _global_socketpool [key ] = pool
171
+ _global_socketpools [key ] = pool
172
172
_global_ssl_contexts [key ] = ssl_context
173
173
174
- return _global_socketpool [key ]
174
+ return _global_socketpools [key ]
175
175
176
176
177
177
def get_radio_ssl_context (radio ):
@@ -199,42 +199,75 @@ def __init__(
199
199
) -> None :
200
200
self ._socket_pool = socket_pool
201
201
# Hang onto open sockets so that we can reuse them.
202
- self ._available_socket = {}
203
- self ._open_sockets = {}
204
-
205
- def _free_sockets (self ) -> None :
206
- available_sockets = []
207
- for socket , free in self ._available_socket .items ():
208
- if free :
209
- available_sockets .append (socket )
202
+ self ._available_sockets = set ()
203
+ self ._key_by_managed_socket = {}
204
+ self ._managed_socket_by_key = {}
210
205
206
+ def _free_sockets (self , force : bool = False ) -> None :
207
+ # cloning lists since items are being removed
208
+ available_sockets = list (self ._available_sockets )
211
209
for socket in available_sockets :
212
210
self .close_socket (socket )
211
+ if force :
212
+ open_sockets = list (self ._managed_socket_by_key .values ())
213
+ for socket in open_sockets :
214
+ self .close_socket (socket )
213
215
214
- def _get_key_for_socket (self , socket ):
216
+ def _get_connected_socket ( # pylint: disable=too-many-arguments
217
+ self ,
218
+ addr_info : List [Tuple [int , int , int , str , Tuple [str , int ]]],
219
+ host : str ,
220
+ port : int ,
221
+ timeout : float ,
222
+ is_ssl : bool ,
223
+ ssl_context : Optional [SSLContextType ] = None ,
224
+ ):
215
225
try :
216
- return next (
217
- key for key , value in self ._open_sockets .items () if value == socket
218
- )
219
- except StopIteration :
220
- return None
226
+ socket = self ._socket_pool .socket (addr_info [0 ], addr_info [1 ])
227
+ except (OSError , RuntimeError ) as exc :
228
+ return exc
229
+
230
+ if is_ssl :
231
+ socket = ssl_context .wrap_socket (socket , server_hostname = host )
232
+ connect_host = host
233
+ else :
234
+ connect_host = addr_info [- 1 ][0 ]
235
+ socket .settimeout (timeout ) # socket read timeout
236
+
237
+ try :
238
+ socket .connect ((connect_host , port ))
239
+ except (MemoryError , OSError ) as exc :
240
+ socket .close ()
241
+ return exc
242
+
243
+ return socket
244
+
245
+ @property
246
+ def available_socket_count (self ) -> int :
247
+ """Get the count of freeable open sockets"""
248
+ return len (self ._available_sockets )
249
+
250
+ @property
251
+ def managed_socket_count (self ) -> int :
252
+ """Get the count of open sockets"""
253
+ return len (self ._managed_socket_by_key )
221
254
222
255
def close_socket (self , socket : SocketType ) -> None :
223
256
"""Close a previously opened socket."""
224
- if socket not in self ._open_sockets .values ():
257
+ if socket not in self ._managed_socket_by_key .values ():
225
258
raise RuntimeError ("Socket not managed" )
226
- key = self ._get_key_for_socket (socket )
227
259
socket .close ()
228
- del self ._available_socket [socket ]
229
- del self ._open_sockets [key ]
260
+ key = self ._key_by_managed_socket .pop (socket )
261
+ del self ._managed_socket_by_key [key ]
262
+ if socket in self ._available_sockets :
263
+ self ._available_sockets .remove (socket )
230
264
231
265
def free_socket (self , socket : SocketType ) -> None :
232
266
"""Mark a previously opened socket as available so it can be reused if needed."""
233
- if socket not in self ._open_sockets .values ():
267
+ if socket not in self ._managed_socket_by_key .values ():
234
268
raise RuntimeError ("Socket not managed" )
235
- self ._available_socket [ socket ] = True
269
+ self ._available_sockets . add ( socket )
236
270
237
- # pylint: disable=too-many-branches,too-many-locals,too-many-statements
238
271
def get_socket (
239
272
self ,
240
273
host : str ,
@@ -250,10 +283,10 @@ def get_socket(
250
283
if session_id :
251
284
session_id = str (session_id )
252
285
key = (host , port , proto , session_id )
253
- if key in self ._open_sockets :
254
- socket = self ._open_sockets [key ]
255
- if self ._available_socket [ socket ] :
256
- self ._available_socket [ socket ] = False
286
+ if key in self ._managed_socket_by_key :
287
+ socket = self ._managed_socket_by_key [key ]
288
+ if socket in self ._available_sockets :
289
+ self ._available_sockets . remove ( socket )
257
290
return socket
258
291
259
292
raise RuntimeError (f"Socket already connected to { proto } //{ host } :{ port } " )
@@ -269,64 +302,68 @@ def get_socket(
269
302
host , port , 0 , self ._socket_pool .SOCK_STREAM
270
303
)[0 ]
271
304
272
- try_count = 0
273
- socket = None
274
- last_exc = None
275
- while try_count < 2 and socket is None :
276
- try_count += 1
277
- if try_count > 1 :
278
- if any (
279
- socket
280
- for socket , free in self ._available_socket .items ()
281
- if free is True
282
- ):
283
- self ._free_sockets ()
284
- else :
285
- break
286
-
287
- try :
288
- socket = self ._socket_pool .socket (addr_info [0 ], addr_info [1 ])
289
- except OSError as exc :
290
- last_exc = exc
291
- continue
292
- except RuntimeError as exc :
293
- last_exc = exc
294
- continue
295
-
296
- if is_ssl :
297
- socket = ssl_context .wrap_socket (socket , server_hostname = host )
298
- connect_host = host
299
- else :
300
- connect_host = addr_info [- 1 ][0 ]
301
- socket .settimeout (timeout ) # socket read timeout
302
-
303
- try :
304
- socket .connect ((connect_host , port ))
305
- except MemoryError as exc :
306
- last_exc = exc
307
- socket .close ()
308
- socket = None
309
- except OSError as exc :
310
- last_exc = exc
311
- socket .close ()
312
- socket = None
313
-
314
- if socket is None :
315
- raise RuntimeError (f"Error connecting socket: { last_exc } " ) from last_exc
316
-
317
- self ._available_socket [socket ] = False
318
- self ._open_sockets [key ] = socket
319
- return socket
305
+ first_exception = None
306
+ result = self ._get_connected_socket (
307
+ addr_info , host , port , timeout , is_ssl , ssl_context
308
+ )
309
+ if isinstance (result , Exception ):
310
+ # Got an error, if there are any available sockets, free them and try again
311
+ if self .available_socket_count :
312
+ first_exception = result
313
+ self ._free_sockets ()
314
+ result = self ._get_connected_socket (
315
+ addr_info , host , port , timeout , is_ssl , ssl_context
316
+ )
317
+ if isinstance (result , Exception ):
318
+ last_result = f", first error: { first_exception } " if first_exception else ""
319
+ raise RuntimeError (
320
+ f"Error connecting socket: { result } { last_result } "
321
+ ) from result
322
+
323
+ self ._key_by_managed_socket [result ] = key
324
+ self ._managed_socket_by_key [key ] = result
325
+ return result
320
326
321
327
322
328
# global helpers
323
329
324
330
325
- _global_connection_manager = {}
331
+ def connection_manager_close_all (
332
+ socket_pool : Optional [SocketpoolModuleType ] = None , release_references : bool = False
333
+ ) -> None :
334
+ """Close all open sockets for pool"""
335
+ if socket_pool :
336
+ socket_pools = [socket_pool ]
337
+ else :
338
+ socket_pools = _global_connection_managers .keys ()
339
+
340
+ for pool in socket_pools :
341
+ connection_manager = _global_connection_managers .get (pool , None )
342
+ if connection_manager is None :
343
+ raise RuntimeError ("SocketPool not managed" )
344
+
345
+ connection_manager ._free_sockets (force = True ) # pylint: disable=protected-access
346
+
347
+ if release_references :
348
+ radio_key = None
349
+ for radio_check , pool_check in _global_socketpools .items ():
350
+ if pool == pool_check :
351
+ radio_key = radio_check
352
+ break
353
+
354
+ if radio_key :
355
+ if radio_key in _global_socketpools :
356
+ del _global_socketpools [radio_key ]
357
+
358
+ if radio_key in _global_ssl_contexts :
359
+ del _global_ssl_contexts [radio_key ]
360
+
361
+ if pool in _global_connection_managers :
362
+ del _global_connection_managers [pool ]
326
363
327
364
328
365
def get_connection_manager (socket_pool : SocketpoolModuleType ) -> ConnectionManager :
329
366
"""Get the ConnectionManager singleton for the given pool"""
330
- if socket_pool not in _global_connection_manager :
331
- _global_connection_manager [socket_pool ] = ConnectionManager (socket_pool )
332
- return _global_connection_manager [socket_pool ]
367
+ if socket_pool not in _global_connection_managers :
368
+ _global_connection_managers [socket_pool ] = ConnectionManager (socket_pool )
369
+ return _global_connection_managers [socket_pool ]
0 commit comments