Skip to content

Commit 001967a

Browse files
committed
Merge branch 'main' into esp32spi-and-wiznet5k-socketpool
2 parents 778b78f + 1531496 commit 001967a

8 files changed

+381
-141
lines changed

adafruit_connection_manager.py

Lines changed: 122 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636

3737
if not sys.implementation.name == "circuitpython":
38-
from typing import Optional, Tuple
38+
from typing import List, Optional, Tuple
3939

4040
from circuitpython_typing.socket import (
4141
CircuitPythonSocketType,
@@ -68,15 +68,14 @@ def connect(self, address: Tuple[str, int]) -> None:
6868
try:
6969
return self._socket.connect(address, self._mode)
7070
except RuntimeError as error:
71-
raise OSError(errno.ENOMEM) from error
71+
raise OSError(errno.ENOMEM, str(error)) from error
7272

7373

7474
class _FakeSSLContext:
7575
def __init__(self, iface: InterfaceType) -> None:
7676
self._iface = iface
7777

78-
# pylint: disable=unused-argument
79-
def wrap_socket(
78+
def wrap_socket( # pylint: disable=unused-argument
8079
self, socket: CircuitPythonSocketType, server_hostname: Optional[str] = None
8180
) -> _FakeSSLSocket:
8281
"""Return the same socket"""
@@ -106,7 +105,8 @@ def create_fake_ssl_context(
106105
return _FakeSSLContext(iface)
107106

108107

109-
_global_socketpool = {}
108+
_global_connection_managers = {}
109+
_global_socketpools = {}
110110
_global_ssl_contexts = {}
111111

112112

@@ -127,7 +127,7 @@ def get_radio_socketpool(radio):
127127
* Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing)
128128
"""
129129
key = _get_radio_hash_key(radio)
130-
if key not in _global_socketpool:
130+
if key not in _global_socketpools:
131131
class_name = radio.__class__.__name__
132132
if class_name == "Radio":
133133
import ssl # pylint: disable=import-outside-toplevel
@@ -168,10 +168,10 @@ def get_radio_socketpool(radio):
168168
else:
169169
raise AttributeError(f"Unsupported radio class: {class_name}")
170170

171-
_global_socketpool[key] = pool
171+
_global_socketpools[key] = pool
172172
_global_ssl_contexts[key] = ssl_context
173173

174-
return _global_socketpool[key]
174+
return _global_socketpools[key]
175175

176176

177177
def get_radio_ssl_context(radio):
@@ -199,42 +199,75 @@ def __init__(
199199
) -> None:
200200
self._socket_pool = socket_pool
201201
# Hang onto open sockets so that we can reuse them.
202-
self._available_socket = {}
203-
self._open_sockets = {}
204-
205-
def _free_sockets(self) -> None:
206-
available_sockets = []
207-
for socket, free in self._available_socket.items():
208-
if free:
209-
available_sockets.append(socket)
202+
self._available_sockets = set()
203+
self._key_by_managed_socket = {}
204+
self._managed_socket_by_key = {}
210205

206+
def _free_sockets(self, force: bool = False) -> None:
207+
# cloning lists since items are being removed
208+
available_sockets = list(self._available_sockets)
211209
for socket in available_sockets:
212210
self.close_socket(socket)
211+
if force:
212+
open_sockets = list(self._managed_socket_by_key.values())
213+
for socket in open_sockets:
214+
self.close_socket(socket)
213215

214-
def _get_key_for_socket(self, socket):
216+
def _get_connected_socket( # pylint: disable=too-many-arguments
217+
self,
218+
addr_info: List[Tuple[int, int, int, str, Tuple[str, int]]],
219+
host: str,
220+
port: int,
221+
timeout: float,
222+
is_ssl: bool,
223+
ssl_context: Optional[SSLContextType] = None,
224+
):
215225
try:
216-
return next(
217-
key for key, value in self._open_sockets.items() if value == socket
218-
)
219-
except StopIteration:
220-
return None
226+
socket = self._socket_pool.socket(addr_info[0], addr_info[1])
227+
except (OSError, RuntimeError) as exc:
228+
return exc
229+
230+
if is_ssl:
231+
socket = ssl_context.wrap_socket(socket, server_hostname=host)
232+
connect_host = host
233+
else:
234+
connect_host = addr_info[-1][0]
235+
socket.settimeout(timeout) # socket read timeout
236+
237+
try:
238+
socket.connect((connect_host, port))
239+
except (MemoryError, OSError) as exc:
240+
socket.close()
241+
return exc
242+
243+
return socket
244+
245+
@property
246+
def available_socket_count(self) -> int:
247+
"""Get the count of freeable open sockets"""
248+
return len(self._available_sockets)
249+
250+
@property
251+
def managed_socket_count(self) -> int:
252+
"""Get the count of open sockets"""
253+
return len(self._managed_socket_by_key)
221254

222255
def close_socket(self, socket: SocketType) -> None:
223256
"""Close a previously opened socket."""
224-
if socket not in self._open_sockets.values():
257+
if socket not in self._managed_socket_by_key.values():
225258
raise RuntimeError("Socket not managed")
226-
key = self._get_key_for_socket(socket)
227259
socket.close()
228-
del self._available_socket[socket]
229-
del self._open_sockets[key]
260+
key = self._key_by_managed_socket.pop(socket)
261+
del self._managed_socket_by_key[key]
262+
if socket in self._available_sockets:
263+
self._available_sockets.remove(socket)
230264

231265
def free_socket(self, socket: SocketType) -> None:
232266
"""Mark a previously opened socket as available so it can be reused if needed."""
233-
if socket not in self._open_sockets.values():
267+
if socket not in self._managed_socket_by_key.values():
234268
raise RuntimeError("Socket not managed")
235-
self._available_socket[socket] = True
269+
self._available_sockets.add(socket)
236270

237-
# pylint: disable=too-many-branches,too-many-locals,too-many-statements
238271
def get_socket(
239272
self,
240273
host: str,
@@ -250,10 +283,10 @@ def get_socket(
250283
if session_id:
251284
session_id = str(session_id)
252285
key = (host, port, proto, session_id)
253-
if key in self._open_sockets:
254-
socket = self._open_sockets[key]
255-
if self._available_socket[socket]:
256-
self._available_socket[socket] = False
286+
if key in self._managed_socket_by_key:
287+
socket = self._managed_socket_by_key[key]
288+
if socket in self._available_sockets:
289+
self._available_sockets.remove(socket)
257290
return socket
258291

259292
raise RuntimeError(f"Socket already connected to {proto}//{host}:{port}")
@@ -269,64 +302,68 @@ def get_socket(
269302
host, port, 0, self._socket_pool.SOCK_STREAM
270303
)[0]
271304

272-
try_count = 0
273-
socket = None
274-
last_exc = None
275-
while try_count < 2 and socket is None:
276-
try_count += 1
277-
if try_count > 1:
278-
if any(
279-
socket
280-
for socket, free in self._available_socket.items()
281-
if free is True
282-
):
283-
self._free_sockets()
284-
else:
285-
break
286-
287-
try:
288-
socket = self._socket_pool.socket(addr_info[0], addr_info[1])
289-
except OSError as exc:
290-
last_exc = exc
291-
continue
292-
except RuntimeError as exc:
293-
last_exc = exc
294-
continue
295-
296-
if is_ssl:
297-
socket = ssl_context.wrap_socket(socket, server_hostname=host)
298-
connect_host = host
299-
else:
300-
connect_host = addr_info[-1][0]
301-
socket.settimeout(timeout) # socket read timeout
302-
303-
try:
304-
socket.connect((connect_host, port))
305-
except MemoryError as exc:
306-
last_exc = exc
307-
socket.close()
308-
socket = None
309-
except OSError as exc:
310-
last_exc = exc
311-
socket.close()
312-
socket = None
313-
314-
if socket is None:
315-
raise RuntimeError(f"Error connecting socket: {last_exc}") from last_exc
316-
317-
self._available_socket[socket] = False
318-
self._open_sockets[key] = socket
319-
return socket
305+
first_exception = None
306+
result = self._get_connected_socket(
307+
addr_info, host, port, timeout, is_ssl, ssl_context
308+
)
309+
if isinstance(result, Exception):
310+
# Got an error, if there are any available sockets, free them and try again
311+
if self.available_socket_count:
312+
first_exception = result
313+
self._free_sockets()
314+
result = self._get_connected_socket(
315+
addr_info, host, port, timeout, is_ssl, ssl_context
316+
)
317+
if isinstance(result, Exception):
318+
last_result = f", first error: {first_exception}" if first_exception else ""
319+
raise RuntimeError(
320+
f"Error connecting socket: {result}{last_result}"
321+
) from result
322+
323+
self._key_by_managed_socket[result] = key
324+
self._managed_socket_by_key[key] = result
325+
return result
320326

321327

322328
# global helpers
323329

324330

325-
_global_connection_manager = {}
331+
def connection_manager_close_all(
332+
socket_pool: Optional[SocketpoolModuleType] = None, release_references: bool = False
333+
) -> None:
334+
"""Close all open sockets for pool"""
335+
if socket_pool:
336+
socket_pools = [socket_pool]
337+
else:
338+
socket_pools = _global_connection_managers.keys()
339+
340+
for pool in socket_pools:
341+
connection_manager = _global_connection_managers.get(pool, None)
342+
if connection_manager is None:
343+
raise RuntimeError("SocketPool not managed")
344+
345+
connection_manager._free_sockets(force=True) # pylint: disable=protected-access
346+
347+
if release_references:
348+
radio_key = None
349+
for radio_check, pool_check in _global_socketpools.items():
350+
if pool == pool_check:
351+
radio_key = radio_check
352+
break
353+
354+
if radio_key:
355+
if radio_key in _global_socketpools:
356+
del _global_socketpools[radio_key]
357+
358+
if radio_key in _global_ssl_contexts:
359+
del _global_ssl_contexts[radio_key]
360+
361+
if pool in _global_connection_managers:
362+
del _global_connection_managers[pool]
326363

327364

328365
def get_connection_manager(socket_pool: SocketpoolModuleType) -> ConnectionManager:
329366
"""Get the ConnectionManager singleton for the given pool"""
330-
if socket_pool not in _global_connection_manager:
331-
_global_connection_manager[socket_pool] = ConnectionManager(socket_pool)
332-
return _global_connection_manager[socket_pool]
367+
if socket_pool not in _global_connection_managers:
368+
_global_connection_managers[socket_pool] = ConnectionManager(socket_pool)
369+
return _global_connection_managers[socket_pool]

examples/connectionmanager_helpers.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,38 @@
2424

2525
# get request session
2626
requests = adafruit_requests.Session(pool, ssl_context)
27+
connection_manager = adafruit_connection_manager.get_connection_manager(pool)
28+
print("-" * 40)
29+
print("Nothing yet opened")
30+
print(f"Open Sockets: {connection_manager.managed_socket_count}")
31+
print(f"Freeable Open Sockets: {connection_manager.available_socket_count}")
2732

2833
# make request
2934
print("-" * 40)
30-
print(f"Fetching from {TEXT_URL}")
35+
print(f"Fetching from {TEXT_URL} in a context handler")
36+
with requests.get(TEXT_URL) as response:
37+
response_text = response.text
38+
print(f"Text Response {response_text}")
39+
40+
print("-" * 40)
41+
print("1 request, opened and freed")
42+
print(f"Open Sockets: {connection_manager.managed_socket_count}")
43+
print(f"Freeable Open Sockets: {connection_manager.available_socket_count}")
3144

45+
print("-" * 40)
46+
print(f"Fetching from {TEXT_URL} not in a context handler")
3247
response = requests.get(TEXT_URL)
33-
response_text = response.text
34-
response.close()
3548

36-
print(f"Text Response {response_text}")
3749
print("-" * 40)
50+
print("1 request, opened but not freed")
51+
print(f"Open Sockets: {connection_manager.managed_socket_count}")
52+
print(f"Freeable Open Sockets: {connection_manager.available_socket_count}")
53+
54+
print("-" * 40)
55+
print("Closing everything in the pool")
56+
adafruit_connection_manager.connection_manager_close_all(pool)
57+
58+
print("-" * 40)
59+
print("Everything closed")
60+
print(f"Open Sockets: {connection_manager.managed_socket_count}")
61+
print(f"Freeable Open Sockets: {connection_manager.available_socket_count}")

tests/close_socket_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ def test_close_socket():
2121
socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:")
2222
key = (mocket.MOCK_HOST_1, 80, "http:", None)
2323
assert socket == mock_socket_1
24-
assert socket in connection_manager._available_socket
25-
assert key in connection_manager._open_sockets
24+
assert socket not in connection_manager._available_sockets
25+
assert key in connection_manager._managed_socket_by_key
2626

2727
# validate socket is no longer tracked
2828
connection_manager.close_socket(socket)
29-
assert socket not in connection_manager._available_socket
30-
assert key not in connection_manager._open_sockets
29+
assert socket not in connection_manager._available_sockets
30+
assert key not in connection_manager._managed_socket_by_key
3131

3232

3333
def test_close_socket_not_managed():

tests/conftest.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,11 @@ def adafruit_wiznet5k_with_ssl_socketpool_module():
9191
@pytest.fixture(autouse=True)
9292
def reset_connection_manager(monkeypatch):
9393
monkeypatch.setattr(
94-
"adafruit_connection_manager._global_socketpool",
94+
"adafruit_connection_manager._global_connection_managers",
95+
{},
96+
)
97+
monkeypatch.setattr(
98+
"adafruit_connection_manager._global_socketpools",
9599
{},
96100
)
97101
monkeypatch.setattr(

0 commit comments

Comments
 (0)