From 1c819a4753606a5140898fae43c27aba03ec4771 Mon Sep 17 00:00:00 2001 From: Brandon Dahler Date: Wed, 20 Apr 2022 13:51:20 -0400 Subject: [PATCH] Update TCPClient to detect and handle TCP socket closures. --- .../cloudwatchlogs/emf/sinks/TCPClient.java | 40 ++++++------ .../emf/sinks/TCPClientTest.java | 65 +++++++++++++------ 2 files changed, 65 insertions(+), 40 deletions(-) diff --git a/src/main/java/software/amazon/cloudwatchlogs/emf/sinks/TCPClient.java b/src/main/java/software/amazon/cloudwatchlogs/emf/sinks/TCPClient.java index 3ae7f43a..e523fe7e 100644 --- a/src/main/java/software/amazon/cloudwatchlogs/emf/sinks/TCPClient.java +++ b/src/main/java/software/amazon/cloudwatchlogs/emf/sinks/TCPClient.java @@ -17,9 +17,9 @@ package software.amazon.cloudwatchlogs.emf.sinks; import java.io.IOException; -import java.io.OutputStream; import java.net.InetSocketAddress; -import java.net.Socket; +import java.nio.ByteBuffer; +import java.nio.channels.SocketChannel; import lombok.extern.slf4j.Slf4j; /** A client that would connect to a TCP socket. */ @@ -27,17 +27,19 @@ public class TCPClient implements SocketClient { private final Endpoint endpoint; - private Socket socket; + private SocketChannel socketChannel; private boolean shouldConnect = true; + private final ByteBuffer readBuffer = ByteBuffer.allocate(1); + public TCPClient(Endpoint endpoint) { this.endpoint = endpoint; } private void connect() { try { - socket = createSocket(); - socket.connect(new InetSocketAddress(endpoint.getHost(), endpoint.getPort())); + socketChannel = SocketChannel.open(); + socketChannel.connect(new InetSocketAddress(endpoint.getHost(), endpoint.getPort())); shouldConnect = false; } catch (Exception e) { shouldConnect = true; @@ -45,27 +47,23 @@ private void connect() { } } - protected Socket createSocket() { - return new Socket(); - } - @Override public synchronized void sendMessage(String message) { - if (socket == null || socket.isClosed() || shouldConnect) { + if (socketChannel == null || !socketChannel.isConnected() || shouldConnect) { connect(); } - OutputStream os; try { - os = socket.getOutputStream(); - } catch (IOException e) { - shouldConnect = true; - throw new RuntimeException( - "Failed to write message to the socket. Failed to open output stream.", e); - } + socketChannel.configureBlocking(true); + socketChannel.write(ByteBuffer.wrap(message.getBytes())); + + // Execute a non-blocking, single-byte read to detect if there was a connection closure. + // No actual data is expected to be read. + readBuffer.clear(); + + socketChannel.configureBlocking(false); + socketChannel.read(readBuffer); - try { - os.write(message.getBytes()); } catch (Exception e) { shouldConnect = true; throw new RuntimeException("Failed to write message to the socket.", e); @@ -74,8 +72,8 @@ public synchronized void sendMessage(String message) { @Override public void close() throws IOException { - if (socket != null) { - socket.close(); + if (socketChannel != null) { + socketChannel.close(); } } } diff --git a/src/test/java/software/amazon/cloudwatchlogs/emf/sinks/TCPClientTest.java b/src/test/java/software/amazon/cloudwatchlogs/emf/sinks/TCPClientTest.java index ccbf4a44..28177811 100644 --- a/src/test/java/software/amazon/cloudwatchlogs/emf/sinks/TCPClientTest.java +++ b/src/test/java/software/amazon/cloudwatchlogs/emf/sinks/TCPClientTest.java @@ -16,36 +16,63 @@ package software.amazon.cloudwatchlogs.emf.sinks; -import static org.junit.Assert.assertEquals; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertThrows; -import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.net.Socket; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.nio.charset.StandardCharsets; import org.junit.Test; public class TCPClientTest { @Test public void testSendMessage() throws IOException { - Socket socket = mock(Socket.class); - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - when(socket.getOutputStream()).thenReturn(bos); - doNothing().when(socket).connect(any()); Endpoint endpoint = Endpoint.DEFAULT_TCP_ENDPOINT; + InetSocketAddress socketAddress = + new InetSocketAddress(endpoint.getHost(), endpoint.getPort()); - TCPClient client = - new TCPClient(endpoint) { - @Override - protected Socket createSocket() { - return socket; - } - }; + try (ServerSocketChannel serverListener = ServerSocketChannel.open()) { + serverListener.bind(socketAddress); - String message = "Test message"; - client.sendMessage(message); + try (TCPClient client = new TCPClient(endpoint)) { + String message = "Test message"; + client.sendMessage(message); - assertEquals(bos.toString(), message); + byte[] messageBytes = message.getBytes(StandardCharsets.UTF_8); + ByteBuffer receiveBuffer = ByteBuffer.allocate(messageBytes.length); + + try (SocketChannel serverChannel = serverListener.accept()) { + serverChannel.read(receiveBuffer); + } + + assertArrayEquals(receiveBuffer.array(), messageBytes); + } + } + } + + @Test + public void testDetectSocketClosure() throws IOException { + Endpoint endpoint = Endpoint.DEFAULT_TCP_ENDPOINT; + InetSocketAddress socketAddress = + new InetSocketAddress(endpoint.getHost(), endpoint.getPort()); + + try (ServerSocketChannel serverListener = ServerSocketChannel.open()) { + serverListener.bind(socketAddress); + + try (TCPClient client = new TCPClient(endpoint)) { + + String message = "Test message"; + client.sendMessage(message); + + SocketChannel serverChannel = serverListener.accept(); + serverChannel.close(); + + assertThrows(RuntimeException.class, () -> client.sendMessage(message)); + } + } } }