diff --git a/src/NHibernate/Async/Engine/BatchFetchQueue.cs b/src/NHibernate/Async/Engine/BatchFetchQueue.cs index 1587efd648e..470619da3d5 100644 --- a/src/NHibernate/Async/Engine/BatchFetchQueue.cs +++ b/src/NHibernate/Async/Engine/BatchFetchQueue.cs @@ -15,6 +15,7 @@ using NHibernate.Persister.Entity; using NHibernate.Util; using System.Collections.Generic; +using Iesi.Collections.Generic; namespace NHibernate.Engine { @@ -40,22 +41,33 @@ public async Task GetCollectionBatchAsync(ICollectionPersister collect int end = -1; bool checkForEnd = false; - // this only works because collection entries are kept in a sequenced - // map by persistence context (maybe we should do like entities and - // keep a separate sequences set...) - foreach (DictionaryEntry me in context.CollectionEntries) + if (batchLoadableCollections.TryGetValue(collectionPersister.Role, out var map)) { - CollectionEntry ce = (CollectionEntry) me.Value; - IPersistentCollection collection = (IPersistentCollection) me.Key; - if (!collection.WasInitialized && ce.LoadedPersister == collectionPersister) + foreach (KeyValuePair me in map) { + var ce = me.Key; + var collection = me.Value; + if (ce.LoadedKey == null) + { + // the LoadedKey of the CollectionEntry might be null as it might have been reset to null + // (see for example Collections.ProcessDereferencedCollection() + // and CollectionEntry.AfterAction()) + // though we clear the queue on flush, it seems like a good idea to guard + // against potentially null LoadedKey:s + continue; + } + + if (collection.WasInitialized) + { + log.Warn("Encountered initialized collection in BatchFetchQueue, this should not happen."); + continue; + } + if (checkForEnd && i == end) { return keys; //the first key found after the given key } - //if ( end == -1 && count > batchSize*10 ) return keys; //try out ten batches, max - bool isEqual = collectionPersister.KeyType.IsEqual(id, ce.LoadedKey, collectionPersister.Factory); if (isEqual) @@ -79,6 +91,7 @@ public async Task GetCollectionBatchAsync(ICollectionPersister collect } } } + return keys; //we ran out of keys to try } @@ -92,7 +105,7 @@ public async Task GetCollectionBatchAsync(ICollectionPersister collect /// The maximum number of keys to return /// A cancellation token that can be used to cancel the work /// an array of identifiers, of length batchSize (possibly padded with nulls) - public async Task GetEntityBatchAsync(IEntityPersister persister,object id,int batchSize, CancellationToken cancellationToken) + public async Task GetEntityBatchAsync(IEntityPersister persister, object id, int batchSize, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); object[] ids = new object[batchSize]; @@ -101,9 +114,9 @@ public async Task GetEntityBatchAsync(IEntityPersister persister,objec int end = -1; bool checkForEnd = false; - foreach (EntityKey key in batchLoadableEntityKeys.Keys) + if (batchLoadableEntityKeys.TryGetValue(persister.EntityName, out var set)) { - if (key.EntityName.Equals(persister.EntityName)) + foreach (var key in set) { //TODO: this needn't exclude subclasses... if (checkForEnd && i == end) diff --git a/src/NHibernate/Async/Engine/Loading/CollectionLoadContext.cs b/src/NHibernate/Async/Engine/Loading/CollectionLoadContext.cs index b4f81ed2c74..0ab01e61d22 100644 --- a/src/NHibernate/Async/Engine/Loading/CollectionLoadContext.cs +++ b/src/NHibernate/Async/Engine/Loading/CollectionLoadContext.cs @@ -128,7 +128,9 @@ private async Task EndLoadingCollectionAsync(LoadingCollectionEntry lce, ICollec { log.Debug("ending loading collection [{0}]", lce); } - ISessionImplementor session = LoadContext.PersistenceContext.Session; + + var persistenceContext = LoadContext.PersistenceContext; + var session = persistenceContext.Session; bool statsEnabled = session.Factory.Statistics.IsStatisticsEnabled; var stopWath = new Stopwatch(); @@ -141,17 +143,17 @@ private async Task EndLoadingCollectionAsync(LoadingCollectionEntry lce, ICollec if (persister.CollectionType.HasHolder()) { - LoadContext.PersistenceContext.AddCollectionHolder(lce.Collection); + persistenceContext.AddCollectionHolder(lce.Collection); } - CollectionEntry ce = LoadContext.PersistenceContext.GetCollectionEntry(lce.Collection); + CollectionEntry ce = persistenceContext.GetCollectionEntry(lce.Collection); if (ce == null) { - ce = LoadContext.PersistenceContext.AddInitializedCollection(persister, lce.Collection, lce.Key); + ce = persistenceContext.AddInitializedCollection(persister, lce.Collection, lce.Key); } else { - ce.PostInitialize(lce.Collection); + ce.PostInitialize(lce.Collection, persistenceContext); } bool addToCache = hasNoQueuedAdds && persister.HasCache && diff --git a/src/NHibernate/Async/Event/Default/DefaultInitializeCollectionEventListener.cs b/src/NHibernate/Async/Event/Default/DefaultInitializeCollectionEventListener.cs index d62cac434c8..1f0d706bed1 100644 --- a/src/NHibernate/Async/Event/Default/DefaultInitializeCollectionEventListener.cs +++ b/src/NHibernate/Async/Event/Default/DefaultInitializeCollectionEventListener.cs @@ -127,7 +127,7 @@ private async Task InitializeCollectionFromCacheAsync(object id, ICollecti CollectionCacheEntry cacheEntry = (CollectionCacheEntry)persister.CacheEntryStructure.Destructure(ce, factory); await (cacheEntry.AssembleAsync(collection, persister, persistenceContext.GetCollectionOwner(id, persister), cancellationToken)).ConfigureAwait(false); - persistenceContext.GetCollectionEntry(collection).PostInitialize(collection); + persistenceContext.GetCollectionEntry(collection).PostInitialize(collection, persistenceContext); return true; } } diff --git a/src/NHibernate/Engine/BatchFetchQueue.cs b/src/NHibernate/Engine/BatchFetchQueue.cs index bca90051da5..dd668788379 100644 --- a/src/NHibernate/Engine/BatchFetchQueue.cs +++ b/src/NHibernate/Engine/BatchFetchQueue.cs @@ -5,23 +5,23 @@ using NHibernate.Persister.Entity; using NHibernate.Util; using System.Collections.Generic; +using Iesi.Collections.Generic; namespace NHibernate.Engine { public partial class BatchFetchQueue { - private static readonly object Marker = new object(); + private static readonly INHibernateLogger log = NHibernateLogger.For(typeof(BatchFetchQueue)); /// - /// Defines a sequence of elements that are currently - /// eligible for batch fetching. + /// Used to hold information about the entities that are currently eligible for batch-fetching. Ultimately + /// used by to build entity load batches. /// /// - /// Even though this is a map, we only use the keys. A map was chosen in - /// order to utilize a to maintain sequencing - /// as well as uniqueness. + /// A Map structure is used to segment the keys by entity type since loading can only be done for a particular entity + /// type at a time. /// - private readonly IDictionary batchLoadableEntityKeys = new LinkedHashMap(8); + private readonly IDictionary> batchLoadableEntityKeys = new Dictionary>(8); /// /// A map of subselect-fetch descriptors @@ -30,6 +30,7 @@ public partial class BatchFetchQueue /// private readonly IDictionary subselectsByEntityKey = new Dictionary(8); + private readonly IDictionary> batchLoadableCollections = new Dictionary>(8); /// /// The owning persistence context. /// @@ -50,6 +51,7 @@ public BatchFetchQueue(IPersistenceContext context) public void Clear() { batchLoadableEntityKeys.Clear(); + batchLoadableCollections.Clear(); subselectsByEntityKey.Clear(); } @@ -113,7 +115,12 @@ public void AddBatchLoadableEntityKey(EntityKey key) { if (key.IsBatchLoadable) { - batchLoadableEntityKeys[key] = Marker; + if (!batchLoadableEntityKeys.TryGetValue(key.EntityName, out var set)) + { + set = new LinkedHashSet(); + batchLoadableEntityKeys.Add(key.EntityName, set); + } + set.Add(key); } } @@ -125,7 +132,44 @@ public void AddBatchLoadableEntityKey(EntityKey key) public void RemoveBatchLoadableEntityKey(EntityKey key) { if (key.IsBatchLoadable) - batchLoadableEntityKeys.Remove(key); + { + if (batchLoadableEntityKeys.TryGetValue(key.EntityName, out var set)) + { + set.Remove(key); + } + } + } + + /// + /// If a CollectionEntry represents a batch loadable collection, add + /// it to the queue. + /// + /// + /// + public void AddBatchLoadableCollection(IPersistentCollection collection, CollectionEntry ce) + { + var persister = ce.LoadedPersister; + + if (!batchLoadableCollections.TryGetValue(persister.Role, out var map)) + { + map = new LinkedHashMap(); + batchLoadableCollections.Add(persister.Role, map); + } + map[ce] = collection; + } + + /// + /// After a collection was initialized or evicted, we don't + /// need to batch fetch it anymore, remove it from the queue + /// if necessary + /// + /// + public void RemoveBatchLoadableCollection(CollectionEntry ce) + { + if (batchLoadableCollections.TryGetValue(ce.LoadedPersister.Role, out var map)) + { + map.Remove(ce); + } } /// @@ -143,22 +187,33 @@ public object[] GetCollectionBatch(ICollectionPersister collectionPersister, obj int end = -1; bool checkForEnd = false; - // this only works because collection entries are kept in a sequenced - // map by persistence context (maybe we should do like entities and - // keep a separate sequences set...) - foreach (DictionaryEntry me in context.CollectionEntries) + if (batchLoadableCollections.TryGetValue(collectionPersister.Role, out var map)) { - CollectionEntry ce = (CollectionEntry) me.Value; - IPersistentCollection collection = (IPersistentCollection) me.Key; - if (!collection.WasInitialized && ce.LoadedPersister == collectionPersister) + foreach (KeyValuePair me in map) { + var ce = me.Key; + var collection = me.Value; + if (ce.LoadedKey == null) + { + // the LoadedKey of the CollectionEntry might be null as it might have been reset to null + // (see for example Collections.ProcessDereferencedCollection() + // and CollectionEntry.AfterAction()) + // though we clear the queue on flush, it seems like a good idea to guard + // against potentially null LoadedKey:s + continue; + } + + if (collection.WasInitialized) + { + log.Warn("Encountered initialized collection in BatchFetchQueue, this should not happen."); + continue; + } + if (checkForEnd && i == end) { return keys; //the first key found after the given key } - //if ( end == -1 && count > batchSize*10 ) return keys; //try out ten batches, max - bool isEqual = collectionPersister.KeyType.IsEqual(id, ce.LoadedKey, collectionPersister.Factory); if (isEqual) @@ -182,6 +237,7 @@ public object[] GetCollectionBatch(ICollectionPersister collectionPersister, obj } } } + return keys; //we ran out of keys to try } @@ -194,7 +250,7 @@ public object[] GetCollectionBatch(ICollectionPersister collectionPersister, obj /// The identifier of the entity currently demanding load. /// The maximum number of keys to return /// an array of identifiers, of length batchSize (possibly padded with nulls) - public object[] GetEntityBatch(IEntityPersister persister,object id,int batchSize) + public object[] GetEntityBatch(IEntityPersister persister, object id, int batchSize) { object[] ids = new object[batchSize]; ids[0] = id; //first element of array is reserved for the actual instance we are loading! @@ -202,9 +258,9 @@ public object[] GetEntityBatch(IEntityPersister persister,object id,int batchSiz int end = -1; bool checkForEnd = false; - foreach (EntityKey key in batchLoadableEntityKeys.Keys) + if (batchLoadableEntityKeys.TryGetValue(persister.EntityName, out var set)) { - if (key.EntityName.Equals(persister.EntityName)) + foreach (var key in set) { //TODO: this needn't exclude subclasses... if (checkForEnd && i == end) diff --git a/src/NHibernate/Engine/CollectionEntry.cs b/src/NHibernate/Engine/CollectionEntry.cs index 7d8ee06aa45..131ec2e365d 100644 --- a/src/NHibernate/Engine/CollectionEntry.cs +++ b/src/NHibernate/Engine/CollectionEntry.cs @@ -298,12 +298,32 @@ public void PreFlush(IPersistentCollection collection) /// has been initialized. /// /// The initialized that this Entry is for. + //Since v5.1 + [Obsolete("Please use PostInitialize(collection, persistenceContext) instead.")] public void PostInitialize(IPersistentCollection collection) { snapshot = LoadedPersister.IsMutable ? collection.GetSnapshot(LoadedPersister) : null; collection.SetSnapshot(loadedKey, role, snapshot); } + /// + /// Updates the CollectionEntry to reflect that the + /// has been initialized. + /// + /// The initialized that this Entry is for. + /// + public void PostInitialize(IPersistentCollection collection, IPersistenceContext persistenceContext) + { +#pragma warning disable 618 + //6.0 TODO: Inline PostInitialize here. + PostInitialize(collection); +#pragma warning restore 618 + if (LoadedPersister.GetBatchSize() > 1) + { + persistenceContext.BatchFetchQueue.RemoveBatchLoadableCollection(this); + } + } + /// /// Updates the CollectionEntry to reflect that it is has been successfully flushed to the database. /// diff --git a/src/NHibernate/Engine/Loading/CollectionLoadContext.cs b/src/NHibernate/Engine/Loading/CollectionLoadContext.cs index 9d87fc4d044..9dd97e4351e 100644 --- a/src/NHibernate/Engine/Loading/CollectionLoadContext.cs +++ b/src/NHibernate/Engine/Loading/CollectionLoadContext.cs @@ -234,7 +234,9 @@ private void EndLoadingCollection(LoadingCollectionEntry lce, ICollectionPersist { log.Debug("ending loading collection [{0}]", lce); } - ISessionImplementor session = LoadContext.PersistenceContext.Session; + + var persistenceContext = LoadContext.PersistenceContext; + var session = persistenceContext.Session; bool statsEnabled = session.Factory.Statistics.IsStatisticsEnabled; var stopWath = new Stopwatch(); @@ -247,17 +249,17 @@ private void EndLoadingCollection(LoadingCollectionEntry lce, ICollectionPersist if (persister.CollectionType.HasHolder()) { - LoadContext.PersistenceContext.AddCollectionHolder(lce.Collection); + persistenceContext.AddCollectionHolder(lce.Collection); } - CollectionEntry ce = LoadContext.PersistenceContext.GetCollectionEntry(lce.Collection); + CollectionEntry ce = persistenceContext.GetCollectionEntry(lce.Collection); if (ce == null) { - ce = LoadContext.PersistenceContext.AddInitializedCollection(persister, lce.Collection, lce.Key); + ce = persistenceContext.AddInitializedCollection(persister, lce.Collection, lce.Key); } else { - ce.PostInitialize(lce.Collection); + ce.PostInitialize(lce.Collection, persistenceContext); } bool addToCache = hasNoQueuedAdds && persister.HasCache && diff --git a/src/NHibernate/Engine/StatefulPersistenceContext.cs b/src/NHibernate/Engine/StatefulPersistenceContext.cs index cdb75ca6a7d..709e8a0de01 100644 --- a/src/NHibernate/Engine/StatefulPersistenceContext.cs +++ b/src/NHibernate/Engine/StatefulPersistenceContext.cs @@ -835,6 +835,10 @@ public void AddUninitializedCollection(ICollectionPersister persister, IPersiste { CollectionEntry ce = new CollectionEntry(collection, persister, id, flushing); AddCollection(collection, ce, id); + if (persister.GetBatchSize() > 1) + { + batchFetchQueue.AddBatchLoadableCollection(collection, ce); + } } /// add a detached uninitialized collection @@ -913,7 +917,7 @@ public CollectionEntry AddInitializedCollection(ICollectionPersister persister, object id) { CollectionEntry ce = new CollectionEntry(collection, persister, id, flushing); - ce.PostInitialize(collection); + ce.PostInitialize(collection, this); AddCollection(collection, ce, id); return ce; } diff --git a/src/NHibernate/Event/Default/DefaultInitializeCollectionEventListener.cs b/src/NHibernate/Event/Default/DefaultInitializeCollectionEventListener.cs index c2d65277a60..2f37f28d1ea 100644 --- a/src/NHibernate/Event/Default/DefaultInitializeCollectionEventListener.cs +++ b/src/NHibernate/Event/Default/DefaultInitializeCollectionEventListener.cs @@ -115,7 +115,7 @@ private bool InitializeCollectionFromCache(object id, ICollectionPersister persi CollectionCacheEntry cacheEntry = (CollectionCacheEntry)persister.CacheEntryStructure.Destructure(ce, factory); cacheEntry.Assemble(collection, persister, persistenceContext.GetCollectionOwner(id, persister)); - persistenceContext.GetCollectionEntry(collection).PostInitialize(collection); + persistenceContext.GetCollectionEntry(collection).PostInitialize(collection, persistenceContext); return true; } } diff --git a/src/NHibernate/Event/Default/EvictVisitor.cs b/src/NHibernate/Event/Default/EvictVisitor.cs index 82501c24d07..ce4a3c89a95 100644 --- a/src/NHibernate/Event/Default/EvictVisitor.cs +++ b/src/NHibernate/Event/Default/EvictVisitor.cs @@ -2,6 +2,7 @@ using NHibernate.Collection; using NHibernate.Engine; using NHibernate.Impl; +using NHibernate.Persister.Collection; using NHibernate.Type; namespace NHibernate.Event.Default @@ -53,6 +54,10 @@ private void EvictCollection(IPersistentCollection collection) Session.PersistenceContext.CollectionEntries.Remove(collection); if (log.IsDebugEnabled()) log.Debug("evicting collection: {0}", MessageHelper.CollectionInfoString(ce.LoadedPersister, collection, ce.LoadedKey, Session)); + if (ce.LoadedPersister?.GetBatchSize() > 1) + { + Session.PersistenceContext.BatchFetchQueue.RemoveBatchLoadableCollection(ce); + } if (ce.LoadedPersister != null && ce.LoadedKey != null) { //TODO: is this 100% correct? diff --git a/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs b/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs index 27da5e4bc11..2611e99474d 100644 --- a/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs +++ b/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs @@ -2043,6 +2043,14 @@ public string[] RootTableKeyColumnNames get { return new string[] {IdentifierColumnName}; } } + /// + /// Get the batch size of a collection persister. + /// + public int GetBatchSize() + { + return batchSize; + } + public SqlString GetSelectByUniqueKeyString(string propertyName) { return diff --git a/src/NHibernate/Persister/Collection/ICollectionPersister.cs b/src/NHibernate/Persister/Collection/ICollectionPersister.cs index dc75daf4626..4d87ea158b3 100644 --- a/src/NHibernate/Persister/Collection/ICollectionPersister.cs +++ b/src/NHibernate/Persister/Collection/ICollectionPersister.cs @@ -1,3 +1,4 @@ +using System; using System.Collections.Generic; using System.Data.Common; using NHibernate.Cache; @@ -283,4 +284,25 @@ public partial interface ICollectionPersister /// object NotFoundObject { get; } } + + public static class CollectionPersisterExtensions + { + /// + /// Get the batch size of a collection persister. + /// + //6.0 TODO: Merge into ICollectionPersister and convert to a property. + public static int GetBatchSize(this ICollectionPersister persister) + { + if (persister is AbstractCollectionPersister acp) + { + return acp.GetBatchSize(); + } + + NHibernateLogger + .For(typeof(CollectionPersisterExtensions)) + .Warn("Collection persister of {0} type is not supported, returning 1 as a batch size.", persister?.GetType()); + + return 1; + } + } } diff --git a/src/NHibernate/Type/ManyToOneType.cs b/src/NHibernate/Type/ManyToOneType.cs index 5f130dc8c35..4ac52cd54f2 100644 --- a/src/NHibernate/Type/ManyToOneType.cs +++ b/src/NHibernate/Type/ManyToOneType.cs @@ -101,7 +101,7 @@ private void ScheduleBatchLoadIfNeeded(object id, ISessionImplementor session) { IEntityPersister persister = session.Factory.GetEntityPersister(GetAssociatedEntityName()); EntityKey entityKey = session.GenerateEntityKey(id, persister); - if (!session.PersistenceContext.ContainsEntity(entityKey)) + if (entityKey.IsBatchLoadable && !session.PersistenceContext.ContainsEntity(entityKey)) { session.PersistenceContext.BatchFetchQueue.AddBatchLoadableEntityKey(entityKey); }