From 2f0ba6f8fcf9fe925463dc36270bab4879c4d923 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Delaporte?= Date: Fri, 22 Sep 2017 17:53:30 +0200 Subject: [PATCH 1/2] Clean-up of TypeFactory * Removing undue global thread sync and concurrency checks, the concurrent dictionary handles already that. * Removing useless singleton pattern, likely a remnant of Java port. --- src/NHibernate/SqlTypes/SqlTypeFactory.cs | 2 - src/NHibernate/Type/TypeFactory.cs | 322 ++++++---------------- 2 files changed, 84 insertions(+), 240 deletions(-) diff --git a/src/NHibernate/SqlTypes/SqlTypeFactory.cs b/src/NHibernate/SqlTypes/SqlTypeFactory.cs index 23983c57a53..3f12c5224e2 100644 --- a/src/NHibernate/SqlTypes/SqlTypeFactory.cs +++ b/src/NHibernate/SqlTypes/SqlTypeFactory.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Concurrent; using System.Data; -using System.Runtime.CompilerServices; namespace NHibernate.SqlTypes { @@ -99,7 +98,6 @@ public static TimeSqlType GetTime(byte fractionalSecondsPrecision) return GetTypeWithLenOrScale(fractionalSecondsPrecision, l => new TimeSqlType(l)); } - [MethodImpl(MethodImplOptions.Synchronized)] public static SqlType GetSqlType(DbType dbType, byte precision, byte scale) { return GetTypeWithPrecision(dbType, precision, scale); diff --git a/src/NHibernate/Type/TypeFactory.cs b/src/NHibernate/Type/TypeFactory.cs index da1166f8dd2..9d14ce6812f 100644 --- a/src/NHibernate/Type/TypeFactory.cs +++ b/src/NHibernate/Type/TypeFactory.cs @@ -3,7 +3,6 @@ using System.Collections.Generic; using System.Globalization; using System.Reflection; -using System.Runtime.CompilerServices; using System.Xml; using System.Xml.Linq; using NHibernate.Bytecode; @@ -11,6 +10,7 @@ using NHibernate.SqlTypes; using NHibernate.UserTypes; using NHibernate.Util; +using Environment = NHibernate.Cfg.Environment; namespace NHibernate.Type { @@ -24,7 +24,7 @@ namespace NHibernate.Type /// correct IType. Instead use TypeFactory.GetString(300) and keep a local variable that holds /// a reference to the IType. /// - public sealed class TypeFactory + public static class TypeFactory { private enum TypeClassification { @@ -35,9 +35,8 @@ private enum TypeClassification private static readonly INHibernateLogger _log = NHibernateLogger.For(typeof(TypeFactory)); private static readonly string[] EmptyAliases= System.Array.Empty(); - private static readonly char[] PrecisionScaleSplit = new[] { '(', ')', ',' }; - private static readonly char[] LengthSplit = new[] { '(', ')' }; - private static readonly TypeFactory Instance; + private static readonly char[] PrecisionScaleSplit = { '(', ')', ',' }; + private static readonly char[] LengthSplit = { '(', ')' }; private static readonly MethodInfo BagDefinition = ReflectHelper.GetMethodDefinition( f => f.Bag(null, null)); @@ -124,11 +123,12 @@ private static void RegisterType(System.Type systemType, IType nhibernateType, private static IEnumerable GetClrTypeAliases(System.Type systemType) { - var typeAliases = new List - { - systemType.FullName, - systemType.AssemblyQualifiedName, - }; + var typeAliases = + new List + { + systemType.FullName, + systemType.AssemblyQualifiedName + }; if (systemType.IsValueType) { // Also register Nullable for ValueTypes @@ -194,8 +194,6 @@ private static void RegisterTypeAlias(IType nhibernateType, string alias) /// static TypeFactory() { - Instance = new TypeFactory(); - // set up the mappings of .NET Classes/Structs to their NHibernate types. RegisterDefaultNetTypes(); @@ -211,7 +209,7 @@ static TypeFactory() /// private static void RegisterDefaultNetTypes() { - // NOTE : each .NET type mut appear only one time + // NOTE: each .NET type should appear only one time RegisterType(typeof (Byte[]), NHibernateUtil.Binary, new[] {"binary"}, l => GetType(NHibernateUtil.Binary, l, len => new BinaryType(SqlTypeFactory.GetBinary(len)))); @@ -315,14 +313,8 @@ private static void RegisterBuiltInTypes() len => new SerializableType(typeof (object), SqlTypeFactory.GetBinary(len)))); } - public ICollectionTypeFactory CollectionTypeFactory - { - get { return Cfg.Environment.BytecodeProvider.CollectionTypeFactory; } - } - - private TypeFactory() - { - } + private static ICollectionTypeFactory CollectionTypeFactory => + Environment.BytecodeProvider.CollectionTypeFactory; /// /// Gets the classification of the Type based on the string. @@ -349,11 +341,11 @@ private TypeFactory() /// private static TypeClassification GetTypeClassification(string typeName) { - int indexOfOpenParen = typeName.IndexOf("("); - int indexOfComma = 0; + var indexOfOpenParen = typeName.IndexOf("(", StringComparison.Ordinal); + var indexOfComma = 0; if (indexOfOpenParen >= 0) { - indexOfComma = typeName.IndexOf(",", indexOfOpenParen); + indexOfComma = typeName.IndexOf(",", indexOfOpenParen, StringComparison.Ordinal); } if (indexOfOpenParen >= 0) @@ -409,8 +401,8 @@ public static IType Basic(string name) string[] parsedName = name.Split(PrecisionScaleSplit); if (parsedName.Length < 4) { - throw new ArgumentOutOfRangeException("TypeClassification.PrecisionScale", name, - "It is not a valid Precision/Scale name"); + throw new ArgumentOutOfRangeException( + "TypeClassification.PrecisionScale", name, "It is not a valid Precision/Scale name"); } typeName = parsedName[0].Trim(); @@ -426,7 +418,8 @@ public static IType Basic(string name) string[] parsedName = name.Split(LengthSplit); if (parsedName.Length < 3) { - throw new ArgumentOutOfRangeException("TypeClassification.LengthOrScale", name, "It is not a valid Length or Scale name"); + throw new ArgumentOutOfRangeException( + "TypeClassification.LengthOrScale", name, "It is not a valid Length or Scale name"); } typeName = parsedName[0].Trim(); @@ -460,34 +453,6 @@ internal static IType BuiltInType(string typeName, byte precision, byte scale) : precisionDelegate(precision, scale); } - private static void AddToTypeOfName(string key, IType type) - { - if (!typeByTypeOfName.TryAdd(key, type)) - { - throw new HibernateException("An item with the same key has already been added to typeByTypeOfName."); - } - if (!typeByTypeOfName.TryAdd(type.Name, type)) - { - throw new HibernateException("An item with the same key has already been added to typeByTypeOfName."); - } - } - - private static void AddToTypeOfNameWithLength(string key, IType type) - { - if (!typeByTypeOfName.TryAdd(key, type)) - { - throw new HibernateException("An item with the same key has already been added to typeByTypeOfName."); - } - } - - private static void AddToTypeOfNameWithPrecision(string key, IType type) - { - if (!typeByTypeOfName.TryAdd(key, type)) - { - throw new HibernateException("An item with the same key has already been added to typeByTypeOfName."); - } - } - private static string GetKeyForLengthOrScaleBased(string name, int lengthOrScale) { return name + "(" + lengthOrScale + ")"; @@ -567,7 +532,7 @@ public static IType HeuristicType(string typeName, IDictionary p { try { - type = (IType) Cfg.Environment.BytecodeProvider.ObjectsFactory.CreateInstance(typeClass); + type = (IType) Environment.BytecodeProvider.ObjectsFactory.CreateInstance(typeClass); } catch (Exception e) { @@ -623,18 +588,10 @@ public static IType GetDefaultTypeFor(System.Type type) return typeByTypeOfName.TryGetValue(type.FullName, out var nhType) ? nhType : null; } - [MethodImpl(MethodImplOptions.Synchronized)] public static NullableType GetAnsiStringType(int length) { - string key = GetKeyForLengthOrScaleBased(NHibernateUtil.AnsiString.Name, length); - - IType returnType; - if (!typeByTypeOfName.TryGetValue(key, out returnType)) - { - returnType = new AnsiStringType(SqlTypeFactory.GetAnsiString(length)); - AddToTypeOfNameWithLength(key, returnType); - } - return (NullableType)returnType; + var key = GetKeyForLengthOrScaleBased(NHibernateUtil.AnsiString.Name, length); + return (NullableType)typeByTypeOfName.GetOrAdd(key, k => new AnsiStringType(SqlTypeFactory.GetAnsiString(length))); } /// @@ -647,7 +604,6 @@ public static NullableType GetAnsiStringType(int length) /// been added to the basicNameMap with the keys Byte[](length) and /// NHibernate.Type.BinaryType(length). /// - [MethodImpl(MethodImplOptions.Synchronized)] public static NullableType GetBinaryType(int length) { //HACK: don't understand why SerializableType calls this with length=0 @@ -656,43 +612,20 @@ public static NullableType GetBinaryType(int length) return NHibernateUtil.Binary; } - string key = GetKeyForLengthOrScaleBased(NHibernateUtil.Binary.Name, length); - IType returnType; - if (!typeByTypeOfName.TryGetValue(key, out returnType)) - { - returnType = new BinaryType(SqlTypeFactory.GetBinary(length)); - AddToTypeOfNameWithLength(key, returnType); - } - - return (NullableType)returnType; + var key = GetKeyForLengthOrScaleBased(NHibernateUtil.Binary.Name, length); + return (NullableType)typeByTypeOfName.GetOrAdd(key, k => new BinaryType(SqlTypeFactory.GetBinary(length))); } - [MethodImpl(MethodImplOptions.Synchronized)] private static NullableType GetType(NullableType defaultUnqualifiedType, int lengthOrScale, GetNullableTypeWithLengthOrScale ctorDelegate) { - string key = GetKeyForLengthOrScaleBased(defaultUnqualifiedType.Name, lengthOrScale); - IType returnType; - if (!typeByTypeOfName.TryGetValue(key, out returnType)) - { - returnType = ctorDelegate(lengthOrScale); - AddToTypeOfNameWithLength(key, returnType); - } - - return (NullableType)returnType; + var key = GetKeyForLengthOrScaleBased(defaultUnqualifiedType.Name, lengthOrScale); + return (NullableType)typeByTypeOfName.GetOrAdd(key, k => ctorDelegate(lengthOrScale)); } - [MethodImpl(MethodImplOptions.Synchronized)] private static NullableType GetType(NullableType defaultUnqualifiedType, byte precision, byte scale, NullableTypeCreatorDelegate ctor) { - string key = GetKeyForPrecisionScaleBased(defaultUnqualifiedType.Name, precision, scale); - IType returnType; - if (!typeByTypeOfName.TryGetValue(key, out returnType)) - { - returnType = ctor(SqlTypeFactory.GetSqlType(defaultUnqualifiedType.SqlType.DbType, precision, scale)); - AddToTypeOfNameWithPrecision(key, returnType); - } - - return (NullableType)returnType; + var key = GetKeyForPrecisionScaleBased(defaultUnqualifiedType.Name, precision, scale); + return (NullableType)typeByTypeOfName.GetOrAdd(key, k => ctor(SqlTypeFactory.GetSqlType(defaultUnqualifiedType.SqlType.DbType, precision, scale))); } /// @@ -708,85 +641,53 @@ private static NullableType GetType(NullableType defaultUnqualifiedType, byte pr /// from the other items put in the basicNameMap because it is uses the AQN and the /// FQN as opposed to the short name used in the maps and the FQN. /// - /// - /// Since this method calls the method - /// GetSerializableType(System.Type, Int32) - /// with the default length, those keys will also be added. - /// /// - [MethodImpl(MethodImplOptions.Synchronized)] public static NullableType GetSerializableType(System.Type serializableType) { - string key = serializableType.AssemblyQualifiedName; + var key = serializableType.AssemblyQualifiedName; - IType returnType; - if (!typeByTypeOfName.TryGetValue(key, out returnType)) + // The value factory may be run concurrently, but only one resulting value will be yielded to all threads. + // So we should add the type with its other key in a later operation in order to ensure we cache the same + // instance for both keys. + var added = false; + var type = (NullableType)typeByTypeOfName.GetOrAdd( + key, + k => + { + var returnType = new SerializableType(serializableType); + added = true; + return returnType; + }); + if (added && typeByTypeOfName.GetOrAdd(type.Name, type) != type) { - returnType = new SerializableType(serializableType); - AddToTypeOfName(key, returnType); + throw new HibernateException($"Another item with the key {type.Name} has already been added to typeByTypeOfName."); } - return (NullableType)returnType; + return type; } - [MethodImpl(MethodImplOptions.Synchronized)] public static NullableType GetSerializableType(System.Type serializableType, int length) { - string key = GetKeyForLengthOrScaleBased(serializableType.AssemblyQualifiedName, length); - - IType returnType; - if (!typeByTypeOfName.TryGetValue(key, out returnType)) - { - returnType = new SerializableType(serializableType, SqlTypeFactory.GetBinary(length)); - AddToTypeOfNameWithLength(key, returnType); - } - - return (NullableType)returnType; + var key = GetKeyForLengthOrScaleBased(serializableType.AssemblyQualifiedName, length); + return (NullableType)typeByTypeOfName.GetOrAdd(key, k => new SerializableType(serializableType, SqlTypeFactory.GetBinary(length))); } - [MethodImpl(MethodImplOptions.Synchronized)] public static NullableType GetSerializableType(int length) { - string key = GetKeyForLengthOrScaleBased(NHibernateUtil.Serializable.Name, length); - - IType returnType; - if (!typeByTypeOfName.TryGetValue(key, out returnType)) - { - returnType = new SerializableType(typeof(object), SqlTypeFactory.GetBinary(length)); - AddToTypeOfNameWithLength(key, returnType); - } - - return (NullableType)returnType; + var key = GetKeyForLengthOrScaleBased(NHibernateUtil.Serializable.Name, length); + return (NullableType)typeByTypeOfName.GetOrAdd(key, k => new SerializableType(typeof(object), SqlTypeFactory.GetBinary(length))); } - [MethodImpl(MethodImplOptions.Synchronized)] public static NullableType GetStringType(int length) { - string key = GetKeyForLengthOrScaleBased(NHibernateUtil.String.Name, length); - - IType returnType; - if (!typeByTypeOfName.TryGetValue(key, out returnType)) - { - returnType = new StringType(SqlTypeFactory.GetString(length)); - AddToTypeOfNameWithLength(key, returnType); - } - - return (NullableType)returnType; + var key = GetKeyForLengthOrScaleBased(NHibernateUtil.String.Name, length); + return (NullableType)typeByTypeOfName.GetOrAdd(key, k => new StringType(SqlTypeFactory.GetString(length))); } - [MethodImpl(MethodImplOptions.Synchronized)] public static NullableType GetTypeType(int length) { - string key = GetKeyForLengthOrScaleBased(typeof(TypeType).FullName, length); - - IType returnType; - if (!typeByTypeOfName.TryGetValue(key, out returnType)) - { - returnType = new TypeType(SqlTypeFactory.GetString(length)); - AddToTypeOfNameWithLength(key, returnType); - } - - return (NullableType)returnType; + var key = GetKeyForLengthOrScaleBased(typeof(TypeType).FullName, length); + return (NullableType)typeByTypeOfName.GetOrAdd(key, k => new TypeType(SqlTypeFactory.GetString(length))); } /// @@ -794,18 +695,10 @@ public static NullableType GetTypeType(int length) /// /// The fractional seconds precision. /// The NHibernate type. - [MethodImpl(MethodImplOptions.Synchronized)] public static NullableType GetDateTimeType(byte fractionalSecondsPrecision) { var key = GetKeyForLengthOrScaleBased(NHibernateUtil.DateTime.Name, fractionalSecondsPrecision); - - if (!typeByTypeOfName.TryGetValue(key, out var returnType)) - { - returnType = new DateTimeType(SqlTypeFactory.GetDateTime(fractionalSecondsPrecision)); - AddToTypeOfNameWithLength(key, returnType); - } - - return (NullableType)returnType; + return (NullableType)typeByTypeOfName.GetOrAdd(key, k => new DateTimeType(SqlTypeFactory.GetDateTime(fractionalSecondsPrecision))); } /// @@ -813,19 +706,12 @@ public static NullableType GetDateTimeType(byte fractionalSecondsPrecision) /// /// The fractional seconds precision. /// The NHibernate type. - [MethodImpl(MethodImplOptions.Synchronized)] + // Since v5.0 [Obsolete("Use GetDateTimeType instead, it uses DateTime2 with dialects supporting it.")] public static NullableType GetDateTime2Type(byte fractionalSecondsPrecision) { var key = GetKeyForLengthOrScaleBased(NHibernateUtil.DateTime2.Name, fractionalSecondsPrecision); - - if (!typeByTypeOfName.TryGetValue(key, out var returnType)) - { - returnType = new DateTime2Type(SqlTypeFactory.GetDateTime2(fractionalSecondsPrecision)); - AddToTypeOfNameWithLength(key, returnType); - } - - return (NullableType)returnType; + return (NullableType)typeByTypeOfName.GetOrAdd(key, k => new DateTime2Type(SqlTypeFactory.GetDateTime2(fractionalSecondsPrecision))); } /// @@ -833,18 +719,10 @@ public static NullableType GetDateTime2Type(byte fractionalSecondsPrecision) /// /// The fractional seconds precision. /// The NHibernate type. - [MethodImpl(MethodImplOptions.Synchronized)] public static NullableType GetLocalDateTimeType(byte fractionalSecondsPrecision) { var key = GetKeyForLengthOrScaleBased(NHibernateUtil.LocalDateTime.Name, fractionalSecondsPrecision); - - if (!typeByTypeOfName.TryGetValue(key, out var returnType)) - { - returnType = new LocalDateTimeType(SqlTypeFactory.GetDateTime(fractionalSecondsPrecision)); - AddToTypeOfNameWithLength(key, returnType); - } - - return (NullableType)returnType; + return (NullableType)typeByTypeOfName.GetOrAdd(key, k => new LocalDateTimeType(SqlTypeFactory.GetDateTime(fractionalSecondsPrecision))); } /// @@ -852,18 +730,10 @@ public static NullableType GetLocalDateTimeType(byte fractionalSecondsPrecision) /// /// The fractional seconds precision. /// The NHibernate type. - [MethodImpl(MethodImplOptions.Synchronized)] public static NullableType GetUtcDateTimeType(byte fractionalSecondsPrecision) { var key = GetKeyForLengthOrScaleBased(NHibernateUtil.UtcDateTime.Name, fractionalSecondsPrecision); - - if (!typeByTypeOfName.TryGetValue(key, out var returnType)) - { - returnType = new UtcDateTimeType(SqlTypeFactory.GetDateTime(fractionalSecondsPrecision)); - AddToTypeOfNameWithLength(key, returnType); - } - - return (NullableType)returnType; + return (NullableType)typeByTypeOfName.GetOrAdd(key, k => new UtcDateTimeType(SqlTypeFactory.GetDateTime(fractionalSecondsPrecision))); } /// @@ -871,18 +741,10 @@ public static NullableType GetUtcDateTimeType(byte fractionalSecondsPrecision) /// /// The fractional seconds precision. /// The NHibernate type. - [MethodImpl(MethodImplOptions.Synchronized)] public static NullableType GetDateTimeOffsetType(byte fractionalSecondsPrecision) { var key = GetKeyForLengthOrScaleBased(NHibernateUtil.DateTimeOffset.Name, fractionalSecondsPrecision); - - if (!typeByTypeOfName.TryGetValue(key, out var returnType)) - { - returnType = new DateTimeOffsetType(SqlTypeFactory.GetDateTimeOffset(fractionalSecondsPrecision)); - AddToTypeOfNameWithLength(key, returnType); - } - - return (NullableType)returnType; + return (NullableType)typeByTypeOfName.GetOrAdd(key, k => new DateTimeOffsetType(SqlTypeFactory.GetDateTimeOffset(fractionalSecondsPrecision))); } /// @@ -890,18 +752,10 @@ public static NullableType GetDateTimeOffsetType(byte fractionalSecondsPrecision /// /// The fractional seconds precision. /// The NHibernate type. - [MethodImpl(MethodImplOptions.Synchronized)] public static NullableType GetTimeAsTimeSpanType(byte fractionalSecondsPrecision) { var key = GetKeyForLengthOrScaleBased(NHibernateUtil.TimeAsTimeSpan.Name, fractionalSecondsPrecision); - - if (!typeByTypeOfName.TryGetValue(key, out var returnType)) - { - returnType = new TimeAsTimeSpanType(SqlTypeFactory.GetTime(fractionalSecondsPrecision)); - AddToTypeOfNameWithLength(key, returnType); - } - - return (NullableType)returnType; + return (NullableType)typeByTypeOfName.GetOrAdd(key, k => new TimeAsTimeSpanType(SqlTypeFactory.GetTime(fractionalSecondsPrecision))); } /// @@ -909,18 +763,10 @@ public static NullableType GetTimeAsTimeSpanType(byte fractionalSecondsPrecision /// /// The fractional seconds precision. /// The NHibernate type. - [MethodImpl(MethodImplOptions.Synchronized)] public static NullableType GetTimeType(byte fractionalSecondsPrecision) { var key = GetKeyForLengthOrScaleBased(NHibernateUtil.Time.Name, fractionalSecondsPrecision); - - if (!typeByTypeOfName.TryGetValue(key, out var returnType)) - { - returnType = new TimeType(SqlTypeFactory.GetTime(fractionalSecondsPrecision)); - AddToTypeOfNameWithLength(key, returnType); - } - - return (NullableType)returnType; + return (NullableType)typeByTypeOfName.GetOrAdd(key, k => new TimeType(SqlTypeFactory.GetTime(fractionalSecondsPrecision))); } // Association Types @@ -932,8 +778,8 @@ public static EntityType OneToOne(string persistentClass, ForeignKeyDirection fo bool lazy, bool unwrapProxy, string entityName, string propertyName) { return - new OneToOneType(persistentClass, foreignKeyType, uniqueKeyPropertyName, lazy, unwrapProxy, - entityName, propertyName); + new OneToOneType( + persistentClass, foreignKeyType, uniqueKeyPropertyName, lazy, unwrapProxy, entityName, propertyName); } /// @@ -964,71 +810,71 @@ public static EntityType ManyToOne(string persistentClass, string uniqueKeyPrope public static CollectionType Array(string role, string propertyRef, System.Type elementClass) { - return Instance.CollectionTypeFactory.Array(role, propertyRef, elementClass); + return CollectionTypeFactory.Array(role, propertyRef, elementClass); } public static CollectionType GenericBag(string role, string propertyRef, System.Type elementClass) { - MethodInfo mi = BagDefinition.MakeGenericMethod(new[] { elementClass }); + MethodInfo mi = BagDefinition.MakeGenericMethod(elementClass); - return (CollectionType)mi.Invoke(Instance.CollectionTypeFactory, new object[] { role, propertyRef }); + return (CollectionType)mi.Invoke(CollectionTypeFactory, new object[] { role, propertyRef }); } public static CollectionType GenericIdBag(string role, string propertyRef, System.Type elementClass) { - MethodInfo mi = IdBagDefinition.MakeGenericMethod(new[] { elementClass }); + MethodInfo mi = IdBagDefinition.MakeGenericMethod(elementClass); - return (CollectionType)mi.Invoke(Instance.CollectionTypeFactory, new object[] { role, propertyRef }); + return (CollectionType)mi.Invoke(CollectionTypeFactory, new object[] { role, propertyRef }); } public static CollectionType GenericList(string role, string propertyRef, System.Type elementClass) { - MethodInfo mi = ListDefinition.MakeGenericMethod(new[] { elementClass }); + MethodInfo mi = ListDefinition.MakeGenericMethod(elementClass); - return (CollectionType)mi.Invoke(Instance.CollectionTypeFactory, new object[] { role, propertyRef }); + return (CollectionType)mi.Invoke(CollectionTypeFactory, new object[] { role, propertyRef }); } public static CollectionType GenericMap(string role, string propertyRef, System.Type indexClass, System.Type elementClass) { - MethodInfo mi = MapDefinition.MakeGenericMethod(new[] { indexClass, elementClass }); + MethodInfo mi = MapDefinition.MakeGenericMethod(indexClass, elementClass); - return (CollectionType)mi.Invoke(Instance.CollectionTypeFactory, new object[] { role, propertyRef }); + return (CollectionType)mi.Invoke(CollectionTypeFactory, new object[] { role, propertyRef }); } public static CollectionType GenericSortedList(string role, string propertyRef, object comparer, System.Type indexClass, System.Type elementClass) { - MethodInfo mi = SortedListDefinition.MakeGenericMethod(new[] { indexClass, elementClass }); + MethodInfo mi = SortedListDefinition.MakeGenericMethod(indexClass, elementClass); - return (CollectionType)mi.Invoke(Instance.CollectionTypeFactory, new object[] { role, propertyRef, comparer }); + return (CollectionType)mi.Invoke(CollectionTypeFactory, new[] { role, propertyRef, comparer }); } public static CollectionType GenericSortedDictionary(string role, string propertyRef, object comparer, System.Type indexClass, System.Type elementClass) { - MethodInfo mi = SortedDictionaryDefinition.MakeGenericMethod(new[] { indexClass, elementClass }); + MethodInfo mi = SortedDictionaryDefinition.MakeGenericMethod(indexClass, elementClass); - return (CollectionType)mi.Invoke(Instance.CollectionTypeFactory, new object[] { role, propertyRef, comparer }); + return (CollectionType)mi.Invoke(CollectionTypeFactory, new[] { role, propertyRef, comparer }); } public static CollectionType GenericSet(string role, string propertyRef, System.Type elementClass) { - MethodInfo mi = SetDefinition.MakeGenericMethod(new[] { elementClass }); + MethodInfo mi = SetDefinition.MakeGenericMethod(elementClass); - return (CollectionType)mi.Invoke(Instance.CollectionTypeFactory, new object[] { role, propertyRef }); + return (CollectionType)mi.Invoke(CollectionTypeFactory, new object[] { role, propertyRef }); } public static CollectionType GenericSortedSet(string role, string propertyRef, object comparer, System.Type elementClass) { - MethodInfo mi = SortedSetDefinition.MakeGenericMethod(new[] { elementClass }); + MethodInfo mi = SortedSetDefinition.MakeGenericMethod(elementClass); - return (CollectionType)mi.Invoke(Instance.CollectionTypeFactory, new object[] { role, propertyRef, comparer }); + return (CollectionType)mi.Invoke(CollectionTypeFactory, new[] { role, propertyRef, comparer }); } public static CollectionType GenericOrderedSet(string role, string propertyRef, System.Type elementClass) { - MethodInfo mi = OrderedSetDefinition.MakeGenericMethod(new[] { elementClass }); + MethodInfo mi = OrderedSetDefinition.MakeGenericMethod(elementClass); - return (CollectionType)mi.Invoke(Instance.CollectionTypeFactory, new object[] { role, propertyRef }); + return (CollectionType)mi.Invoke(CollectionTypeFactory, new object[] { role, propertyRef }); } public static CollectionType CustomCollection(string typeName, IDictionary typeParameters, string role, string propertyRef) @@ -1052,9 +898,9 @@ public static CollectionType CustomCollection(string typeName, IDictionary parameters) { - if (type is IParameterizedType) + if (type is IParameterizedType parameterizedType) { - ((IParameterizedType) type).SetParameterValues(parameters); + parameterizedType.SetParameterValues(parameters); } else if (parameters != null && parameters.Count != 0) { From e904ab71d2dd70b58169a26475aa3bc847fc9448 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Delaporte?= Date: Sat, 30 Sep 2017 16:46:48 +0200 Subject: [PATCH 2/2] Enabling thread safety test of type factory, to be squashed. --- .../DistributedSystemTransactionFixture.cs | 2 +- src/NHibernate.Test/MultiThreadRunner.cs | 130 +++++++++++------- .../DistributedSystemTransactionFixture.cs | 47 ++----- .../TypesTest/TypeFactoryFixture.cs | 46 +++---- .../ThreadSafeDictionaryFixture.cs | 103 -------------- 5 files changed, 115 insertions(+), 213 deletions(-) delete mode 100644 src/NHibernate.Test/UtilityTest/ThreadSafeDictionaryFixture.cs diff --git a/src/NHibernate.Test/Async/SystemTransactions/DistributedSystemTransactionFixture.cs b/src/NHibernate.Test/Async/SystemTransactions/DistributedSystemTransactionFixture.cs index ca7c64e5a9d..d8cafe793a3 100644 --- a/src/NHibernate.Test/Async/SystemTransactions/DistributedSystemTransactionFixture.cs +++ b/src/NHibernate.Test/Async/SystemTransactions/DistributedSystemTransactionFixture.cs @@ -780,4 +780,4 @@ protected override void Configure(Configuration configuration) protected override bool AppliesTo(ISessionFactoryImplementor factory) => base.AppliesTo(factory) && factory.ConnectionProvider.Driver.SupportsEnlistmentWhenAutoEnlistmentIsDisabled; } -} \ No newline at end of file +} diff --git a/src/NHibernate.Test/MultiThreadRunner.cs b/src/NHibernate.Test/MultiThreadRunner.cs index 0132d8a8885..3145b3c9515 100644 --- a/src/NHibernate.Test/MultiThreadRunner.cs +++ b/src/NHibernate.Test/MultiThreadRunner.cs @@ -1,4 +1,7 @@ using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; using System.Threading; namespace NHibernate.Test @@ -6,76 +9,109 @@ namespace NHibernate.Test public class MultiThreadRunner { public delegate void ExecuteAction(T subject); - private readonly int numThreads; - private readonly ExecuteAction[] actions; - private readonly Random rnd = new Random(); - private bool running; - private int timeout = 1000; - private int timeoutBetweenThreadStart = 30; - public MultiThreadRunner(int numThreads, ExecuteAction[] actions) + private readonly int _numThreads; + private readonly ExecuteAction[] _actions; + private readonly Random _rnd = new Random(); + private volatile bool _running; + private ConcurrentQueue _errors = new ConcurrentQueue(); + + public MultiThreadRunner(int numThreads, params ExecuteAction[] actions) { - if(numThreads < 1) + if (numThreads < 1) { - throw new ArgumentOutOfRangeException("numThreads",numThreads,"Must be GT 1"); + throw new ArgumentOutOfRangeException(nameof(numThreads), numThreads, "Must be GTE 1"); } if (actions == null || actions.Length == 0) { - throw new ArgumentNullException("actions"); + throw new ArgumentNullException(nameof(actions)); } - foreach (ExecuteAction action in actions) + if (actions.Any(action => action == null)) { - if(action==null) - throw new ArgumentNullException("actions", "null delegate"); + throw new ArgumentNullException(nameof(actions), "null delegate"); } - this.numThreads = numThreads; - this.actions = actions; + _numThreads = numThreads; + _actions = actions; } - public int EndTimeout - { - get { return timeout; } - set { timeout = value; } - } + public int EndTimeout { get; set; } = 1000; - public int TimeoutBetweenThreadStart - { - get { return timeoutBetweenThreadStart; } - set { timeoutBetweenThreadStart = value; } - } + public int TimeoutBetweenThreadStart { get; set; } = 30; - public void Run(T subjectInstance) - { - running = true; - Thread[] t = new Thread[numThreads]; - for (int i = 0; i < numThreads; i++) - { - t[i] = new Thread(ThreadProc); - t[i].Name = i.ToString(); - t[i].Start(subjectInstance); - if (i > 2) - Thread.Sleep(timeoutBetweenThreadStart); - } + public Exception[] GetErrors() => _errors.ToArray(); + public void ClearErrors() => _errors = new ConcurrentQueue(); - Thread.Sleep(timeout); + public int Run(T subjectInstance) + { + var allThreads = new List>(); - // Tell the threads to shut down, then wait until they all - // finish. - running = false; - for (int i = 0; i < numThreads; i++) + var launcher = new Thread( + () => + { + try + { + for (var i = 0; i < _numThreads; i++) + { + var threadHolder = new ThreadHolder + { + Thread = new Thread(ThreadProc) { Name = i.ToString() }, + Subject = subjectInstance + }; + threadHolder.Thread.Start(threadHolder); + allThreads.Add(threadHolder); + if (i > 2 && TimeoutBetweenThreadStart > 0) + Thread.Sleep(TimeoutBetweenThreadStart); + } + } + catch (Exception e) + { + _errors.Enqueue(e); + throw; + } + }); + var totalLoops = 0; + _running = true; + // Use a separated thread for launching in case too many threads are asked: the inner Start will freeze + // but would be able to resume once _running would have been set to false, causing first threads to stop. + launcher.Start(); + // Sleep for the required timeout, taking into account the start delay (if all threads are launchable without + // having to wait due to thread starvation). + Thread.Sleep(TimeoutBetweenThreadStart * _numThreads + EndTimeout); + // Tell the threads to shut down, then wait until they all finish. + _running = false; + launcher.Join(); + foreach (var threadHolder in allThreads.Where(t => t != null)) { - t[i].Join(); + threadHolder.Thread.Join(); + totalLoops += threadHolder.LoopsDone; } + return totalLoops; } private void ThreadProc(object arg) { - T subjectInstance = (T) arg; - while (running) + try + { + var holder = (ThreadHolder) arg; + while (_running) + { + var actionIdx = _rnd.Next(0, _actions.Length); + _actions[actionIdx](holder.Subject); + holder.LoopsDone++; + } + } + catch (Exception e) { - int actionIdx = rnd.Next(0, actions.Length); - actions[actionIdx](subjectInstance); + _errors.Enqueue(e); + throw; } } + + private class ThreadHolder + { + public Thread Thread { get; set; } + public int LoopsDone { get; set; } + public TH Subject { get; set; } + } } } diff --git a/src/NHibernate.Test/SystemTransactions/DistributedSystemTransactionFixture.cs b/src/NHibernate.Test/SystemTransactions/DistributedSystemTransactionFixture.cs index b9b8e089f0e..78e23702e7b 100644 --- a/src/NHibernate.Test/SystemTransactions/DistributedSystemTransactionFixture.cs +++ b/src/NHibernate.Test/SystemTransactions/DistributedSystemTransactionFixture.cs @@ -345,44 +345,27 @@ public void TransactionInsertLoadWithRollBackTask(bool explicitFlush) } } - private int _totalCall; - - [Test, Explicit("Test added for NH-1709 (trying to recreate the issue... without luck). If one thread break the test, you can see the result in the console.")] + [Test, Explicit("Test added for NH-1709 (trying to recreate the issue... without luck).")] public void MultiThreadedTransaction() { // Test added for NH-1709 (trying to recreate the issue... without luck) - // If one thread break the test, you can see the result in the console. - ((Logger)_log.Logger).Level = log4net.Core.Level.Debug; - var actions = new MultiThreadRunner.ExecuteAction[] - { - delegate - { - CanRollbackTransaction(false); - _totalCall++; - }, - delegate - { - RollbackOutsideNh(false); - _totalCall++; - }, - delegate - { - TransactionInsertWithRollBackTask(false); - _totalCall++; - }, - delegate - { - TransactionInsertLoadWithRollBackTask(false); - _totalCall++; - }, - }; - var mtr = new MultiThreadRunner(20, actions) + var mtr = new MultiThreadRunner( + 20, + o => CanRollbackTransaction(false), + o => RollbackOutsideNh(false), + o => TransactionInsertWithRollBackTask(false), + o => TransactionInsertLoadWithRollBackTask(false)) { EndTimeout = 5000, TimeoutBetweenThreadStart = 5 }; - mtr.Run(null); - _log.DebugFormat("{0} calls", _totalCall); + var totalCalls = mtr.Run(null); + _log.DebugFormat("{0} calls", totalCalls); + var errors = mtr.GetErrors(); + if (errors.Length > 0) + { + Assert.Fail("One or more thread failed, found {0} errors. First exception: {1}", errors.Length, errors[0]); + } } [Theory] @@ -820,4 +803,4 @@ public void SessionIsNotEnlisted() } } } -} \ No newline at end of file +} diff --git a/src/NHibernate.Test/TypesTest/TypeFactoryFixture.cs b/src/NHibernate.Test/TypesTest/TypeFactoryFixture.cs index 5b233f74d5c..33890ae7585 100644 --- a/src/NHibernate.Test/TypesTest/TypeFactoryFixture.cs +++ b/src/NHibernate.Test/TypesTest/TypeFactoryFixture.cs @@ -73,42 +73,28 @@ public long? GenericInt64 } private readonly Random rnd = new Random(); - private int totalCall; - [Test, Explicit] + [Test] public void MultiThreadAccess() { // Test added for NH-1251 - // If one thread break the test you can see the result in the console. - ((Logger) log.Logger).Level = log4net.Core.Level.Debug; - MultiThreadRunner.ExecuteAction[] actions = new MultiThreadRunner.ExecuteAction[] + var mtr = new MultiThreadRunner( + 100, + o => TypeFactory.GetStringType(rnd.Next(1, 50)), + o => TypeFactory.GetBinaryType(rnd.Next(1, 50)), + o => TypeFactory.GetSerializableType(rnd.Next(1, 50)), + o => TypeFactory.GetTypeType(rnd.Next(1, 20))) { - delegate(object o) - { - TypeFactory.GetStringType(rnd.Next(1, 50)); - totalCall++; - }, - delegate(object o) - { - TypeFactory.GetBinaryType(rnd.Next(1, 50)); - totalCall++; - }, - delegate(object o) - { - TypeFactory.GetSerializableType(rnd.Next(1, 50)); - totalCall++; - }, - delegate(object o) - { - TypeFactory.GetTypeType(rnd.Next(1, 20)); - totalCall++; - }, + EndTimeout = 2000, + TimeoutBetweenThreadStart = 2 }; - MultiThreadRunner mtr = new MultiThreadRunner(100, actions); - mtr.EndTimeout = 2000; - mtr.TimeoutBetweenThreadStart = 2; - mtr.Run(null); - log.DebugFormat("{0} calls", totalCall); + var totalCalls = mtr.Run(null); + log.DebugFormat("{0} calls", totalCalls); + var errors = mtr.GetErrors(); + if (errors.Length > 0) + { + Assert.Fail("One or more thread failed, found {0} errors. First exception: {1}", errors.Length, errors[0]); + } } [Test] diff --git a/src/NHibernate.Test/UtilityTest/ThreadSafeDictionaryFixture.cs b/src/NHibernate.Test/UtilityTest/ThreadSafeDictionaryFixture.cs deleted file mode 100644 index 877b2d37a04..00000000000 --- a/src/NHibernate.Test/UtilityTest/ThreadSafeDictionaryFixture.cs +++ /dev/null @@ -1,103 +0,0 @@ -using System; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Threading; -using log4net; -using NUnit.Framework; - -namespace NHibernate.Test.UtilityTest -{ - [TestFixture] - public class ThreadSafeDictionaryFixture - { - public ThreadSafeDictionaryFixture() - { - log4net.Config.XmlConfigurator.Configure(LogManager.GetRepository(typeof(ThreadSafeDictionaryFixture).Assembly)); - } - - private static readonly ILog log = LogManager.GetLogger(typeof(ThreadSafeDictionaryFixture)); - - private readonly Random rnd = new Random(); - private int read, write; - - [Test, Explicit] - public void MultiThreadAccess() - { - MultiThreadRunner>.ExecuteAction[] actions = - new MultiThreadRunner>.ExecuteAction[] - { - delegate(ConcurrentDictionary d) - { - log.DebugFormat("T{0} Add", Thread.CurrentThread.Name); - write++; - d.TryAdd(rnd.Next(), rnd.Next()); - }, - delegate(ConcurrentDictionary d) - { - log.DebugFormat("T{0} ContainsKey", Thread.CurrentThread.Name); - read++; - d.ContainsKey(rnd.Next()); - }, - delegate(ConcurrentDictionary d) - { - log.DebugFormat("T{0} Remove", Thread.CurrentThread.Name); - write++; - int value; - d.TryRemove(rnd.Next(), out value); - }, - delegate(ConcurrentDictionary d) - { - log.DebugFormat("T{0} TryGetValue", Thread.CurrentThread.Name); - read++; - int val; - d.TryGetValue(rnd.Next(), out val); - }, - delegate(ConcurrentDictionary d) - { - try - { - log.DebugFormat("T{0} get_this[]", Thread.CurrentThread.Name); - read++; - int val = d[rnd.Next()]; - } - catch (KeyNotFoundException) - { - // not foud key - } - }, - delegate(ConcurrentDictionary d) - { - log.DebugFormat("T{0} set_this[]", Thread.CurrentThread.Name); - write++; - d[rnd.Next()] = rnd.Next(); - }, - delegate(ConcurrentDictionary d) - { - log.DebugFormat("T{0} Keys", Thread.CurrentThread.Name); - read++; - IEnumerable e = d.Keys; - }, - delegate(ConcurrentDictionary d) - { - log.DebugFormat("T{0} Values", Thread.CurrentThread.Name); - read++; - IEnumerable e = d.Values; - }, - delegate(ConcurrentDictionary d) - { - log.DebugFormat("T{0} GetEnumerator", Thread.CurrentThread.Name); - read++; - foreach (KeyValuePair pair in d) - { - - } - }, - }; - MultiThreadRunner> mtr = new MultiThreadRunner>(20, actions); - ConcurrentDictionary wrapper = new ConcurrentDictionary(); - mtr.EndTimeout = 2000; - mtr.Run(wrapper); - log.DebugFormat("{0} reads, {1} writes -- elements {2}", read, write, wrapper.Count); - } - } -}