Skip to content

Commit f79486f

Browse files
author
Brandon Dahler
committed
Update TCPClient to detect and handle TCP socket closures.
1 parent 11369ac commit f79486f

File tree

2 files changed

+61
-38
lines changed

2 files changed

+61
-38
lines changed

src/main/java/software/amazon/cloudwatchlogs/emf/sinks/TCPClient.java

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,55 +17,53 @@
1717
package software.amazon.cloudwatchlogs.emf.sinks;
1818

1919
import java.io.IOException;
20-
import java.io.OutputStream;
2120
import java.net.InetSocketAddress;
22-
import java.net.Socket;
21+
import java.nio.ByteBuffer;
22+
import java.nio.channels.SocketChannel;
2323
import lombok.extern.slf4j.Slf4j;
2424

2525
/** A client that would connect to a TCP socket. */
2626
@Slf4j
2727
public class TCPClient implements SocketClient {
2828

2929
private final Endpoint endpoint;
30-
private Socket socket;
30+
private SocketChannel socketChannel;
3131
private boolean shouldConnect = true;
3232

33+
private final ByteBuffer readBuffer = ByteBuffer.allocate(1);
34+
3335
public TCPClient(Endpoint endpoint) {
3436
this.endpoint = endpoint;
3537
}
3638

3739
private void connect() {
3840
try {
39-
socket = createSocket();
40-
socket.connect(new InetSocketAddress(endpoint.getHost(), endpoint.getPort()));
41+
socketChannel = SocketChannel.open();
42+
socketChannel.connect(new InetSocketAddress(endpoint.getHost(), endpoint.getPort()));
4143
shouldConnect = false;
4244
} catch (Exception e) {
4345
shouldConnect = true;
4446
throw new RuntimeException("Failed to connect to the socket.", e);
4547
}
4648
}
4749

48-
protected Socket createSocket() {
49-
return new Socket();
50-
}
51-
5250
@Override
5351
public synchronized void sendMessage(String message) {
54-
if (socket == null || socket.isClosed() || shouldConnect) {
52+
if (socketChannel == null || !socketChannel.isConnected() || shouldConnect) {
5553
connect();
5654
}
5755

58-
OutputStream os;
5956
try {
60-
os = socket.getOutputStream();
61-
} catch (IOException e) {
62-
shouldConnect = true;
63-
throw new RuntimeException(
64-
"Failed to write message to the socket. Failed to open output stream.", e);
65-
}
57+
socketChannel.configureBlocking(true);
58+
socketChannel.write(ByteBuffer.wrap(message.getBytes()));
59+
60+
// Execute a non-blocking, single-byte read to detect if there was a connection closure.
61+
// No actual data is expected to be read.
62+
readBuffer.clear();
63+
64+
socketChannel.configureBlocking(false);
65+
socketChannel.read(readBuffer);
6666

67-
try {
68-
os.write(message.getBytes());
6967
} catch (Exception e) {
7068
shouldConnect = true;
7169
throw new RuntimeException("Failed to write message to the socket.", e);
@@ -74,8 +72,8 @@ public synchronized void sendMessage(String message) {
7472

7573
@Override
7674
public void close() throws IOException {
77-
if (socket != null) {
78-
socket.close();
75+
if (socketChannel != null) {
76+
socketChannel.close();
7977
}
8078
}
8179
}

src/test/java/software/amazon/cloudwatchlogs/emf/sinks/TCPClientTest.java

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,36 +16,61 @@
1616

1717
package software.amazon.cloudwatchlogs.emf.sinks;
1818

19-
import static org.junit.Assert.assertEquals;
20-
import static org.mockito.ArgumentMatchers.any;
21-
import static org.mockito.Mockito.*;
19+
import static org.junit.Assert.assertArrayEquals;
20+
import static org.junit.Assert.assertThrows;
2221

23-
import java.io.ByteArrayOutputStream;
2422
import java.io.IOException;
25-
import java.net.Socket;
23+
import java.net.InetSocketAddress;
24+
import java.nio.ByteBuffer;
25+
import java.nio.channels.ServerSocketChannel;
26+
import java.nio.channels.SocketChannel;
27+
import java.nio.charset.StandardCharsets;
2628
import org.junit.Test;
2729

2830
public class TCPClientTest {
2931

3032
@Test
3133
public void testSendMessage() throws IOException {
32-
Socket socket = mock(Socket.class);
33-
ByteArrayOutputStream bos = new ByteArrayOutputStream();
34-
when(socket.getOutputStream()).thenReturn(bos);
35-
doNothing().when(socket).connect(any());
3634
Endpoint endpoint = Endpoint.DEFAULT_TCP_ENDPOINT;
35+
InetSocketAddress socketAddress =
36+
new InetSocketAddress(endpoint.getHost(), endpoint.getPort());
3737

38-
TCPClient client =
39-
new TCPClient(endpoint) {
40-
@Override
41-
protected Socket createSocket() {
42-
return socket;
43-
}
44-
};
38+
ServerSocketChannel serverListener = ServerSocketChannel.open();
39+
serverListener.bind(socketAddress);
40+
41+
TCPClient client = new TCPClient(endpoint);
4542

4643
String message = "Test message";
4744
client.sendMessage(message);
4845

49-
assertEquals(bos.toString(), message);
46+
byte[] messageBytes = message.getBytes(StandardCharsets.UTF_8);
47+
ByteBuffer receiveBuffer = ByteBuffer.allocate(messageBytes.length);
48+
49+
SocketChannel serverChannel = serverListener.accept();
50+
serverChannel.read(receiveBuffer);
51+
52+
assertArrayEquals(receiveBuffer.array(), messageBytes);
5053
}
54+
55+
56+
@Test
57+
public void testDetectSocketClosure() throws IOException {
58+
Endpoint endpoint = Endpoint.DEFAULT_TCP_ENDPOINT;
59+
InetSocketAddress socketAddress =
60+
new InetSocketAddress(endpoint.getHost(), endpoint.getPort());
61+
62+
ServerSocketChannel serverListener = ServerSocketChannel.open();
63+
serverListener.bind(socketAddress);
64+
65+
TCPClient client = new TCPClient(endpoint);
66+
67+
String message = "Test message";
68+
client.sendMessage(message);
69+
70+
SocketChannel serverChannel = serverListener.accept();
71+
serverChannel.close();
72+
73+
assertThrows(RuntimeException.class, () -> client.sendMessage(message));
74+
}
75+
5176
}

0 commit comments

Comments
 (0)