diff --git a/src/main/java/software/amazon/cloudwatchlogs/emf/sinks/TCPClient.java b/src/main/java/software/amazon/cloudwatchlogs/emf/sinks/TCPClient.java index 3ae7f43a..1805f551 100644 --- a/src/main/java/software/amazon/cloudwatchlogs/emf/sinks/TCPClient.java +++ b/src/main/java/software/amazon/cloudwatchlogs/emf/sinks/TCPClient.java @@ -51,9 +51,7 @@ protected Socket createSocket() { @Override public synchronized void sendMessage(String message) { - if (socket == null || socket.isClosed() || shouldConnect) { - connect(); - } + checkConnection(); OutputStream os; try { @@ -65,13 +63,22 @@ public synchronized void sendMessage(String message) { } try { + // Write a space to the socket to verify connection before sending event + os.write(32); + os.write(message.getBytes()); - } catch (Exception e) { + } catch (IOException e) { shouldConnect = true; throw new RuntimeException("Failed to write message to the socket.", e); } } + private void checkConnection() { + if (socket == null || socket.isClosed() || shouldConnect) { + connect(); + } + } + @Override public void close() throws IOException { if (socket != null) { diff --git a/src/test/java/software/amazon/cloudwatchlogs/emf/sinks/TCPClientTest.java b/src/test/java/software/amazon/cloudwatchlogs/emf/sinks/TCPClientTest.java index ccbf4a44..e22048cc 100644 --- a/src/test/java/software/amazon/cloudwatchlogs/emf/sinks/TCPClientTest.java +++ b/src/test/java/software/amazon/cloudwatchlogs/emf/sinks/TCPClientTest.java @@ -22,6 +22,7 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.net.ServerSocket; import java.net.Socket; import org.junit.Test; @@ -45,7 +46,24 @@ protected Socket createSocket() { String message = "Test message"; client.sendMessage(message); + client.close(); - assertEquals(bos.toString(), message); + assertEquals(message, bos.toString().trim()); + } + + @Test(timeout = 5000) + public void testSendMessageWithSocketServer() throws IOException { + TCPClient client = new TCPClient(new Endpoint("0.0.0.0", 9999, Protocol.TCP)); + ServerSocket server = new ServerSocket(9999); + client.sendMessage("Test message"); + Socket socket = server.accept(); + + byte[] bytes = new byte[1024]; + int read = socket.getInputStream().read(bytes); + String message = new String(bytes, 0, read); + socket.close(); + server.close(); + + assertEquals("Test message", message.trim()); } }