Skip to content

Commit 1c2d8f4

Browse files
author
Mark Kuhn
committed
add socket eos detection
1 parent dd7bffa commit 1c2d8f4

File tree

2 files changed

+98
-3
lines changed

2 files changed

+98
-3
lines changed

src/main/java/software/amazon/cloudwatchlogs/emf/sinks/TCPClient.java

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import java.io.OutputStream;
2121
import java.net.InetSocketAddress;
2222
import java.net.Socket;
23+
import java.util.concurrent.*;
24+
2325
import lombok.extern.slf4j.Slf4j;
2426

2527
/** A client that would connect to a TCP socket. */
@@ -51,9 +53,7 @@ protected Socket createSocket() {
5153

5254
@Override
5355
public synchronized void sendMessage(String message) {
54-
if (socket == null || socket.isClosed() || shouldConnect) {
55-
connect();
56-
}
56+
checkConnection();
5757

5858
OutputStream os;
5959
try {
@@ -72,6 +72,48 @@ public synchronized void sendMessage(String message) {
7272
}
7373
}
7474

75+
/**
76+
* Performs checks to see if the socket is connected.
77+
* If not, it will attempt to connect.
78+
*/
79+
private void checkConnection() {
80+
if (socket == null || socket.isClosed() || shouldConnect) {
81+
connect();
82+
return;
83+
}
84+
85+
try {
86+
if (streamIsClosed()) {
87+
connect();
88+
}
89+
} catch (RuntimeException | IOException e) {
90+
connect();
91+
}
92+
}
93+
94+
/**
95+
* Checks if the socket's input stream has reached EOS.
96+
*
97+
* @return true if the input stream is closed, false otherwise.
98+
*/
99+
private boolean streamIsClosed() throws IOException {
100+
Integer result;
101+
Callable<Integer> readTask = socket.getInputStream()::read;
102+
103+
ExecutorService executor = Executors.newCachedThreadPool();
104+
Future<Integer> readFuture = executor.submit(readTask);
105+
106+
try {
107+
result = readFuture.get(1, TimeUnit.MILLISECONDS);
108+
} catch (ExecutionException | InterruptedException e) {
109+
throw new RuntimeException(e);
110+
} catch (TimeoutException ignored) {
111+
return false;
112+
}
113+
114+
return result == -1;
115+
}
116+
75117
@Override
76118
public void close() throws IOException {
77119
if (socket != null) {

src/test/java/software/amazon/cloudwatchlogs/emf/sinks/TCPClientTest.java

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
import java.io.ByteArrayOutputStream;
2424
import java.io.IOException;
25+
import java.io.InputStream;
26+
import java.net.ServerSocket;
2527
import java.net.Socket;
2628
import org.junit.Test;
2729

@@ -48,4 +50,55 @@ protected Socket createSocket() {
4850

4951
assertEquals(bos.toString(), message);
5052
}
53+
54+
@Test(timeout = 5000)
55+
public void testSendMessageAfterAgentReconnect() throws InterruptedException {
56+
String[] messages = new String[2];
57+
58+
Thread server = new Thread(() -> {
59+
try {
60+
ServerSocket serverSocket = new ServerSocket(Endpoint.DEFAULT_TCP_ENDPOINT.getPort());
61+
Socket socket = serverSocket.accept();
62+
InputStream is = socket.getInputStream();
63+
byte[] bytes = new byte[1024];
64+
int read = is.read(bytes);
65+
messages[0] = new String(bytes, 0, read);
66+
67+
// Disconnect
68+
socket.close();
69+
serverSocket.close();
70+
71+
// Reconnect
72+
serverSocket = new ServerSocket(Endpoint.DEFAULT_TCP_ENDPOINT.getPort());
73+
socket = serverSocket.accept();
74+
is = socket.getInputStream();
75+
bytes = new byte[1024];
76+
read = is.read(bytes);
77+
messages[1] = new String(bytes, 0, read);
78+
79+
socket.close();
80+
serverSocket.close();
81+
} catch (IOException e) {
82+
throw new RuntimeException(e);
83+
}
84+
});
85+
86+
Thread client = new Thread(() -> {
87+
try (TCPClient tcpClient = new TCPClient(Endpoint.DEFAULT_TCP_ENDPOINT)) {
88+
tcpClient.sendMessage("Test message 1");
89+
Thread.sleep(1000);
90+
tcpClient.sendMessage("Test message 2");
91+
} catch (InterruptedException | IOException e) {
92+
throw new RuntimeException(e);
93+
}
94+
});
95+
96+
server.start();
97+
client.start();
98+
server.join();
99+
client.join();
100+
101+
assertEquals("Test message 1", messages[0]);
102+
assertEquals("Test message 2", messages[1]);
103+
}
51104
}

0 commit comments

Comments
 (0)