Skip to content

Commit 1c819a4

Browse files
author
Brandon Dahler
committed
Update TCPClient to detect and handle TCP socket closures.
1 parent 11369ac commit 1c819a4

File tree

2 files changed

+65
-40
lines changed

2 files changed

+65
-40
lines changed

src/main/java/software/amazon/cloudwatchlogs/emf/sinks/TCPClient.java

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,55 +17,53 @@
1717
package software.amazon.cloudwatchlogs.emf.sinks;
1818

1919
import java.io.IOException;
20-
import java.io.OutputStream;
2120
import java.net.InetSocketAddress;
22-
import java.net.Socket;
21+
import java.nio.ByteBuffer;
22+
import java.nio.channels.SocketChannel;
2323
import lombok.extern.slf4j.Slf4j;
2424

2525
/** A client that would connect to a TCP socket. */
2626
@Slf4j
2727
public class TCPClient implements SocketClient {
2828

2929
private final Endpoint endpoint;
30-
private Socket socket;
30+
private SocketChannel socketChannel;
3131
private boolean shouldConnect = true;
3232

33+
private final ByteBuffer readBuffer = ByteBuffer.allocate(1);
34+
3335
public TCPClient(Endpoint endpoint) {
3436
this.endpoint = endpoint;
3537
}
3638

3739
private void connect() {
3840
try {
39-
socket = createSocket();
40-
socket.connect(new InetSocketAddress(endpoint.getHost(), endpoint.getPort()));
41+
socketChannel = SocketChannel.open();
42+
socketChannel.connect(new InetSocketAddress(endpoint.getHost(), endpoint.getPort()));
4143
shouldConnect = false;
4244
} catch (Exception e) {
4345
shouldConnect = true;
4446
throw new RuntimeException("Failed to connect to the socket.", e);
4547
}
4648
}
4749

48-
protected Socket createSocket() {
49-
return new Socket();
50-
}
51-
5250
@Override
5351
public synchronized void sendMessage(String message) {
54-
if (socket == null || socket.isClosed() || shouldConnect) {
52+
if (socketChannel == null || !socketChannel.isConnected() || shouldConnect) {
5553
connect();
5654
}
5755

58-
OutputStream os;
5956
try {
60-
os = socket.getOutputStream();
61-
} catch (IOException e) {
62-
shouldConnect = true;
63-
throw new RuntimeException(
64-
"Failed to write message to the socket. Failed to open output stream.", e);
65-
}
57+
socketChannel.configureBlocking(true);
58+
socketChannel.write(ByteBuffer.wrap(message.getBytes()));
59+
60+
// Execute a non-blocking, single-byte read to detect if there was a connection closure.
61+
// No actual data is expected to be read.
62+
readBuffer.clear();
63+
64+
socketChannel.configureBlocking(false);
65+
socketChannel.read(readBuffer);
6666

67-
try {
68-
os.write(message.getBytes());
6967
} catch (Exception e) {
7068
shouldConnect = true;
7169
throw new RuntimeException("Failed to write message to the socket.", e);
@@ -74,8 +72,8 @@ public synchronized void sendMessage(String message) {
7472

7573
@Override
7674
public void close() throws IOException {
77-
if (socket != null) {
78-
socket.close();
75+
if (socketChannel != null) {
76+
socketChannel.close();
7977
}
8078
}
8179
}

src/test/java/software/amazon/cloudwatchlogs/emf/sinks/TCPClientTest.java

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,36 +16,63 @@
1616

1717
package software.amazon.cloudwatchlogs.emf.sinks;
1818

19-
import static org.junit.Assert.assertEquals;
20-
import static org.mockito.ArgumentMatchers.any;
21-
import static org.mockito.Mockito.*;
19+
import static org.junit.Assert.assertArrayEquals;
20+
import static org.junit.Assert.assertThrows;
2221

23-
import java.io.ByteArrayOutputStream;
2422
import java.io.IOException;
25-
import java.net.Socket;
23+
import java.net.InetSocketAddress;
24+
import java.nio.ByteBuffer;
25+
import java.nio.channels.ServerSocketChannel;
26+
import java.nio.channels.SocketChannel;
27+
import java.nio.charset.StandardCharsets;
2628
import org.junit.Test;
2729

2830
public class TCPClientTest {
2931

3032
@Test
3133
public void testSendMessage() throws IOException {
32-
Socket socket = mock(Socket.class);
33-
ByteArrayOutputStream bos = new ByteArrayOutputStream();
34-
when(socket.getOutputStream()).thenReturn(bos);
35-
doNothing().when(socket).connect(any());
3634
Endpoint endpoint = Endpoint.DEFAULT_TCP_ENDPOINT;
35+
InetSocketAddress socketAddress =
36+
new InetSocketAddress(endpoint.getHost(), endpoint.getPort());
3737

38-
TCPClient client =
39-
new TCPClient(endpoint) {
40-
@Override
41-
protected Socket createSocket() {
42-
return socket;
43-
}
44-
};
38+
try (ServerSocketChannel serverListener = ServerSocketChannel.open()) {
39+
serverListener.bind(socketAddress);
4540

46-
String message = "Test message";
47-
client.sendMessage(message);
41+
try (TCPClient client = new TCPClient(endpoint)) {
42+
String message = "Test message";
43+
client.sendMessage(message);
4844

49-
assertEquals(bos.toString(), message);
45+
byte[] messageBytes = message.getBytes(StandardCharsets.UTF_8);
46+
ByteBuffer receiveBuffer = ByteBuffer.allocate(messageBytes.length);
47+
48+
try (SocketChannel serverChannel = serverListener.accept()) {
49+
serverChannel.read(receiveBuffer);
50+
}
51+
52+
assertArrayEquals(receiveBuffer.array(), messageBytes);
53+
}
54+
}
55+
}
56+
57+
@Test
58+
public void testDetectSocketClosure() throws IOException {
59+
Endpoint endpoint = Endpoint.DEFAULT_TCP_ENDPOINT;
60+
InetSocketAddress socketAddress =
61+
new InetSocketAddress(endpoint.getHost(), endpoint.getPort());
62+
63+
try (ServerSocketChannel serverListener = ServerSocketChannel.open()) {
64+
serverListener.bind(socketAddress);
65+
66+
try (TCPClient client = new TCPClient(endpoint)) {
67+
68+
String message = "Test message";
69+
client.sendMessage(message);
70+
71+
SocketChannel serverChannel = serverListener.accept();
72+
serverChannel.close();
73+
74+
assertThrows(RuntimeException.class, () -> client.sendMessage(message));
75+
}
76+
}
5077
}
5178
}

0 commit comments

Comments
 (0)