Skip to content

Commit 6ca130f

Browse files
committed
Fix socket reconnect issue
1 parent ac27573 commit 6ca130f

File tree

2 files changed

+165
-15
lines changed

2 files changed

+165
-15
lines changed

aws_embedded_metrics/sinks/tcp_client.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import logging
1616
import socket
1717
import threading
18+
import errno
1819
from urllib.parse import ParseResult
1920

2021
log = logging.getLogger(__name__)
@@ -25,24 +26,44 @@
2526
class TcpClient(SocketClient):
2627
def __init__(self, endpoint: ParseResult):
2728
self._endpoint = endpoint
28-
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
29-
self._write_lock = threading.Lock()
29+
# using reentrant lock so that we can retry through recursion
30+
self._write_lock = threading.RLock()
31+
self._connect_lock = threading.RLock()
3032
self._should_connect = True
3133

3234
def connect(self) -> "TcpClient":
33-
try:
34-
self._sock.connect((self._endpoint.hostname, self._endpoint.port))
35-
self._should_connect = False
36-
except socket.timeout as e:
37-
log.error("Socket timeout durring connect %s" % (e,))
38-
self._should_connect = True
39-
except Exception as e:
40-
log.error("Failed to connect to the socket. %s" % (e,))
41-
self._should_connect = True
42-
return self
43-
44-
def send_message(self, message: bytes) -> None:
45-
if self._sock._closed or self._should_connect: # type: ignore
35+
with self._connect_lock:
36+
try:
37+
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
38+
self._sock.connect((self._endpoint.hostname, self._endpoint.port))
39+
self._should_connect = False
40+
except socket.timeout as e:
41+
log.error("Socket timeout durring connect %s" % (e,))
42+
except OSError as e:
43+
if e.errno == errno.EISCONN:
44+
log.debug("Socket is already connected.")
45+
self._should_connect = False
46+
else:
47+
log.error("Failed to connect to the socket. %s" % (e,))
48+
self._should_connect = True
49+
except Exception as e:
50+
log.error("Failed to connect to the socket. %s" % (e,))
51+
self._should_connect = True
52+
return self
53+
54+
# TODO: once #21 lands, we should increase the max retries
55+
# the reason this is only 1 is to allow for a single
56+
# reconnect attempt in case the agent disconnects
57+
# additional retries and backoff would impose back
58+
# pressure on the caller that may not be accounted
59+
# for. Before we do that, we need to run the I/O
60+
# operations on a background thread.s
61+
def send_message(self, message: bytes, retry: int = 1) -> None:
62+
if retry < 0:
63+
log.error("Max retries exhausted, dropping message")
64+
return
65+
66+
if self._sock is None or self._sock._closed or self._should_connect: # type: ignore
4667
self.connect()
4768

4869
with self._write_lock:
@@ -52,9 +73,12 @@ def send_message(self, message: bytes) -> None:
5273
except socket.timeout as e:
5374
log.error("Socket timeout durring send %s" % (e,))
5475
self.connect()
76+
self.send_message(message, retry - 1)
5577
except socket.error as e:
5678
log.error("Failed to write metrics to the socket due to socket.error. %s" % (e,))
5779
self.connect()
80+
self.send_message(message, retry - 1)
5881
except Exception as e:
5982
log.error("Failed to write metrics to the socket due to exception. %s" % (e,))
6083
self.connect()
84+
self.send_message(message, retry - 1)

tests/sinks/test_tcp_client.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
from aws_embedded_metrics.sinks.tcp_client import TcpClient
2+
from urllib.parse import urlparse
3+
import socket
4+
import threading
5+
import time
6+
import logging
7+
8+
log = logging.getLogger(__name__)
9+
10+
test_host = '0.0.0.0'
11+
test_port = 9999
12+
endpoint = urlparse("tcp://0.0.0.0:9999")
13+
message = "_16-Byte-String_".encode('utf-8')
14+
15+
16+
def test_can_send_message():
17+
# arrange
18+
agent = InProcessAgent().start()
19+
client = TcpClient(endpoint)
20+
21+
# act
22+
client.connect()
23+
client.send_message(message)
24+
25+
# assert
26+
time.sleep(1)
27+
messages = agent.messages
28+
assert 1 == len(messages)
29+
assert message == messages[0]
30+
agent.shutdown()
31+
32+
33+
def test_can_connect_concurrently_from_threads():
34+
# arrange
35+
concurrency = 10
36+
agent = InProcessAgent().start()
37+
client = TcpClient(endpoint)
38+
barrier = threading.Barrier(concurrency, timeout=5)
39+
40+
def run():
41+
barrier.wait()
42+
client.connect()
43+
client.send_message(message)
44+
45+
def start_thread():
46+
thread = threading.Thread(target=run, args=())
47+
thread.daemon = True
48+
thread.start()
49+
50+
# act
51+
for _ in range(concurrency):
52+
start_thread()
53+
54+
# assert
55+
time.sleep(1)
56+
messages = agent.messages
57+
assert concurrency == len(messages)
58+
for i in range(concurrency):
59+
assert message == messages[i]
60+
agent.shutdown()
61+
62+
63+
def test_can_recover_from_agent_shutdown():
64+
# arrange
65+
agent = InProcessAgent().start()
66+
client = TcpClient(endpoint)
67+
68+
# act
69+
client.connect()
70+
client.send_message(message)
71+
agent.shutdown()
72+
time.sleep(5)
73+
client.send_message(message)
74+
agent = InProcessAgent().start()
75+
client.send_message(message)
76+
77+
# assert
78+
time.sleep(1)
79+
messages = agent.messages
80+
assert 1 == len(messages)
81+
assert message == messages[0]
82+
agent.shutdown()
83+
84+
85+
class InProcessAgent(object):
86+
""" Agent that runs on a background thread and collects
87+
messages in memory.
88+
"""
89+
90+
def __init__(self):
91+
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
92+
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
93+
self.sock.bind((test_host, test_port))
94+
self.sock.listen()
95+
self.is_shutdown = False
96+
self.messages = []
97+
98+
def start(self) -> "InProcessAgent":
99+
thread = threading.Thread(target=self.run, args=())
100+
thread.daemon = True
101+
thread.start()
102+
return self
103+
104+
def run(self):
105+
while not self.is_shutdown:
106+
connection, client_address = self.sock.accept()
107+
self.connection = connection
108+
109+
try:
110+
while not self.is_shutdown:
111+
data = self.connection.recv(16)
112+
if data:
113+
self.messages.append(data)
114+
else:
115+
break
116+
finally:
117+
log.error("Exited the recv loop")
118+
119+
def shutdown(self):
120+
try:
121+
self.is_shutdown = True
122+
self.connection.shutdown(socket.SHUT_RDWR)
123+
self.connection.close()
124+
self.sock.close()
125+
except Exception as e:
126+
log.error("Failed to shutdown %s" % (e,))

0 commit comments

Comments
 (0)