diff --git a/log4j-core-test/src/test/java/org/apache/logging/log4j/core/util/internal/InternalLoggerRegistryTest.java b/log4j-core-test/src/test/java/org/apache/logging/log4j/core/util/internal/InternalLoggerRegistryTest.java new file mode 100644 index 00000000000..81df39b24b9 --- /dev/null +++ b/log4j-core-test/src/test/java/org/apache/logging/log4j/core/util/internal/InternalLoggerRegistryTest.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.logging.log4j.core.util.internal; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.awaitility.Awaitility.await; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.lang.ref.WeakReference; +import java.lang.reflect.Field; +import java.util.Map; +import org.apache.logging.log4j.core.Logger; +import org.apache.logging.log4j.core.LoggerContext; +import org.apache.logging.log4j.message.MessageFactory; +import org.apache.logging.log4j.message.SimpleMessageFactory; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; + +class InternalLoggerRegistryTest { + private LoggerContext loggerContext; + private InternalLoggerRegistry registry; + private MessageFactory messageFactory; + + @BeforeEach + void setUp(TestInfo testInfo) throws NoSuchFieldException, IllegalAccessException { + loggerContext = new LoggerContext(testInfo.getDisplayName()); + final Field registryField = loggerContext.getClass().getDeclaredField("loggerRegistry"); + registryField.setAccessible(true); + registry = (InternalLoggerRegistry) registryField.get(loggerContext); + messageFactory = SimpleMessageFactory.INSTANCE; + } + + @AfterEach + void tearDown() { + if (loggerContext != null) { + loggerContext.stop(); + } + } + + @Test + void testGetLoggerReturnsNullForNonExistentLogger() { + assertNull(registry.getLogger("nonExistent", messageFactory)); + } + + @Test + void testComputeIfAbsentCreatesLogger() { + final Logger logger = registry.computeIfAbsent( + "testLogger", messageFactory, (name, factory) -> new Logger(loggerContext, name, factory) {}); + assertNotNull(logger); + assertEquals("testLogger", logger.getName()); + } + + @Test + void testGetLoggerRetrievesExistingLogger() { + final Logger logger = registry.computeIfAbsent( + "testLogger", messageFactory, (name, factory) -> new Logger(loggerContext, name, factory) {}); + assertSame(logger, registry.getLogger("testLogger", messageFactory)); + } + + @Test + void testHasLoggerReturnsCorrectStatus() { + assertFalse(registry.hasLogger("testLogger", messageFactory)); + registry.computeIfAbsent( + "testLogger", messageFactory, (name, factory) -> new Logger(loggerContext, name, factory) {}); + assertTrue(registry.hasLogger("testLogger", messageFactory)); + } + + @Test + void testExpungeStaleWeakReferenceEntries() { + final String loggerNamePrefix = "testLogger_"; + final int numberOfLoggers = 1000; + + for (int i = 0; i < numberOfLoggers; i++) { + final Logger logger = registry.computeIfAbsent( + loggerNamePrefix + i, + messageFactory, + (name, factory) -> new Logger(loggerContext, name, factory) {}); + logger.info("Using logger {}", logger.getName()); + } + + await().atMost(10, SECONDS).pollInterval(100, MILLISECONDS).untilAsserted(() -> { + System.gc(); + registry.computeIfAbsent( + "triggerExpunge", messageFactory, (name, factory) -> new Logger(loggerContext, name, factory) {}); + + final Map>> loggerRefByNameByMessageFactory = + reflectAndGetLoggerMapFromRegistry(); + final Map> loggerRefByName = + loggerRefByNameByMessageFactory.get(messageFactory); + + int unexpectedCount = 0; + for (int i = 0; i < numberOfLoggers; i++) { + if (loggerRefByName.containsKey(loggerNamePrefix + i)) { + unexpectedCount++; + } + } + assertEquals( + 0, unexpectedCount, "Found " + unexpectedCount + " unexpected stale entries for MessageFactory"); + }); + } + + @Test + void testExpungeStaleMessageFactoryEntry() { + final SimpleMessageFactory mockMessageFactory = new SimpleMessageFactory(); + Logger logger = registry.computeIfAbsent( + "testLogger", mockMessageFactory, (name, factory) -> new Logger(loggerContext, name, factory) {}); + logger.info("Using logger {}", logger.getName()); + logger = null; + + await().atMost(10, SECONDS).pollInterval(100, MILLISECONDS).untilAsserted(() -> { + System.gc(); + registry.getLogger("triggerExpunge", mockMessageFactory); + + final Map>> loggerRefByNameByMessageFactory = + reflectAndGetLoggerMapFromRegistry(); + assertNull( + loggerRefByNameByMessageFactory.get(mockMessageFactory), + "Stale MessageFactory entry was not removed from the outer map"); + }); + } + + private Map>> reflectAndGetLoggerMapFromRegistry() + throws NoSuchFieldException, IllegalAccessException { + final Field loggerMapField = registry.getClass().getDeclaredField("loggerRefByNameByMessageFactory"); + loggerMapField.setAccessible(true); + @SuppressWarnings("unchecked") + final Map>> loggerMap = + (Map>>) loggerMapField.get(registry); + return loggerMap; + } +} diff --git a/log4j-core/src/main/java/org/apache/logging/log4j/core/util/internal/InternalLoggerRegistry.java b/log4j-core/src/main/java/org/apache/logging/log4j/core/util/internal/InternalLoggerRegistry.java index eff1d46b77a..d75c47e978e 100644 --- a/log4j-core/src/main/java/org/apache/logging/log4j/core/util/internal/InternalLoggerRegistry.java +++ b/log4j-core/src/main/java/org/apache/logging/log4j/core/util/internal/InternalLoggerRegistry.java @@ -18,9 +18,12 @@ import static java.util.Objects.requireNonNull; +import java.lang.ref.Reference; +import java.lang.ref.ReferenceQueue; import java.lang.ref.WeakReference; import java.util.Collection; import java.util.HashMap; +import java.util.Iterator; import java.util.Map; import java.util.WeakHashMap; import java.util.concurrent.locks.Lock; @@ -40,7 +43,9 @@ * A registry of {@link Logger}s namespaced by name and message factory. * This class is internally used by {@link LoggerContext}. *

- * We don't use {@linkplain org.apache.logging.log4j.spi.LoggerRegistry the registry from Log4j API} to keep Log4j Core independent from the version of Log4j API at runtime. + * We don't use {@linkplain org.apache.logging.log4j.spi.LoggerRegistry the + * registry from Log4j API} to keep Log4j Core independent from the version of + * Log4j API at runtime. * This also allows Log4j Core to evolve independently from Log4j API. *

* @@ -53,13 +58,46 @@ public final class InternalLoggerRegistry { new WeakHashMap<>(); private final ReadWriteLock lock = new ReentrantReadWriteLock(); - private final Lock readLock = lock.readLock(); - private final Lock writeLock = lock.writeLock(); + // ReferenceQueue to track stale WeakReferences + private final ReferenceQueue staleLoggerRefs = new ReferenceQueue<>(); + public InternalLoggerRegistry() {} + /** + * Expunges stale entries for logger references and message factories. + */ + private void expungeStaleEntries() { + final Reference loggerRef = staleLoggerRefs.poll(); + + if (loggerRef != null) { + writeLock.lock(); + try { + while (staleLoggerRefs.poll() != null) { + // Clear refQueue + } + + final Iterator>>> + loggerRefByNameByMessageFactoryEntryIt = + loggerRefByNameByMessageFactory.entrySet().iterator(); + while (loggerRefByNameByMessageFactoryEntryIt.hasNext()) { + final Map.Entry>> + loggerRefByNameByMessageFactoryEntry = loggerRefByNameByMessageFactoryEntryIt.next(); + final Map> loggerRefByName = + loggerRefByNameByMessageFactoryEntry.getValue(); + loggerRefByName.values().removeIf(weakRef -> weakRef.get() == null); + if (loggerRefByName.isEmpty()) { + loggerRefByNameByMessageFactoryEntryIt.remove(); + } + } + } finally { + writeLock.unlock(); + } + } + } + /** * Returns the logger associated with the given name and message factory. * @@ -70,6 +108,8 @@ public InternalLoggerRegistry() {} public @Nullable Logger getLogger(final String name, final MessageFactory messageFactory) { requireNonNull(name, "name"); requireNonNull(messageFactory, "messageFactory"); + expungeStaleEntries(); + readLock.lock(); try { final Map> loggerRefByName = @@ -87,6 +127,8 @@ public InternalLoggerRegistry() {} } public Collection getLoggers() { + expungeStaleEntries(); + readLock.lock(); try { // Return a new collection to allow concurrent iteration over the loggers @@ -127,6 +169,8 @@ public boolean hasLogger(final String name, final MessageFactory messageFactory) public boolean hasLogger(final String name, final Class messageFactoryClass) { requireNonNull(name, "name"); requireNonNull(messageFactoryClass, "messageFactoryClass"); + expungeStaleEntries(); + readLock.lock(); try { return loggerRefByNameByMessageFactory.entrySet().stream() @@ -146,6 +190,7 @@ public Logger computeIfAbsent( requireNonNull(name, "name"); requireNonNull(messageFactory, "messageFactory"); requireNonNull(loggerSupplier, "loggerSupplier"); + // Skipping `expungeStaleEntries()`, it will be invoked by the `getLogger()` invocation below // Read lock fast path: See if logger already exists @Nullable Logger logger = getLogger(name, messageFactory); @@ -194,9 +239,10 @@ public Logger computeIfAbsent( if (loggerRefByName == null) { loggerRefByNameByMessageFactory.put(messageFactory, loggerRefByName = new HashMap<>()); } + final WeakReference loggerRef = loggerRefByName.get(name); if (loggerRef == null || (logger = loggerRef.get()) == null) { - loggerRefByName.put(name, new WeakReference<>(logger = newLogger)); + loggerRefByName.put(name, new WeakReference<>(logger = newLogger, staleLoggerRefs)); } return logger; } finally { diff --git a/src/changelog/.2.x.x/3430_InternalLoggerRegistry_stale_entry_expunge.xml b/src/changelog/.2.x.x/3430_InternalLoggerRegistry_stale_entry_expunge.xml new file mode 100644 index 00000000000..75c40209783 --- /dev/null +++ b/src/changelog/.2.x.x/3430_InternalLoggerRegistry_stale_entry_expunge.xml @@ -0,0 +1,13 @@ + + + + + + Improved expunging of stale entries in `InternalLoggerRegistry` to prevent potential memory leaks + +