From 7cdeeb534433d3549da8bc55bc269240c5247f3a Mon Sep 17 00:00:00 2001 From: maca88 Date: Sat, 2 Mar 2019 17:07:44 +0100 Subject: [PATCH 01/29] Reduce cast usage for SUM aggregate function --- .../NHSpecificTest/GH2029/Fixture.cs | 175 ++++++++++++++++++ .../Visitors/HqlGeneratorExpressionVisitor.cs | 16 +- 2 files changed, 189 insertions(+), 2 deletions(-) create mode 100644 src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs diff --git a/src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs b/src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs new file mode 100644 index 00000000000..b24f3186d2f --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs @@ -0,0 +1,175 @@ +using System; +using System.Linq; +using NHibernate.Cfg.MappingSchema; +using NHibernate.Mapping.ByCode; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH2029 +{ + public class TestClass + { + public virtual int Id { get; set; } + public virtual int? NullableInt32Prop { get; set; } + public virtual int Int32Prop { get; set; } + public virtual long? NullableInt64Prop { get; set; } + public virtual long Int64Prop { get; set; } + } + + [TestFixture] + public class Fixture : TestCaseMappingByCode + { + protected override HbmMapping GetMappings() + { + var mapper = new ModelMapper(); + mapper.Class(rc => + { + rc.Id(x => x.Id, m => m.Generator(Generators.Native)); + rc.Property(x => x.NullableInt32Prop); + rc.Property(x => x.Int32Prop); + rc.Property(x => x.NullableInt64Prop); + rc.Property(x => x.Int64Prop); + }); + + return mapper.CompileMappingForAllExplicitlyAddedEntities(); + } + + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var tx = session.BeginTransaction()) + { + session.Save(new TestClass + { + Int32Prop = int.MaxValue, + NullableInt32Prop = int.MaxValue, + Int64Prop = int.MaxValue, + NullableInt64Prop = int.MaxValue + }); + session.Save(new TestClass + { + Int32Prop = int.MaxValue, + NullableInt32Prop = int.MaxValue, + Int64Prop = int.MaxValue, + NullableInt64Prop = int.MaxValue + }); + session.Save(new TestClass + { + Int32Prop = int.MaxValue, + NullableInt32Prop = null, + Int64Prop = int.MaxValue, + NullableInt64Prop = null + }); + + tx.Commit(); + } + } + + protected override void OnTearDown() + { + using (var session = OpenSession()) + using (var tx = session.BeginTransaction()) + { + session.CreateQuery("delete from TestClass").ExecuteUpdate(); + + tx.Commit(); + } + } + + [Test] + public void NullableIntOverflow() + { + using (var session = OpenSession()) + using (session.BeginTransaction()) + using (var sqlLog = new SqlLogSpy()) + { + var groups = session.Query() + .GroupBy(i => 1) + .Select(g => new + { + s = g.Sum(i => (long) i.NullableInt32Prop) + }) + .ToArray(); + + Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(1)); + Assert.That(groups, Has.Length.EqualTo(1)); + Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 2)); + } + } + + [Test] + public void IntOverflow() + { + using (var session = OpenSession()) + using (session.BeginTransaction()) + using (var sqlLog = new SqlLogSpy()) + { + var groups = session.Query() + .GroupBy(i => 1) + .Select(g => new + { + s = g.Sum(i => (long) i.Int32Prop) + }) + .ToArray(); + + Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(1)); + Assert.That(groups, Has.Length.EqualTo(1)); + Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 3)); + } + } + + [Test] + public void NullableInt64NoCast() + { + using (var session = OpenSession()) + using (session.BeginTransaction()) + using (var sqlLog = new SqlLogSpy()) + { + var groups = session.Query() + .GroupBy(i => 1) + .Select(g => new { + s = g.Sum(i => i.NullableInt64Prop) + }) + .ToArray(); + + Assert.That(sqlLog.GetWholeLog(), Does.Not.Contains("cast")); + Assert.That(groups, Has.Length.EqualTo(1)); + Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 2)); + } + } + + [Test] + public void Int64NoCast() + { + using (var session = OpenSession()) + using (session.BeginTransaction()) + using (var sqlLog = new SqlLogSpy()) + { + var groups = session.Query() + .GroupBy(i => 1) + .Select(g => new { + s = g.Sum(i => i.Int64Prop) + }) + .ToArray(); + + Assert.That(sqlLog.GetWholeLog(), Does.Not.Contains("cast")); + Assert.That(groups, Has.Length.EqualTo(1)); + Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 3)); + } + } + + private int FindAllOccurrences(string source, string substring) + { + if (source == null) + { + return 0; + } + int n = 0, count = 0; + while ((n = source.IndexOf(substring, n, StringComparison.InvariantCulture)) != -1) + { + n += substring.Length; + ++count; + } + return count; + } + } +} diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index 3d2e9be5f47..ebdb0a19c1c 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -5,9 +5,11 @@ using System.Runtime.CompilerServices; using NHibernate.Engine.Query; using NHibernate.Hql.Ast; +using NHibernate.Hql.Ast.ANTLR; using NHibernate.Linq.Expressions; using NHibernate.Linq.Functions; using NHibernate.Param; +using NHibernate.Type; using NHibernate.Util; using Remotion.Linq.Clauses.Expressions; @@ -261,6 +263,14 @@ protected HqlTreeNode VisitNhMax(NhMaxExpression expression) protected HqlTreeNode VisitNhSum(NhSumExpression expression) { + var type = expression.Type.UnwrapIfNullable(); + var nhType = TypeFactory.GetDefaultTypeFor(type); + if (nhType != null && _parameters.SessionFactory.SQLFunctionRegistry.FindSQLFunction("sum") + ?.ReturnType(nhType, _parameters.SessionFactory)?.ReturnedClass == type) + { + return _hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()); + } + return _hqlTreeBuilder.Cast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type); } @@ -475,8 +485,10 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression) case ExpressionType.Convert: case ExpressionType.ConvertChecked: case ExpressionType.TypeAs: - if ((expression.Operand.Type.IsPrimitive || expression.Operand.Type == typeof(Decimal)) && - (expression.Type.IsPrimitive || expression.Type == typeof(Decimal))) + var operandType = expression.Operand.Type.UnwrapIfNullable(); + if ((operandType.IsPrimitive || operandType == typeof(decimal)) && + (expression.Type.IsPrimitive || expression.Type == typeof(decimal)) && + expression.Type != operandType) { return _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), expression.Type); } From a3699a13a833dee5513fe7cb433a39694200e271 Mon Sep 17 00:00:00 2001 From: maca88 Date: Sat, 2 Mar 2019 18:53:10 +0100 Subject: [PATCH 02/29] Fix SQLite tests --- src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs | 6 ++++++ .../Linq/Visitors/HqlGeneratorExpressionVisitor.cs | 4 +++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs b/src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs index b24f3186d2f..70aed66bcab 100644 --- a/src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs +++ b/src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs @@ -1,6 +1,7 @@ using System; using System.Linq; using NHibernate.Cfg.MappingSchema; +using NHibernate.Dialect; using NHibernate.Mapping.ByCode; using NUnit.Framework; @@ -33,6 +34,11 @@ protected override HbmMapping GetMappings() return mapper.CompileMappingForAllExplicitlyAddedEntities(); } + protected override bool AppliesTo(Dialect.Dialect dialect) + { + return !(dialect is SQLiteDialect); + } + protected override void OnSetUp() { using (var session = OpenSession()) diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index ebdb0a19c1c..8f9ea0a34a8 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -268,7 +268,9 @@ protected HqlTreeNode VisitNhSum(NhSumExpression expression) if (nhType != null && _parameters.SessionFactory.SQLFunctionRegistry.FindSQLFunction("sum") ?.ReturnType(nhType, _parameters.SessionFactory)?.ReturnedClass == type) { - return _hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()); + return _hqlTreeBuilder.TransparentCast( + _hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), + expression.Type); } return _hqlTreeBuilder.Cast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type); From 5050805487ef4c028066b4bd16f92751b2549a85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Delaporte?= <12201973+fredericdelaporte@users.noreply.github.com> Date: Sun, 10 Mar 2019 19:41:24 +0100 Subject: [PATCH 03/29] Adjust tests and regen async To be squashed --- .../Async/NHSpecificTest/GH2029/Fixture.cs | 185 ++++++++++++++++++ .../NHSpecificTest/GH2029/Fixture.cs | 18 +- 2 files changed, 194 insertions(+), 9 deletions(-) create mode 100644 src/NHibernate.Test/Async/NHSpecificTest/GH2029/Fixture.cs diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH2029/Fixture.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH2029/Fixture.cs new file mode 100644 index 00000000000..c156cc700eb --- /dev/null +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH2029/Fixture.cs @@ -0,0 +1,185 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System; +using System.Linq; +using NHibernate.Cfg.MappingSchema; +using NHibernate.Dialect; +using NHibernate.Mapping.ByCode; +using NUnit.Framework; +using NHibernate.Linq; + +namespace NHibernate.Test.NHSpecificTest.GH2029 +{ + using System.Threading.Tasks; + + [TestFixture] + public class FixtureAsync : TestCaseMappingByCode + { + protected override HbmMapping GetMappings() + { + var mapper = new ModelMapper(); + mapper.Class(rc => + { + rc.Id(x => x.Id, m => m.Generator(Generators.Native)); + rc.Property(x => x.NullableInt32Prop); + rc.Property(x => x.Int32Prop); + rc.Property(x => x.NullableInt64Prop); + rc.Property(x => x.Int64Prop); + }); + + return mapper.CompileMappingForAllExplicitlyAddedEntities(); + } + + protected override bool AppliesTo(Dialect.Dialect dialect) + { + return !(dialect is SQLiteDialect); + } + + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var tx = session.BeginTransaction()) + { + session.Save(new TestClass + { + Int32Prop = int.MaxValue, + NullableInt32Prop = int.MaxValue, + Int64Prop = int.MaxValue, + NullableInt64Prop = int.MaxValue + }); + session.Save(new TestClass + { + Int32Prop = int.MaxValue, + NullableInt32Prop = int.MaxValue, + Int64Prop = int.MaxValue, + NullableInt64Prop = int.MaxValue + }); + session.Save(new TestClass + { + Int32Prop = int.MaxValue, + NullableInt32Prop = null, + Int64Prop = int.MaxValue, + NullableInt64Prop = null + }); + + tx.Commit(); + } + } + + protected override void OnTearDown() + { + using (var session = OpenSession()) + using (var tx = session.BeginTransaction()) + { + session.CreateQuery("delete from TestClass").ExecuteUpdate(); + + tx.Commit(); + } + } + + [Test] + public async Task NullableIntOverflowAsync() + { + using (var session = OpenSession()) + using (session.BeginTransaction()) + using (var sqlLog = new SqlLogSpy()) + { + var groups = await (session.Query() + .GroupBy(i => 1) + .Select(g => new + { + s = g.Sum(i => (long) i.NullableInt32Prop) + }) + .ToListAsync()); + + Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(1)); + Assert.That(groups, Has.Count.EqualTo(1)); + Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 2)); + } + } + + [Test] + public async Task IntOverflowAsync() + { + using (var session = OpenSession()) + using (session.BeginTransaction()) + using (var sqlLog = new SqlLogSpy()) + { + var groups = await (session.Query() + .GroupBy(i => 1) + .Select(g => new + { + s = g.Sum(i => (long) i.Int32Prop) + }) + .ToListAsync()); + + Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(1)); + Assert.That(groups, Has.Count.EqualTo(1)); + Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 3)); + } + } + + [Test] + public async Task NullableInt64NoCastAsync() + { + using (var session = OpenSession()) + using (session.BeginTransaction()) + using (var sqlLog = new SqlLogSpy()) + { + var groups = await (session.Query() + .GroupBy(i => 1) + .Select(g => new { + s = g.Sum(i => i.NullableInt64Prop) + }) + .ToListAsync()); + + Assert.That(sqlLog.GetWholeLog(), Does.Not.Contains("cast")); + Assert.That(groups, Has.Count.EqualTo(1)); + Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 2)); + } + } + + [Test] + public async Task Int64NoCastAsync() + { + using (var session = OpenSession()) + using (session.BeginTransaction()) + using (var sqlLog = new SqlLogSpy()) + { + var groups = await (session.Query() + .GroupBy(i => 1) + .Select(g => new { + s = g.Sum(i => i.Int64Prop) + }) + .ToListAsync()); + + Assert.That(sqlLog.GetWholeLog(), Does.Not.Contains("cast")); + Assert.That(groups, Has.Count.EqualTo(1)); + Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 3)); + } + } + + private int FindAllOccurrences(string source, string substring) + { + if (source == null) + { + return 0; + } + int n = 0, count = 0; + while ((n = source.IndexOf(substring, n, StringComparison.InvariantCulture)) != -1) + { + n += substring.Length; + count++; + } + return count; + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs b/src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs index 70aed66bcab..7565e1889f9 100644 --- a/src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs +++ b/src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs @@ -94,10 +94,10 @@ public void NullableIntOverflow() { s = g.Sum(i => (long) i.NullableInt32Prop) }) - .ToArray(); + .ToList(); Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(1)); - Assert.That(groups, Has.Length.EqualTo(1)); + Assert.That(groups, Has.Count.EqualTo(1)); Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 2)); } } @@ -115,10 +115,10 @@ public void IntOverflow() { s = g.Sum(i => (long) i.Int32Prop) }) - .ToArray(); + .ToList(); Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(1)); - Assert.That(groups, Has.Length.EqualTo(1)); + Assert.That(groups, Has.Count.EqualTo(1)); Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 3)); } } @@ -135,10 +135,10 @@ public void NullableInt64NoCast() .Select(g => new { s = g.Sum(i => i.NullableInt64Prop) }) - .ToArray(); + .ToList(); Assert.That(sqlLog.GetWholeLog(), Does.Not.Contains("cast")); - Assert.That(groups, Has.Length.EqualTo(1)); + Assert.That(groups, Has.Count.EqualTo(1)); Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 2)); } } @@ -155,10 +155,10 @@ public void Int64NoCast() .Select(g => new { s = g.Sum(i => i.Int64Prop) }) - .ToArray(); + .ToList(); Assert.That(sqlLog.GetWholeLog(), Does.Not.Contains("cast")); - Assert.That(groups, Has.Length.EqualTo(1)); + Assert.That(groups, Has.Count.EqualTo(1)); Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 3)); } } @@ -173,7 +173,7 @@ private int FindAllOccurrences(string source, string substring) while ((n = source.IndexOf(substring, n, StringComparison.InvariantCulture)) != -1) { n += substring.Length; - ++count; + count++; } return count; } From 8fb195612a9ad0bb1bde4c9754c6afcb6fe943aa Mon Sep 17 00:00:00 2001 From: maca88 Date: Mon, 11 Mar 2019 20:25:08 +0100 Subject: [PATCH 04/29] Extend the logic to be used for other aggregate functions --- .../Visitors/HqlGeneratorExpressionVisitor.cs | 127 ++++++++++++++---- 1 file changed, 104 insertions(+), 23 deletions(-) diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index 8f9ea0a34a8..99c9229aa16 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -1,4 +1,5 @@ using System; +using System.Data; using System.Dynamic; using System.Linq; using System.Linq.Expressions; @@ -240,10 +241,13 @@ constant.Value is CallSite site && protected HqlTreeNode VisitNhAverage(NhAverageExpression expression) { var hqlExpression = VisitExpression(expression.Expression).AsExpression(); - if (expression.Type != expression.Expression.Type) - hqlExpression = _hqlTreeBuilder.Cast(hqlExpression, expression.Type); + hqlExpression = IsCastRequired(expression.Expression, expression.Type) + ? (HqlExpression) _hqlTreeBuilder.Cast(hqlExpression, expression.Type) + : _hqlTreeBuilder.TransparentCast(hqlExpression, expression.Type); - return _hqlTreeBuilder.Cast(_hqlTreeBuilder.Average(hqlExpression), expression.Type); + return IsCastRequired(expression.Type, "avg") + ? (HqlTreeNode) _hqlTreeBuilder.Cast(_hqlTreeBuilder.Average(hqlExpression), expression.Type) + : _hqlTreeBuilder.TransparentCast(_hqlTreeBuilder.Average(hqlExpression), expression.Type); } protected HqlTreeNode VisitNhCount(NhCountExpression expression) @@ -263,17 +267,9 @@ protected HqlTreeNode VisitNhMax(NhMaxExpression expression) protected HqlTreeNode VisitNhSum(NhSumExpression expression) { - var type = expression.Type.UnwrapIfNullable(); - var nhType = TypeFactory.GetDefaultTypeFor(type); - if (nhType != null && _parameters.SessionFactory.SQLFunctionRegistry.FindSQLFunction("sum") - ?.ReturnType(nhType, _parameters.SessionFactory)?.ReturnedClass == type) - { - return _hqlTreeBuilder.TransparentCast( - _hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), - expression.Type); - } - - return _hqlTreeBuilder.Cast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type); + return IsCastRequired(expression.Type, "sum") + ? (HqlTreeNode) _hqlTreeBuilder.Cast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type) + : _hqlTreeBuilder.TransparentCast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type); } protected HqlTreeNode VisitNhDistinct(NhDistinctExpression expression) @@ -487,15 +483,9 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression) case ExpressionType.Convert: case ExpressionType.ConvertChecked: case ExpressionType.TypeAs: - var operandType = expression.Operand.Type.UnwrapIfNullable(); - if ((operandType.IsPrimitive || operandType == typeof(decimal)) && - (expression.Type.IsPrimitive || expression.Type == typeof(decimal)) && - expression.Type != operandType) - { - return _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), expression.Type); - } - - return VisitExpression(expression.Operand); + return IsCastRequired(expression.Operand, expression.Type) + ? _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), expression.Type) + : VisitExpression(expression.Operand); } throw new NotSupportedException(expression.ToString()); @@ -596,5 +586,96 @@ protected HqlTreeNode VisitNewArrayExpression(NewArrayExpression expression) var expressionSubTree = expression.Expressions.ToArray(exp => VisitExpression(exp)); return _hqlTreeBuilder.ExpressionSubTreeHolder(expressionSubTree); } + + private bool IsCastRequired(Expression expression, System.Type toType) + { + return toType != typeof(object) && IsCastRequired(GetType(expression), TypeFactory.GetDefaultTypeFor(toType)); + } + + private bool IsCastRequired(IType type, IType toType) + { + // A type can be null when casting an entity into a base class, in that case we should not cast + if (type == null || toType == null || Equals(type, toType)) + { + return false; + } + + var sqlTypes = type.SqlTypes(_parameters.SessionFactory); + var toSqlTypes = toType.SqlTypes(_parameters.SessionFactory); + if (sqlTypes.Length != 1 || toSqlTypes.Length != 1) + { + return false; // Casting a multi-column type is not possible + } + + if (type.ReturnedClass.IsEnum && sqlTypes[0].DbType == DbType.String) + { + return false; // Never cast an enum that is mapped as string, the type will provide a string for the parameter value + } + + return sqlTypes[0].DbType != toSqlTypes[0].DbType; + } + + private bool IsCastRequired(System.Type type, string sqlFunctionName) + { + if (type == typeof(object)) + { + return false; + } + + var toType = TypeFactory.GetDefaultTypeFor(type); + if (toType == null) + { + return true; // Fallback to the old behavior + } + + var sqlFunction = _parameters.SessionFactory.SQLFunctionRegistry.FindSQLFunction(sqlFunctionName); + if (sqlFunction == null) + { + return true; // Fallback to the old behavior + } + + var fnReturnType = sqlFunction.ReturnType(toType, _parameters.SessionFactory); + return fnReturnType == null || IsCastRequired(fnReturnType, toType); + } + + private IType GetType(Expression expression) + { + if (!(expression is MemberExpression memberExpression)) + { + return expression.Type != typeof(object) + ? TypeFactory.GetDefaultTypeFor(expression.Type) + : null; + } + + // Try to get the mapped type for the member as it may be a non default one + var entityName = TryGetEntityName(memberExpression); + if (entityName == null) + { + return TypeFactory.GetDefaultTypeFor(expression.Type); // Not mapped + } + + var persister = _parameters.SessionFactory.GetEntityPersister(entityName); + var index = persister.EntityMetamodel.GetPropertyIndexOrNull(memberExpression.Member.Name); + return !index.HasValue + ? TypeFactory.GetDefaultTypeFor(expression.Type) // Not mapped + : persister.EntityMetamodel.PropertyTypes[index.Value]; + } + + private string TryGetEntityName(MemberExpression memberExpression) + { + System.Type entityType; + // Try to get the actual entity type from the query source if possbile as member can be declared + // in a base type + if (memberExpression.Expression is QuerySourceReferenceExpression querySourceReferenceExpression) + { + entityType = querySourceReferenceExpression.Type; + } + else + { + entityType = memberExpression.Member.ReflectedType; + } + + return _parameters.SessionFactory.TryGetGuessEntityName(entityType); + } } } From 36772cf985c7924e8e027e2052e1fff1b93fa697 Mon Sep 17 00:00:00 2001 From: maca88 Date: Tue, 12 Mar 2019 19:52:30 +0100 Subject: [PATCH 05/29] Check the dialect cast type instead --- .../Linq/Visitors/HqlGeneratorExpressionVisitor.cs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index 99c9229aa16..979d32d7fb4 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -607,12 +607,20 @@ private bool IsCastRequired(IType type, IType toType) return false; // Casting a multi-column type is not possible } + if (sqlTypes[0].DbType == toSqlTypes[0].DbType) + { + return false; + } + if (type.ReturnedClass.IsEnum && sqlTypes[0].DbType == DbType.String) { return false; // Never cast an enum that is mapped as string, the type will provide a string for the parameter value } - return sqlTypes[0].DbType != toSqlTypes[0].DbType; + // Some dialects can map several sql types into one, cast only if the dialect types are different + var castTypeName = _parameters.SessionFactory.Dialect.GetCastTypeName(sqlTypes[0]); + var toCastTypeName = _parameters.SessionFactory.Dialect.GetCastTypeName(toSqlTypes[0]); + return castTypeName != toCastTypeName; } private bool IsCastRequired(System.Type type, string sqlFunctionName) From 046483e0c73e06f61ab2e36979529378e0c0220a Mon Sep 17 00:00:00 2001 From: maca88 Date: Thu, 14 Mar 2019 20:50:37 +0100 Subject: [PATCH 06/29] Fix oracle and mysql issues --- .../NHSpecificTest/GH2029/Fixture.cs | 10 ++++- src/NHibernate/Dialect/Dialect.cs | 38 ++++++++++++++-- .../Dialect/Function/CastFunction.cs | 5 +-- src/NHibernate/Dialect/TypeNames.cs | 43 +++++++++++++++++-- .../Visitors/HqlGeneratorExpressionVisitor.cs | 43 +++++++++++++------ 5 files changed, 112 insertions(+), 27 deletions(-) diff --git a/src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs b/src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs index 7565e1889f9..c0ec490cf5e 100644 --- a/src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs +++ b/src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs @@ -84,6 +84,9 @@ protected override void OnTearDown() [Test] public void NullableIntOverflow() { + var hasCast = Dialect.GetCastTypeName(NHibernateUtil.Int32.SqlType) != + Dialect.GetCastTypeName(NHibernateUtil.Int64.SqlType); + using (var session = OpenSession()) using (session.BeginTransaction()) using (var sqlLog = new SqlLogSpy()) @@ -96,7 +99,7 @@ public void NullableIntOverflow() }) .ToList(); - Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(1)); + Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(hasCast ? 1 : 0)); Assert.That(groups, Has.Count.EqualTo(1)); Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 2)); } @@ -105,6 +108,9 @@ public void NullableIntOverflow() [Test] public void IntOverflow() { + var hasCast = Dialect.GetCastTypeName(NHibernateUtil.Int32.SqlType) != + Dialect.GetCastTypeName(NHibernateUtil.Int64.SqlType); + using (var session = OpenSession()) using (session.BeginTransaction()) using (var sqlLog = new SqlLogSpy()) @@ -117,7 +123,7 @@ public void IntOverflow() }) .ToList(); - Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(1)); + Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(hasCast ? 1 : 0)); Assert.That(groups, Has.Count.EqualTo(1)); Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 3)); } diff --git a/src/NHibernate/Dialect/Dialect.cs b/src/NHibernate/Dialect/Dialect.cs index 7447f9494e7..e3d0b6cacd9 100644 --- a/src/NHibernate/Dialect/Dialect.cs +++ b/src/NHibernate/Dialect/Dialect.cs @@ -266,6 +266,18 @@ public virtual string GetLongestTypeName(DbType dbType) public virtual string GetCastTypeName(SqlType sqlType) => GetCastTypeName(sqlType, _typeNames); + /// + /// Get the name of the database type appropriate for casting operations + /// (via the CAST() SQL function) for the given typecode. + /// + /// The typecode. + /// The database type name that will be set in case it was found. + /// Whether the type name was found. + public virtual bool TryGetCastTypeName(SqlType sqlType, out string typeName) + { + return TryGetCastTypeName(sqlType, _typeNames, out typeName); + } + /// /// Get the name of the database type appropriate for casting operations /// (via the CAST() SQL function) for the given typecode. @@ -274,9 +286,27 @@ public virtual string GetCastTypeName(SqlType sqlType) => /// The source for type names. /// The database type name. protected virtual string GetCastTypeName(SqlType sqlType, TypeNames castTypeNames) + { + if (!TryGetCastTypeName(sqlType, castTypeNames, out var result)) + { + throw new ArgumentException("Dialect does not support DbType." + sqlType.DbType, nameof(sqlType)); + } + + return result; + } + + /// + /// Get the name of the database type appropriate for casting operations + /// (via the CAST() SQL function) for the given typecode. + /// + /// The typecode. + /// The source for type names. + /// The database type name that will be set in case it was found. + /// Whether the type name was found. + protected virtual bool TryGetCastTypeName(SqlType sqlType, TypeNames castTypeNames, out string typeName) { if (sqlType.LengthDefined || sqlType.PrecisionDefined || sqlType.ScaleDefined) - return castTypeNames.Get(sqlType.DbType, sqlType.Length, sqlType.Precision, sqlType.Scale); + return castTypeNames.TryGet(sqlType.DbType, sqlType.Length, sqlType.Precision, sqlType.Scale, out typeName); switch (sqlType.DbType) { case DbType.Decimal: @@ -284,18 +314,18 @@ protected virtual string GetCastTypeName(SqlType sqlType, TypeNames castTypeName case DbType.Double: // We cannot know if the user needs its digit after or before the dot, so use a configurable // default. - return castTypeNames.Get(sqlType.DbType, 0, DefaultCastPrecision, DefaultCastScale); + return castTypeNames.TryGet(sqlType.DbType, 0, DefaultCastPrecision, DefaultCastScale, out typeName); case DbType.DateTime: case DbType.DateTime2: case DbType.DateTimeOffset: case DbType.Time: case DbType.Currency: // Use default for these, dialects are supposed to map them to max capacity - return castTypeNames.Get(sqlType.DbType); + return castTypeNames.TryGet(sqlType.DbType, out typeName); default: // Other types are either length bound or not length/precision/scale bound. Otherwise they need to be // handled previously. - return castTypeNames.Get(sqlType.DbType, DefaultCastLength, 0, 0); + return castTypeNames.TryGet(sqlType.DbType, DefaultCastLength, 0, 0, out typeName); } } diff --git a/src/NHibernate/Dialect/Function/CastFunction.cs b/src/NHibernate/Dialect/Function/CastFunction.cs index 747dd90440a..e7a41a91882 100644 --- a/src/NHibernate/Dialect/Function/CastFunction.cs +++ b/src/NHibernate/Dialect/Function/CastFunction.cs @@ -50,10 +50,9 @@ public SqlString Render(IList args, ISessionFactoryImplementor factory) { throw new QueryException("invalid NHibernate type for cast(), was:" + typeName); } - sqlType = factory.Dialect.GetCastTypeName(sqlTypeCodes[0]); - if (sqlType == null) + + if (!factory.Dialect.TryGetCastTypeName(sqlTypeCodes[0], out sqlType)) { - //TODO: never reached, since GetTypeName() actually throws an exception! sqlType = typeName; } //else diff --git a/src/NHibernate/Dialect/TypeNames.cs b/src/NHibernate/Dialect/TypeNames.cs index 7c25a211461..b2fbd4529fc 100644 --- a/src/NHibernate/Dialect/TypeNames.cs +++ b/src/NHibernate/Dialect/TypeNames.cs @@ -57,13 +57,24 @@ public class TypeNames /// the default type name associated with the specified key public string Get(DbType typecode) { - if (!defaults.TryGetValue(typecode, out var result)) + if (!TryGet(typecode, out var result)) { throw new ArgumentException("Dialect does not support DbType." + typecode, nameof(typecode)); } return result; } + /// + /// Get default type name for specified type. + /// + /// The type key. + /// The default type name that will be set in case it was found. + /// Whether the default type name was found. + public bool TryGet(DbType typecode, out string typeName) + { + return defaults.TryGetValue(typecode, out typeName); + } + /// /// Get the type name specified type and size /// @@ -76,6 +87,28 @@ public string Get(DbType typecode) /// if available, otherwise the default type name. /// public string Get(DbType typecode, int size, int precision, int scale) + { + if (!TryGet(typecode, size, precision, scale, out var result)) + { + throw new ArgumentException("Dialect does not support DbType." + typecode, nameof(typecode)); + } + + return result; + } + + /// + /// Get the type name specified type and size. + /// + /// The type key. + /// The SQL length. + /// The SQL scale. + /// The SQL precision. + /// + /// The associated name with smallest capacity >= size (or precision for decimal, or scale for date time types) + /// if available, otherwise the default type name. + /// + /// Whether the type name was found. + public bool TryGet(DbType typecode, int size, int precision, int scale, out string typeName) { weighted.TryGetValue(typecode, out var map); if (map != null && map.Count > 0) @@ -88,7 +121,8 @@ public string Get(DbType typecode, int size, int precision, int scale) { if (requiredCapacity <= entry.Key) { - return Replace(entry.Value, size, precision, scale); + typeName = Replace(entry.Value, size, precision, scale); + return true; } } if (isPrecisionType && precision != 0) @@ -102,11 +136,12 @@ public string Get(DbType typecode, int size, int precision, int scale) // But if the type is used for storing amounts, this may cause losing the ability to store cents... // So better just reduce as few as possible. var adjustedScale = Math.Min(scale, adjustedPrecision); - return Replace(maxEntry.Value, size, adjustedPrecision, adjustedScale); + typeName = Replace(maxEntry.Value, size, adjustedPrecision, adjustedScale); + return true; } } //Could not find a specific type for the capacity, using the default - return Get(typecode); + return TryGet(typecode, out typeName); } /// diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index 979d32d7fb4..62cc4da556a 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -241,13 +241,13 @@ constant.Value is CallSite site && protected HqlTreeNode VisitNhAverage(NhAverageExpression expression) { var hqlExpression = VisitExpression(expression.Expression).AsExpression(); - hqlExpression = IsCastRequired(expression.Expression, expression.Type) + hqlExpression = IsCastRequired(expression.Expression, expression.Type, out _) ? (HqlExpression) _hqlTreeBuilder.Cast(hqlExpression, expression.Type) : _hqlTreeBuilder.TransparentCast(hqlExpression, expression.Type); - return IsCastRequired(expression.Type, "avg") - ? (HqlTreeNode) _hqlTreeBuilder.Cast(_hqlTreeBuilder.Average(hqlExpression), expression.Type) - : _hqlTreeBuilder.TransparentCast(_hqlTreeBuilder.Average(hqlExpression), expression.Type); + // In Oracle the avg function can return a number with up to 40 digits which cannot be retrieved from the data reader due to the lack of such + // numeric type in .NET. In order to avoid that we have to add a cast to trim the number so that it can be converted into a .NET numeric type. + return _hqlTreeBuilder.Cast(_hqlTreeBuilder.Average(hqlExpression), expression.Type); } protected HqlTreeNode VisitNhCount(NhCountExpression expression) @@ -267,7 +267,7 @@ protected HqlTreeNode VisitNhMax(NhMaxExpression expression) protected HqlTreeNode VisitNhSum(NhSumExpression expression) { - return IsCastRequired(expression.Type, "sum") + return IsCastRequired(expression.Type, "sum", out _) ? (HqlTreeNode) _hqlTreeBuilder.Cast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type) : _hqlTreeBuilder.TransparentCast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type); } @@ -483,9 +483,12 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression) case ExpressionType.Convert: case ExpressionType.ConvertChecked: case ExpressionType.TypeAs: - return IsCastRequired(expression.Operand, expression.Type) + return IsCastRequired(expression.Operand, expression.Type, out var existType) ? _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), expression.Type) - : VisitExpression(expression.Operand); + // Make a transparent cast when an IType exists, so that it can be used to retrieve the value from the data reader + : existType + ? _hqlTreeBuilder.TransparentCast(VisitExpression(expression.Operand).AsExpression(), expression.Type) + : VisitExpression(expression.Operand); } throw new NotSupportedException(expression.ToString()); @@ -587,16 +590,18 @@ protected HqlTreeNode VisitNewArrayExpression(NewArrayExpression expression) return _hqlTreeBuilder.ExpressionSubTreeHolder(expressionSubTree); } - private bool IsCastRequired(Expression expression, System.Type toType) + private bool IsCastRequired(Expression expression, System.Type toType, out bool existType) { - return toType != typeof(object) && IsCastRequired(GetType(expression), TypeFactory.GetDefaultTypeFor(toType)); + existType = false; + return toType != typeof(object) && IsCastRequired(GetType(expression), TypeFactory.GetDefaultTypeFor(toType), out existType); } - private bool IsCastRequired(IType type, IType toType) + private bool IsCastRequired(IType type, IType toType, out bool existType) { // A type can be null when casting an entity into a base class, in that case we should not cast if (type == null || toType == null || Equals(type, toType)) { + existType = false; return false; } @@ -604,9 +609,11 @@ private bool IsCastRequired(IType type, IType toType) var toSqlTypes = toType.SqlTypes(_parameters.SessionFactory); if (sqlTypes.Length != 1 || toSqlTypes.Length != 1) { + existType = false; return false; // Casting a multi-column type is not possible } + existType = true; if (sqlTypes[0].DbType == toSqlTypes[0].DbType) { return false; @@ -614,28 +621,36 @@ private bool IsCastRequired(IType type, IType toType) if (type.ReturnedClass.IsEnum && sqlTypes[0].DbType == DbType.String) { + existType = false; return false; // Never cast an enum that is mapped as string, the type will provide a string for the parameter value } // Some dialects can map several sql types into one, cast only if the dialect types are different - var castTypeName = _parameters.SessionFactory.Dialect.GetCastTypeName(sqlTypes[0]); - var toCastTypeName = _parameters.SessionFactory.Dialect.GetCastTypeName(toSqlTypes[0]); + if (!_parameters.SessionFactory.Dialect.TryGetCastTypeName(sqlTypes[0], out var castTypeName) || + !_parameters.SessionFactory.Dialect.TryGetCastTypeName(toSqlTypes[0], out var toCastTypeName)) + { + return false; // The dialect does not support such cast + } + return castTypeName != toCastTypeName; } - private bool IsCastRequired(System.Type type, string sqlFunctionName) + private bool IsCastRequired(System.Type type, string sqlFunctionName, out bool existType) { if (type == typeof(object)) { + existType = false; return false; } var toType = TypeFactory.GetDefaultTypeFor(type); if (toType == null) { + existType = false; return true; // Fallback to the old behavior } + existType = true; var sqlFunction = _parameters.SessionFactory.SQLFunctionRegistry.FindSQLFunction(sqlFunctionName); if (sqlFunction == null) { @@ -643,7 +658,7 @@ private bool IsCastRequired(System.Type type, string sqlFunctionName) } var fnReturnType = sqlFunction.ReturnType(toType, _parameters.SessionFactory); - return fnReturnType == null || IsCastRequired(fnReturnType, toType); + return fnReturnType == null || IsCastRequired(fnReturnType, toType, out existType); } private IType GetType(Expression expression) From 2a18f3786eba1f9f2741ba66c44a14fe5f07bfe6 Mon Sep 17 00:00:00 2001 From: maca88 Date: Fri, 15 Mar 2019 19:16:22 +0100 Subject: [PATCH 07/29] Add missed override for mysql --- src/NHibernate/Dialect/MySQLDialect.cs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/NHibernate/Dialect/MySQLDialect.cs b/src/NHibernate/Dialect/MySQLDialect.cs index 475c1d1e5c2..9a83de26eb3 100644 --- a/src/NHibernate/Dialect/MySQLDialect.cs +++ b/src/NHibernate/Dialect/MySQLDialect.cs @@ -497,6 +497,10 @@ protected void RegisterCastType(DbType code, int capacity, string name) public override string GetCastTypeName(SqlType sqlType) => GetCastTypeName(sqlType, castTypeNames); + /// + public override bool TryGetCastTypeName(SqlType sqlType, out string typeName) => + TryGetCastTypeName(sqlType, castTypeNames, out typeName); + public override long TimestampResolutionInTicks { get From e966220589f89df939f5dd4222bb8552b382bcb8 Mon Sep 17 00:00:00 2001 From: maca88 Date: Fri, 15 Mar 2019 20:48:55 +0100 Subject: [PATCH 08/29] async regen --- .../Async/NHSpecificTest/GH2029/Fixture.cs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH2029/Fixture.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH2029/Fixture.cs index c156cc700eb..b05df684d2e 100644 --- a/src/NHibernate.Test/Async/NHSpecificTest/GH2029/Fixture.cs +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH2029/Fixture.cs @@ -88,6 +88,9 @@ protected override void OnTearDown() [Test] public async Task NullableIntOverflowAsync() { + var hasCast = Dialect.GetCastTypeName(NHibernateUtil.Int32.SqlType) != + Dialect.GetCastTypeName(NHibernateUtil.Int64.SqlType); + using (var session = OpenSession()) using (session.BeginTransaction()) using (var sqlLog = new SqlLogSpy()) @@ -100,7 +103,7 @@ public async Task NullableIntOverflowAsync() }) .ToListAsync()); - Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(1)); + Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(hasCast ? 1 : 0)); Assert.That(groups, Has.Count.EqualTo(1)); Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 2)); } @@ -109,6 +112,9 @@ public async Task NullableIntOverflowAsync() [Test] public async Task IntOverflowAsync() { + var hasCast = Dialect.GetCastTypeName(NHibernateUtil.Int32.SqlType) != + Dialect.GetCastTypeName(NHibernateUtil.Int64.SqlType); + using (var session = OpenSession()) using (session.BeginTransaction()) using (var sqlLog = new SqlLogSpy()) @@ -121,7 +127,7 @@ public async Task IntOverflowAsync() }) .ToListAsync()); - Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(1)); + Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(hasCast ? 1 : 0)); Assert.That(groups, Has.Count.EqualTo(1)); Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 3)); } From 21e7dd4c3a6eaeb0e5452a2e49490734cf628af1 Mon Sep 17 00:00:00 2001 From: maca88 Date: Wed, 10 Apr 2019 21:37:43 +0200 Subject: [PATCH 09/29] Fix casting for custom registered types --- .../NHSpecific/NullableInt32.cs | 91 ++++++++++++++++++- .../Async/Linq/SelectionTests.cs | 19 ++++ src/NHibernate.Test/Linq/SelectionTests.cs | 19 ++++ src/NHibernate/Hql/Ast/HqlTreeNode.cs | 24 +++++ .../Visitors/HqlGeneratorExpressionVisitor.cs | 2 +- 5 files changed, 153 insertions(+), 2 deletions(-) diff --git a/src/NHibernate.DomainModel/NHSpecific/NullableInt32.cs b/src/NHibernate.DomainModel/NHSpecific/NullableInt32.cs index 89d599183bf..95abf028f98 100644 --- a/src/NHibernate.DomainModel/NHSpecific/NullableInt32.cs +++ b/src/NHibernate.DomainModel/NHSpecific/NullableInt32.cs @@ -7,7 +7,7 @@ namespace NHibernate.DomainModel.NHSpecific /// A nullable type that wraps an value. /// [TypeConverter(typeof(NullableInt32Converter)), Serializable()] - public struct NullableInt32 : IFormattable, IComparable + public struct NullableInt32 : IFormattable, IComparable, IConvertible { public static readonly NullableInt32 Default = new NullableInt32(); @@ -234,5 +234,94 @@ public static NullableInt32 Parse(string s) // TODO: implement the rest of the Parse overloads found in Int32 #endregion + + #region IConvertible + + public TypeCode GetTypeCode() + { + return _value.GetTypeCode(); + } + + public bool ToBoolean(IFormatProvider provider) + { + return ((IConvertible) _value).ToBoolean(provider); + } + + public char ToChar(IFormatProvider provider) + { + return ((IConvertible) _value).ToChar(provider); + } + + public sbyte ToSByte(IFormatProvider provider) + { + return ((IConvertible) _value).ToSByte(provider); + } + + public byte ToByte(IFormatProvider provider) + { + return ((IConvertible) _value).ToByte(provider); + } + + public short ToInt16(IFormatProvider provider) + { + return ((IConvertible) _value).ToInt16(provider); + } + + public ushort ToUInt16(IFormatProvider provider) + { + return ((IConvertible) _value).ToUInt16(provider); + } + + public int ToInt32(IFormatProvider provider) + { + return ((IConvertible) _value).ToInt32(provider); + } + + public uint ToUInt32(IFormatProvider provider) + { + return ((IConvertible) _value).ToUInt32(provider); + } + + public long ToInt64(IFormatProvider provider) + { + return ((IConvertible) _value).ToInt64(provider); + } + + public ulong ToUInt64(IFormatProvider provider) + { + return ((IConvertible) _value).ToUInt64(provider); + } + + public float ToSingle(IFormatProvider provider) + { + return ((IConvertible) _value).ToSingle(provider); + } + + public double ToDouble(IFormatProvider provider) + { + return ((IConvertible) _value).ToDouble(provider); + } + + public decimal ToDecimal(IFormatProvider provider) + { + return ((IConvertible) _value).ToDecimal(provider); + } + + public DateTime ToDateTime(IFormatProvider provider) + { + return ((IConvertible) _value).ToDateTime(provider); + } + + public string ToString(IFormatProvider provider) + { + return _value.ToString(provider); + } + + public object ToType(System.Type conversionType, IFormatProvider provider) + { + return ((IConvertible) _value).ToType(conversionType, provider); + } + + #endregion } } diff --git a/src/NHibernate.Test/Async/Linq/SelectionTests.cs b/src/NHibernate.Test/Async/Linq/SelectionTests.cs index b4ac7e372e4..6984cafd4d8 100644 --- a/src/NHibernate.Test/Async/Linq/SelectionTests.cs +++ b/src/NHibernate.Test/Async/Linq/SelectionTests.cs @@ -11,7 +11,9 @@ using System; using System.Collections.Generic; using System.Linq; +using NHibernate.DomainModel.NHSpecific; using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Type; using NUnit.Framework; using NHibernate.Linq; @@ -453,6 +455,23 @@ public async Task CanSelectConditionalObjectAsync() Assert.That(fatherIsKnown, Has.Exactly(1).With.Property("FatherIsKnown").True); } + [Test] + public async Task CanCastToDerivedTypeAsync() + { + var dogs = await (db.Animals + .Where(a => ((Dog) a).Pregnant) + .Select(a => new {a.SerialNumber}) + .ToListAsync()); + Assert.That(dogs, Has.Exactly(1).With.Property("SerialNumber").Not.Null); + } + + [Test] + public async Task CanCastToCustomRegisteredTypeAsync() + { + TypeFactory.RegisterType(typeof(NullableInt32), new NullableInt32Type(), Enumerable.Empty()); + Assert.That(await (db.Users.Where(o => (NullableInt32) o.Id == 1).ToListAsync()), Has.Count.EqualTo(1)); + } + public class Wrapper { public T item; diff --git a/src/NHibernate.Test/Linq/SelectionTests.cs b/src/NHibernate.Test/Linq/SelectionTests.cs index 3873558badf..5d2020f7bbe 100644 --- a/src/NHibernate.Test/Linq/SelectionTests.cs +++ b/src/NHibernate.Test/Linq/SelectionTests.cs @@ -1,7 +1,9 @@ using System; using System.Collections.Generic; using System.Linq; +using NHibernate.DomainModel.NHSpecific; using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Type; using NUnit.Framework; namespace NHibernate.Test.Linq @@ -492,6 +494,23 @@ public void CanSelectConditionalObject() Assert.That(fatherIsKnown, Has.Exactly(1).With.Property("FatherIsKnown").True); } + [Test] + public void CanCastToDerivedType() + { + var dogs = db.Animals + .Where(a => ((Dog) a).Pregnant) + .Select(a => new {a.SerialNumber}) + .ToList(); + Assert.That(dogs, Has.Exactly(1).With.Property("SerialNumber").Not.Null); + } + + [Test] + public void CanCastToCustomRegisteredType() + { + TypeFactory.RegisterType(typeof(NullableInt32), new NullableInt32Type(), Enumerable.Empty()); + Assert.That(db.Users.Where(o => (NullableInt32) o.Id == 1).ToList(), Has.Count.EqualTo(1)); + } + public class Wrapper { public T item; diff --git a/src/NHibernate/Hql/Ast/HqlTreeNode.cs b/src/NHibernate/Hql/Ast/HqlTreeNode.cs index 7289d5acbc5..5864a3130a7 100755 --- a/src/NHibernate/Hql/Ast/HqlTreeNode.cs +++ b/src/NHibernate/Hql/Ast/HqlTreeNode.cs @@ -257,6 +257,30 @@ internal HqlIdent(IASTFactory factory, System.Type type) throw new NotSupportedException(string.Format("Don't currently support idents of type {0}", type.Name)); } } + + internal static bool SupportsType(System.Type type) + { + type = type.UnwrapIfNullable(); + switch (System.Type.GetTypeCode(type)) + { + case TypeCode.Boolean: + case TypeCode.Int16: + case TypeCode.Int32: + case TypeCode.Int64: + case TypeCode.Decimal: + case TypeCode.Single: + case TypeCode.DateTime: + case TypeCode.String: + case TypeCode.Double: + return true; + default: + return new[] + { + typeof(Guid), + typeof(DateTimeOffset) + }.Contains(type); + } + } } public class HqlRange : HqlStatement diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index 62cc4da556a..8e82abc418c 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -486,7 +486,7 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression) return IsCastRequired(expression.Operand, expression.Type, out var existType) ? _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), expression.Type) // Make a transparent cast when an IType exists, so that it can be used to retrieve the value from the data reader - : existType + : existType && HqlIdent.SupportsType(expression.Type) ? _hqlTreeBuilder.TransparentCast(VisitExpression(expression.Operand).AsExpression(), expression.Type) : VisitExpression(expression.Operand); } From 8a04fa49c77debd9bb4940088bff2cfd444cb9bd Mon Sep 17 00:00:00 2001 From: maca88 Date: Tue, 1 Oct 2019 19:20:11 +0200 Subject: [PATCH 10/29] Update TryGetEntityName method --- .../Async/Linq/SelectionTests.cs | 4 ++ src/NHibernate.Test/Linq/SelectionTests.cs | 4 ++ .../Visitors/HqlGeneratorExpressionVisitor.cs | 27 +++----- .../Tuple/Entity/EntityMetamodel.cs | 46 +++++++++++-- src/NHibernate/Util/ExpressionsHelper.cs | 65 ++++++++++++++++++- 5 files changed, 121 insertions(+), 25 deletions(-) diff --git a/src/NHibernate.Test/Async/Linq/SelectionTests.cs b/src/NHibernate.Test/Async/Linq/SelectionTests.cs index 6984cafd4d8..cf065e6bf5d 100644 --- a/src/NHibernate.Test/Async/Linq/SelectionTests.cs +++ b/src/NHibernate.Test/Async/Linq/SelectionTests.cs @@ -309,6 +309,10 @@ public async Task CanProjectWithCastAsync() var names5 = await (db.Users.Select(p => new { p1 = (p as IUser).Name }).ToListAsync()); Assert.AreEqual(3, names5.Count); + + var names6 = await (db.Users.Select(p => new { p1 = (long) p.Id }).ToListAsync()); + Assert.AreEqual(3, names6.Count); + // ReSharper restore RedundantCast } diff --git a/src/NHibernate.Test/Linq/SelectionTests.cs b/src/NHibernate.Test/Linq/SelectionTests.cs index 5d2020f7bbe..7aac7edc2da 100644 --- a/src/NHibernate.Test/Linq/SelectionTests.cs +++ b/src/NHibernate.Test/Linq/SelectionTests.cs @@ -348,6 +348,10 @@ public void CanProjectWithCast() var names5 = db.Users.Select(p => new { p1 = (p as IUser).Name }).ToList(); Assert.AreEqual(3, names5.Count); + + var names6 = db.Users.Select(p => new { p1 = (long) p.Id }).ToList(); + Assert.AreEqual(3, names6.Count); + // ReSharper restore RedundantCast } diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index 8e82abc418c..9658c07f75c 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -671,34 +671,23 @@ private IType GetType(Expression expression) } // Try to get the mapped type for the member as it may be a non default one - var entityName = TryGetEntityName(memberExpression); + var entityName = ExpressionsHelper.TryGetEntityName(_parameters.SessionFactory, memberExpression, out var memberPath); if (entityName == null) { return TypeFactory.GetDefaultTypeFor(expression.Type); // Not mapped } var persister = _parameters.SessionFactory.GetEntityPersister(entityName); - var index = persister.EntityMetamodel.GetPropertyIndexOrNull(memberExpression.Member.Name); - return !index.HasValue - ? TypeFactory.GetDefaultTypeFor(expression.Type) // Not mapped - : persister.EntityMetamodel.PropertyTypes[index.Value]; - } - - private string TryGetEntityName(MemberExpression memberExpression) - { - System.Type entityType; - // Try to get the actual entity type from the query source if possbile as member can be declared - // in a base type - if (memberExpression.Expression is QuerySourceReferenceExpression querySourceReferenceExpression) + var type = persister.EntityMetamodel.GetIdentifierPropertyType(memberPath); + if (type != null) { - entityType = querySourceReferenceExpression.Type; - } - else - { - entityType = memberExpression.Member.ReflectedType; + return type; } - return _parameters.SessionFactory.TryGetGuessEntityName(entityType); + var index = persister.EntityMetamodel.GetPropertyIndexOrNull(memberPath); + return !index.HasValue + ? TypeFactory.GetDefaultTypeFor(expression.Type) // Not mapped + : persister.EntityMetamodel.PropertyTypes[index.Value]; } } } diff --git a/src/NHibernate/Tuple/Entity/EntityMetamodel.cs b/src/NHibernate/Tuple/Entity/EntityMetamodel.cs index b12c49975c4..65c24a0052c 100644 --- a/src/NHibernate/Tuple/Entity/EntityMetamodel.cs +++ b/src/NHibernate/Tuple/Entity/EntityMetamodel.cs @@ -50,6 +50,7 @@ public class EntityMetamodel private readonly CascadeStyle[] cascadeStyles; private readonly Dictionary propertyIndexes = new Dictionary(); + private readonly IDictionary _identifierPropertyTypes = new Dictionary(); private readonly bool hasCollections; private readonly bool hasMutableProperties; private readonly bool hasLazyProperties; @@ -91,6 +92,7 @@ public EntityMetamodel(PersistentClass persistentClass, ISessionFactoryImplement identifierProperty = PropertyFactory.BuildIdentifierProperty(persistentClass, sessionFactory.GetIdentifierGenerator(rootName)); + MapIdentifierPropertyTypes(identifierProperty); versioned = persistentClass.IsVersioned; @@ -409,13 +411,42 @@ private bool HasPartialUpdateComponentGeneration(Mapping.Component component) private void MapPropertyToIndex(Mapping.Property prop, int i) { - propertyIndexes[prop.Name] = i; - Mapping.Component comp = prop.Value as Mapping.Component; - if (comp != null) + MapPropertyToIndex(null, prop, i); + } + + private void MapPropertyToIndex(string path, Mapping.Property prop, int i) + { + propertyIndexes[!string.IsNullOrEmpty(path) ? $"{path}.{prop.Name}" : prop.Name] = i; + if (!(prop.Value is Mapping.Component comp)) + { + return; + } + + foreach (var subprop in comp.PropertyIterator) + { + MapPropertyToIndex(!string.IsNullOrEmpty(path) ? $"{path}.{prop.Name}" : prop.Name, subprop, i); + } + } + + private void MapIdentifierPropertyTypes(IdentifierProperty identifier) + { + MapIdentifierPropertyTypes(identifier.Name, identifier.Type); + } + + private void MapIdentifierPropertyTypes(string path, IType propertyType) + { + if (!string.IsNullOrEmpty(path)) + { + _identifierPropertyTypes[path] = propertyType; + } + + if (propertyType is IAbstractComponentType componentType) { - foreach (Mapping.Property subprop in comp.PropertyIterator) + for (var i = 0; i < componentType.PropertyNames.Length; i++) { - propertyIndexes[prop.Name + '.' + subprop.Name] = i; + MapIdentifierPropertyTypes( + !string.IsNullOrEmpty(path) ? $"{path}.{componentType.PropertyNames[i]}" : componentType.PropertyNames[i], + componentType.Subtypes[i]); } } } @@ -534,6 +565,11 @@ public int GetPropertyIndex(string propertyName) return null; } + internal IType GetIdentifierPropertyType(string memberPath) + { + return _identifierPropertyTypes.TryGetValue(memberPath, out var propertyType) ? propertyType : null; + } + public bool HasCollections { get { return hasCollections; } diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index 6fa3a44615f..c996b2aa5fa 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -1,6 +1,10 @@ using System.Linq.Expressions; using System.Reflection; using System; +using NHibernate.Engine; +using NHibernate.Type; +using Remotion.Linq.Clauses; +using Remotion.Linq.Clauses.Expressions; namespace NHibernate.Util { @@ -15,5 +19,64 @@ public static MemberInfo DecodeMemberAccessExpression(Expressi } return ((MemberExpression)expression.Body).Member; } + + internal static string TryGetEntityName(ISessionFactoryImplementor sessionFactory, MemberExpression memberExpression, out string memberPath) + { + string entityName; + memberPath = memberExpression.Member.Name; + // When having components we need to go though them in order to find the entity + while (memberExpression.Expression is MemberExpression subMemberExpression) + { + // In some cases we can encounter a property representing the entity e.g. [_0].Customer.CustomerId + if (subMemberExpression.NodeType == ExpressionType.MemberAccess) + { + entityName = sessionFactory.TryGetGuessEntityName(memberExpression.Member.ReflectedType); + if (entityName != null) + { + return entityName; + } + } + + memberPath = $"{subMemberExpression.Member.Name}.{memberPath}"; // Build a path that can be used to get the property form the entity metadata + memberExpression = subMemberExpression; + } + + // Try to get the actual entity type from the query source if possbile as member can be declared + // in a base type + if (memberExpression.Expression is QuerySourceReferenceExpression querySourceReferenceExpression) + { + entityName = sessionFactory.TryGetGuessEntityName(querySourceReferenceExpression.Type); + if (entityName != null || + !(querySourceReferenceExpression.ReferencedQuerySource is IFromClause fromClause) || + !(fromClause.FromExpression is MemberExpression subMemberExpression)) + { + return entityName; + } + + // When the member type is not the one that is mapped (e.g. interface) we have to find the first + // mapped entity and calculate the entity name from there + entityName = TryGetEntityName(sessionFactory, subMemberExpression, out var subMemberPath); + if (entityName == null) + { + return null; + } + + var persister = sessionFactory.GetEntityPersister(entityName); + var index = persister.EntityMetamodel.GetPropertyIndexOrNull(subMemberPath); + IAssociationType associationType; + if (index.HasValue) + { + associationType = persister.PropertyTypes[index.Value] as IAssociationType; + } + else + { + associationType = persister.EntityMetamodel.GetIdentifierPropertyType(subMemberPath) as IAssociationType; + } + + return associationType?.GetAssociatedEntityName(sessionFactory); + } + + return sessionFactory.TryGetGuessEntityName(memberExpression.Member.ReflectedType); + } } -} \ No newline at end of file +} From bad5b0e492979daccbd53889013ccf549de405bb Mon Sep 17 00:00:00 2001 From: maca88 Date: Tue, 1 Oct 2019 19:21:19 +0200 Subject: [PATCH 11/29] Simplify SupportsType method --- src/NHibernate/Hql/Ast/HqlTreeNode.cs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/NHibernate/Hql/Ast/HqlTreeNode.cs b/src/NHibernate/Hql/Ast/HqlTreeNode.cs index 5864a3130a7..5964b99db90 100755 --- a/src/NHibernate/Hql/Ast/HqlTreeNode.cs +++ b/src/NHibernate/Hql/Ast/HqlTreeNode.cs @@ -274,11 +274,9 @@ internal static bool SupportsType(System.Type type) case TypeCode.Double: return true; default: - return new[] - { - typeof(Guid), - typeof(DateTimeOffset) - }.Contains(type); + return + type == typeof(Guid) || + type == typeof(DateTimeOffset); } } } From 4fe3d7ff124c6acae9d9eedb3909df3709e01143 Mon Sep 17 00:00:00 2001 From: maca88 Date: Tue, 1 Oct 2019 19:28:11 +0200 Subject: [PATCH 12/29] Fixed CodeFactor issue --- .../Async/NHSpecificTest/GH2029/Fixture.cs | 1 - .../NHSpecificTest/GH2029/Fixture.cs | 9 --------- .../NHSpecificTest/GH2029/TestClass.cs | 17 +++++++++++++++++ 3 files changed, 17 insertions(+), 10 deletions(-) create mode 100644 src/NHibernate.Test/NHSpecificTest/GH2029/TestClass.cs diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH2029/Fixture.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH2029/Fixture.cs index b05df684d2e..7b148a864dc 100644 --- a/src/NHibernate.Test/Async/NHSpecificTest/GH2029/Fixture.cs +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH2029/Fixture.cs @@ -19,7 +19,6 @@ namespace NHibernate.Test.NHSpecificTest.GH2029 { using System.Threading.Tasks; - [TestFixture] public class FixtureAsync : TestCaseMappingByCode { diff --git a/src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs b/src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs index c0ec490cf5e..544034db0ea 100644 --- a/src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs +++ b/src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs @@ -7,15 +7,6 @@ namespace NHibernate.Test.NHSpecificTest.GH2029 { - public class TestClass - { - public virtual int Id { get; set; } - public virtual int? NullableInt32Prop { get; set; } - public virtual int Int32Prop { get; set; } - public virtual long? NullableInt64Prop { get; set; } - public virtual long Int64Prop { get; set; } - } - [TestFixture] public class Fixture : TestCaseMappingByCode { diff --git a/src/NHibernate.Test/NHSpecificTest/GH2029/TestClass.cs b/src/NHibernate.Test/NHSpecificTest/GH2029/TestClass.cs new file mode 100644 index 00000000000..c15c60dfee3 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH2029/TestClass.cs @@ -0,0 +1,17 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace NHibernate.Test.NHSpecificTest.GH2029 +{ + public class TestClass + { + public virtual int Id { get; set; } + public virtual int? NullableInt32Prop { get; set; } + public virtual int Int32Prop { get; set; } + public virtual long? NullableInt64Prop { get; set; } + public virtual long Int64Prop { get; set; } + } +} From 6f32861684c0b3e05c34e712c67dd9a465e32678 Mon Sep 17 00:00:00 2001 From: maca88 Date: Sat, 5 Oct 2019 02:14:42 +0200 Subject: [PATCH 13/29] Fix TryGetEntityName for custom entity names --- .../Linq/Functions/ListIndexerGenerator.cs | 35 +- .../Visitors/HqlGeneratorExpressionVisitor.cs | 53 +-- .../Tuple/Entity/EntityMetamodel.cs | 14 +- src/NHibernate/Util/ExpressionsHelper.cs | 350 ++++++++++++++++-- 4 files changed, 372 insertions(+), 80 deletions(-) diff --git a/src/NHibernate/Linq/Functions/ListIndexerGenerator.cs b/src/NHibernate/Linq/Functions/ListIndexerGenerator.cs index 6435b88a476..f03f6700230 100644 --- a/src/NHibernate/Linq/Functions/ListIndexerGenerator.cs +++ b/src/NHibernate/Linq/Functions/ListIndexerGenerator.cs @@ -12,20 +12,30 @@ namespace NHibernate.Linq.Functions { internal class ListIndexerGenerator : BaseHqlGeneratorForMethod,IRuntimeMethodHqlGenerator { + private static readonly HashSet _supportedMethods = new HashSet + { + ReflectHelper.GetMethodDefinition(() => Enumerable.ElementAt(null, 0)), + ReflectHelper.GetMethodDefinition(() => Queryable.ElementAt(null, 0)) + }; + public ListIndexerGenerator() { - SupportedMethods = new[] - { - ReflectHelper.GetMethodDefinition(() => Enumerable.ElementAt(null, 0)), - ReflectHelper.GetMethodDefinition(() => Queryable.ElementAt(null, 0)) - }; + SupportedMethods = _supportedMethods; } public bool SupportsMethod(MethodInfo method) { - return method != null && - method.Name == "get_Item" && - (method.IsMethodOf(typeof(IList)) || method.IsMethodOf(typeof(IList<>))); + return IsRuntimeMethodSupported(method); + } + + public static bool IsMethodSupported(MethodInfo method) + { + if (method.IsGenericMethod) + { + method = method.GetGenericMethodDefinition(); + } + + return _supportedMethods.Contains(method) || IsRuntimeMethodSupported(method); } public IHqlGeneratorForMethod GetMethodGenerator(MethodInfo method) @@ -40,5 +50,12 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, return treeBuilder.Index(collection, index); } + + private static bool IsRuntimeMethodSupported(MethodInfo method) + { + return method != null && + method.Name == "get_Item" && + (method.IsMethodOf(typeof(IList)) || method.IsMethodOf(typeof(IList<>))); + } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index 9658c07f75c..d09730ca757 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -240,6 +240,9 @@ constant.Value is CallSite site && protected HqlTreeNode VisitNhAverage(NhAverageExpression expression) { + // We need to cast the argument when its type is different from Average method return type, + // otherwise the result may be incorrect. In SQL Server avg always returns int + // when the argument is int. var hqlExpression = VisitExpression(expression.Expression).AsExpression(); hqlExpression = IsCastRequired(expression.Expression, expression.Type, out _) ? (HqlExpression) _hqlTreeBuilder.Cast(hqlExpression, expression.Type) @@ -267,7 +270,7 @@ protected HqlTreeNode VisitNhMax(NhMaxExpression expression) protected HqlTreeNode VisitNhSum(NhSumExpression expression) { - return IsCastRequired(expression.Type, "sum", out _) + return IsCastRequired("sum", expression.Expression, expression.Type) ? (HqlTreeNode) _hqlTreeBuilder.Cast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type) : _hqlTreeBuilder.TransparentCast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type); } @@ -593,7 +596,8 @@ protected HqlTreeNode VisitNewArrayExpression(NewArrayExpression expression) private bool IsCastRequired(Expression expression, System.Type toType, out bool existType) { existType = false; - return toType != typeof(object) && IsCastRequired(GetType(expression), TypeFactory.GetDefaultTypeFor(toType), out existType); + return toType != typeof(object) && + IsCastRequired(GetType(expression), TypeFactory.GetDefaultTypeFor(toType), out existType); } private bool IsCastRequired(IType type, IType toType, out bool existType) @@ -635,59 +639,38 @@ private bool IsCastRequired(IType type, IType toType, out bool existType) return castTypeName != toCastTypeName; } - private bool IsCastRequired(System.Type type, string sqlFunctionName, out bool existType) + private bool IsCastRequired(string sqlFunctionName, Expression argumentExpression, System.Type returnType) { - if (type == typeof(object)) + var argumentType = GetType(argumentExpression); + if (argumentType == null || returnType == typeof(object)) { - existType = false; return false; } - var toType = TypeFactory.GetDefaultTypeFor(type); - if (toType == null) + var returnNhType = TypeFactory.GetDefaultTypeFor(returnType); + if (returnNhType == null) { - existType = false; return true; // Fallback to the old behavior } - existType = true; var sqlFunction = _parameters.SessionFactory.SQLFunctionRegistry.FindSQLFunction(sqlFunctionName); if (sqlFunction == null) { return true; // Fallback to the old behavior } - var fnReturnType = sqlFunction.ReturnType(toType, _parameters.SessionFactory); - return fnReturnType == null || IsCastRequired(fnReturnType, toType, out existType); + var fnReturnType = sqlFunction.ReturnType(argumentType, _parameters.SessionFactory); + return fnReturnType == null || IsCastRequired(fnReturnType, returnNhType, out _); } private IType GetType(Expression expression) { - if (!(expression is MemberExpression memberExpression)) - { - return expression.Type != typeof(object) - ? TypeFactory.GetDefaultTypeFor(expression.Type) - : null; - } - // Try to get the mapped type for the member as it may be a non default one - var entityName = ExpressionsHelper.TryGetEntityName(_parameters.SessionFactory, memberExpression, out var memberPath); - if (entityName == null) - { - return TypeFactory.GetDefaultTypeFor(expression.Type); // Not mapped - } - - var persister = _parameters.SessionFactory.GetEntityPersister(entityName); - var type = persister.EntityMetamodel.GetIdentifierPropertyType(memberPath); - if (type != null) - { - return type; - } - - var index = persister.EntityMetamodel.GetPropertyIndexOrNull(memberPath); - return !index.HasValue - ? TypeFactory.GetDefaultTypeFor(expression.Type) // Not mapped - : persister.EntityMetamodel.PropertyTypes[index.Value]; + ExpressionsHelper.TryGetEntityName(_parameters.SessionFactory, expression, out _, out var type); + return type ?? + (expression.Type != typeof(object) + ? TypeFactory.GetDefaultTypeFor(expression.Type) + : null); } } } diff --git a/src/NHibernate/Tuple/Entity/EntityMetamodel.cs b/src/NHibernate/Tuple/Entity/EntityMetamodel.cs index 65c24a0052c..3b793d348a8 100644 --- a/src/NHibernate/Tuple/Entity/EntityMetamodel.cs +++ b/src/NHibernate/Tuple/Entity/EntityMetamodel.cs @@ -51,6 +51,7 @@ public class EntityMetamodel private readonly Dictionary propertyIndexes = new Dictionary(); private readonly IDictionary _identifierPropertyTypes = new Dictionary(); + private readonly IDictionary _propertyTypes = new Dictionary(); private readonly bool hasCollections; private readonly bool hasMutableProperties; private readonly bool hasLazyProperties; @@ -416,7 +417,9 @@ private void MapPropertyToIndex(Mapping.Property prop, int i) private void MapPropertyToIndex(string path, Mapping.Property prop, int i) { - propertyIndexes[!string.IsNullOrEmpty(path) ? $"{path}.{prop.Name}" : prop.Name] = i; + var propPath = !string.IsNullOrEmpty(path) ? $"{path}.{prop.Name}" : prop.Name; + propertyIndexes[propPath] = i; + _propertyTypes[propPath] = prop.Type; if (!(prop.Value is Mapping.Component comp)) { return; @@ -424,7 +427,7 @@ private void MapPropertyToIndex(string path, Mapping.Property prop, int i) foreach (var subprop in comp.PropertyIterator) { - MapPropertyToIndex(!string.IsNullOrEmpty(path) ? $"{path}.{prop.Name}" : prop.Name, subprop, i); + MapPropertyToIndex(propPath, subprop, i); } } @@ -570,6 +573,13 @@ internal IType GetIdentifierPropertyType(string memberPath) return _identifierPropertyTypes.TryGetValue(memberPath, out var propertyType) ? propertyType : null; } + internal IType GetPropertyType(string memberPath) + { + return _propertyTypes.TryGetValue(memberPath, out var propertyType) + ? propertyType + : GetIdentifierPropertyType(memberPath); + } + public bool HasCollections { get { return hasCollections; } diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index c996b2aa5fa..ef5eb437a86 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -1,7 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Linq; using System.Linq.Expressions; using System.Reflection; -using System; using NHibernate.Engine; +using NHibernate.Linq; +using NHibernate.Linq.Expressions; +using NHibernate.Linq.Functions; +using NHibernate.Persister.Collection; +using NHibernate.Persister.Entity; using NHibernate.Type; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; @@ -20,63 +27,338 @@ public static MemberInfo DecodeMemberAccessExpression(Expressi return ((MemberExpression)expression.Body).Member; } - internal static string TryGetEntityName(ISessionFactoryImplementor sessionFactory, MemberExpression memberExpression, out string memberPath) + internal static string TryGetEntityName( + ISessionFactoryImplementor sessionFactory, + Expression expression, + out string memberPath, + out IType memberType) { - string entityName; - memberPath = memberExpression.Member.Name; - // When having components we need to go though them in order to find the entity - while (memberExpression.Expression is MemberExpression subMemberExpression) + var memberPaths = TryGetAllMemberMetadata(sessionFactory, expression, out var entityName, out var convertType); + if (memberPaths == null) + { + memberPath = null; + memberType = null; + return null; + } + + entityName = GetEntityName(entityName, convertType, sessionFactory, out var persister); + if (entityName == null || memberPaths.Count == 0) // ((NotMapped)q).Prop || q + { + memberPath = null; + memberType = null; + return entityName; + } + + var member = memberPaths.Pop(); + var type = persister.EntityMetamodel.GetPropertyType(member.Path); + while (true) { - // In some cases we can encounter a property representing the entity e.g. [_0].Customer.CustomerId - if (subMemberExpression.NodeType == ExpressionType.MemberAccess) + if (type == null) // q.NotMappedProp + { + memberPath = null; + memberType = null; + return entityName; + } + + if (memberPaths.Count == 0) // q.ManyToOne || q.OneToMany || q.OneToMany[0] || q.Component || q.Prop || q.AnyType + { + memberPath = member.Path; + memberType = GetType(entityName, type, memberPath, member.HasIndexer, sessionFactory); + return entityName; + } + + if (type is IAssociationType associationType) { - entityName = sessionFactory.TryGetGuessEntityName(memberExpression.Member.ReflectedType); - if (entityName != null) + if (associationType.IsCollectionType) + { + // Check manually for entity association as GetAssociatedEntityName throws when there is none. + var queryableCollection = + (IQueryableCollection) associationType.GetAssociatedJoinable(sessionFactory); + if (!queryableCollection.ElementType.IsEntityType) // q.OneToMany[0].CompositeElement + { + entityName = null; + persister = null; + // Ignore if the type is casted as composite elements cannot be casted to entities + type = queryableCollection.ElementType; + member = memberPaths.Pop(); + continue; + } + + // q.OneToMany[0].Prop + entityName = GetEntityName( + associationType.GetAssociatedEntityName(sessionFactory), + member.ConvertType, + sessionFactory, + out persister); + } + else if (associationType.IsAnyType) + { + // ((Address)q.AnyType).Prop || q.AnyType.Prop + // Unfortunately we cannot detect the exact entity name as cast does not provide it, + // so the only option is to guess it. + entityName = GetEntityName(member.ConvertType, sessionFactory, out persister); + } + else // q.ManyToOne.Prop { - return entityName; + entityName = GetEntityName( + associationType.GetAssociatedEntityName(sessionFactory), + member.ConvertType, + sessionFactory, + out persister); } + + if (entityName == null) // q.AnyType.Prop || ((NotMappedClass)q.ManyToOne).Prop + { + memberPath = null; + memberType = null; + return null; + } + + member = memberPaths.Pop(); + type = persister.EntityMetamodel.GetPropertyType(member.Path); } + else if (type is IAbstractComponentType componentType) + { + // q.OneToMany[0].CompositeElement.Prop + if (entityName == null) + { + var index = GetComponentPropertyIndex(componentType, member.Path); + if (!index.HasValue) + { + memberPath = null; + memberType = null; + return null; + } - memberPath = $"{subMemberExpression.Member.Name}.{memberPath}"; // Build a path that can be used to get the property form the entity metadata - memberExpression = subMemberExpression; + type = componentType.Subtypes[index.Value]; + continue; + } + + // q.Component.Prop + // Ignore if the type is casted as components cannot be casted to entities + var componentMember = memberPaths.Pop(); + member = new MemberMetadata( + $"{member.Path}.{componentMember.Path}", + componentMember.ConvertType, + componentMember.HasIndexer); + type = persister.EntityMetamodel.GetPropertyType(member.Path); + } + else + { + // q.Prop.NotMappedProp + memberPath = null; + memberType = null; + return entityName; + } } + } - // Try to get the actual entity type from the query source if possbile as member can be declared - // in a base type - if (memberExpression.Expression is QuerySourceReferenceExpression querySourceReferenceExpression) + private static Stack TryGetAllMemberMetadata( + ISessionFactoryImplementor sessionFactory, + Expression expression, + out string entityName, + out System.Type convertType) + { + var memberPaths = new Stack(); + var currentExpression = expression; + convertType = null; + bool hasIndexer = false; + while (true) { - entityName = sessionFactory.TryGetGuessEntityName(querySourceReferenceExpression.Type); - if (entityName != null || - !(querySourceReferenceExpression.ReferencedQuerySource is IFromClause fromClause) || - !(fromClause.FromExpression is MemberExpression subMemberExpression)) + if (currentExpression is MemberExpression subMemberExpression) { - return entityName; + memberPaths.Push(new MemberMetadata(subMemberExpression.Member.Name, convertType, hasIndexer)); + convertType = null; + hasIndexer = false; + currentExpression = subMemberExpression.Expression; } + else if (currentExpression is QuerySourceReferenceExpression querySourceReferenceExpression) + { + if (querySourceReferenceExpression.ReferencedQuerySource is IFromClause fromClause) + { + currentExpression = fromClause.FromExpression; + } + else if (querySourceReferenceExpression.ReferencedQuerySource is JoinClause joinClause) + { + currentExpression = joinClause.InnerSequence; + } + else + { + // Unknown ReferencedQuerySource + entityName = null; + return null; + } + } + else if (currentExpression is UnaryExpression unaryExpression) // ((BaseEntity)q.Entity).Prop + { + currentExpression = unaryExpression.Operand; + convertType = unaryExpression.Type; + } + else if (currentExpression is NhNominatedExpression nominatedExpression) // ((BaseEntity)q.Entity).Prop + { + currentExpression = nominatedExpression.Expression; + } + else if (currentExpression is ConstantExpression constantExpression) + { + if (!(constantExpression.Value is IEntityNameProvider entityNameProvider)) + { + // Not a NhQueryable + entityName = null; + return null; + } - // When the member type is not the one that is mapped (e.g. interface) we have to find the first - // mapped entity and calculate the entity name from there - entityName = TryGetEntityName(sessionFactory, subMemberExpression, out var subMemberPath); - if (entityName == null) + entityName = entityNameProvider.EntityName; + break; + } + else if (currentExpression is MethodCallExpression methodCallExpression && + ListIndexerGenerator.IsMethodSupported(methodCallExpression.Method)) + { + currentExpression = methodCallExpression.Object == null + ? Enumerable.First(methodCallExpression.Arguments) // q.Children.ElementAt(0) + : methodCallExpression.Object; // q.Children[0] + hasIndexer = true; + } + else { + // Not supported expressions + entityName = null; return null; } + } + + return memberPaths; + } + + private static string GetEntityName( + string currentEntityName, + System.Type convertedType, + ISessionFactoryImplementor sessionFactory, + out IEntityPersister persister) + { + persister = sessionFactory.TryGetEntityPersister(currentEntityName); + if (persister == null) + { + return null; // Querying an unmapped interface e.g. s.Query().Where(a => a.Type == "A") + } + + if (convertedType == null) + { + return currentEntityName; + } + + if (persister.EntityMetamodel.HasSubclasses) + { + // When a class is casted to a subclass e.g. ((PizzaOrder)c.Order).PizzaName, we + // can only guess the entity name of it, as there can be many entity names mapped + // to the same subclass. + persister = persister.EntityMetamodel.SubclassEntityNames + .Select(sessionFactory.GetEntityPersister) + .FirstOrDefault(p => p.MappedClass == convertedType); + + return persister?.EntityName; + } + + return GetEntityName(convertedType, sessionFactory, out persister); + } + + private static string GetEntityName( + System.Type convertedType, + ISessionFactoryImplementor sessionFactory, + out IEntityPersister persister) + { + if (convertedType == null) + { + persister = null; + return null; + } - var persister = sessionFactory.GetEntityPersister(entityName); - var index = persister.EntityMetamodel.GetPropertyIndexOrNull(subMemberPath); - IAssociationType associationType; - if (index.HasValue) + var entityName = sessionFactory.TryGetGuessEntityName(convertedType); + if (entityName == null) + { + persister = null; + return null; + } + + persister = sessionFactory.GetEntityPersister(entityName); + return entityName; + } + + private static int? GetComponentPropertyIndex(IAbstractComponentType componentType, string name) + { + var names = componentType.PropertyNames; + for (var i = 0; i < names.Length; i++) + { + if (names[i].Equals(name)) { - associationType = persister.PropertyTypes[index.Value] as IAssociationType; + return i; } - else + } + + return null; + } + + private static IType GetType( + string entityName, + IType currentType, + string memberPath, + bool hasIndexer, + ISessionFactoryImplementor sessionFactory) + { + if (entityName == null && memberPath == null) + { + return null; + } + + if (entityName == null) + { + // q.OneToMany[0].CompositeElement.Prop + if (currentType is IAbstractComponentType componentType) { - associationType = persister.EntityMetamodel.GetIdentifierPropertyType(subMemberPath) as IAssociationType; + var index = GetComponentPropertyIndex(componentType, memberPath); + return index.HasValue + ? componentType.Subtypes[index.Value] + : null; } - return associationType?.GetAssociatedEntityName(sessionFactory); + return null; } - return sessionFactory.TryGetGuessEntityName(memberExpression.Member.ReflectedType); + if (memberPath == null) // q.NotMappedProp + { + return null; + } + + if (!hasIndexer) // q.Prop + { + return currentType; + } + + // q.OneToMany[0] + if (currentType is IAssociationType associationType) + { + var queryableCollection = + (IQueryableCollection) associationType.GetAssociatedJoinable(sessionFactory); + return queryableCollection.ElementType; + } + + // q.Prop[0] + return null; + } + + private struct MemberMetadata + { + public MemberMetadata(string path, System.Type convertType, bool hasIndexer) + { + Path = path; + ConvertType = convertType; + HasIndexer = hasIndexer; + } + + public string Path { get; } + + public System.Type ConvertType { get; } + + public bool HasIndexer { get; } } } } From c35aff7d12ef6c83250c28714b5e773da024f16a Mon Sep 17 00:00:00 2001 From: Alexander Zaytsev Date: Sat, 5 Oct 2019 22:04:23 +1300 Subject: [PATCH 14/29] Remove GetComponentPropertyIndex method --- src/NHibernate/Util/ExpressionsHelper.cs | 32 +++++++----------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index ef5eb437a86..cc29f31bd54 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -122,15 +122,15 @@ internal static string TryGetEntityName( // q.OneToMany[0].CompositeElement.Prop if (entityName == null) { - var index = GetComponentPropertyIndex(componentType, member.Path); - if (!index.HasValue) + var index = Array.IndexOf(componentType.PropertyNames, member.Path); + if (index < 0) { memberPath = null; memberType = null; return null; } - type = componentType.Subtypes[index.Value]; + type = componentType.Subtypes[index]; continue; } @@ -283,20 +283,6 @@ private static string GetEntityName( return entityName; } - private static int? GetComponentPropertyIndex(IAbstractComponentType componentType, string name) - { - var names = componentType.PropertyNames; - for (var i = 0; i < names.Length; i++) - { - if (names[i].Equals(name)) - { - return i; - } - } - - return null; - } - private static IType GetType( string entityName, IType currentType, @@ -314,10 +300,11 @@ private static IType GetType( // q.OneToMany[0].CompositeElement.Prop if (currentType is IAbstractComponentType componentType) { - var index = GetComponentPropertyIndex(componentType, memberPath); - return index.HasValue - ? componentType.Subtypes[index.Value] - : null; + var names = componentType.PropertyNames; + var index = Array.IndexOf(names, memberPath); + return index < 0 + ? null + : componentType.Subtypes[index]; } return null; @@ -336,8 +323,7 @@ private static IType GetType( // q.OneToMany[0] if (currentType is IAssociationType associationType) { - var queryableCollection = - (IQueryableCollection) associationType.GetAssociatedJoinable(sessionFactory); + var queryableCollection = (IQueryableCollection) associationType.GetAssociatedJoinable(sessionFactory); return queryableCollection.ElementType; } From f5fda9fe7adf9e394a8db4a465af9269402d0df8 Mon Sep 17 00:00:00 2001 From: Alexander Zaytsev Date: Sat, 5 Oct 2019 22:05:27 +1300 Subject: [PATCH 15/29] Remove unused parameter from TryGetAllMemberMetadata method --- src/NHibernate/Util/ExpressionsHelper.cs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index cc29f31bd54..63d8070ff04 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -33,7 +33,7 @@ internal static string TryGetEntityName( out string memberPath, out IType memberType) { - var memberPaths = TryGetAllMemberMetadata(sessionFactory, expression, out var entityName, out var convertType); + var memberPaths = TryGetAllMemberMetadata(expression, out var entityName, out var convertType); if (memberPaths == null) { memberPath = null; @@ -154,7 +154,6 @@ internal static string TryGetEntityName( } private static Stack TryGetAllMemberMetadata( - ISessionFactoryImplementor sessionFactory, Expression expression, out string entityName, out System.Type convertType) From 27af9bc1924e5503eef09baedc60804b4c1d5298 Mon Sep 17 00:00:00 2001 From: Alexander Zaytsev Date: Sat, 5 Oct 2019 22:44:28 +1300 Subject: [PATCH 16/29] Replace TryGetAllMemberMetadata with a visitor class --- src/NHibernate/Util/ExpressionsHelper.cs | 172 +++++++++++++---------- 1 file changed, 96 insertions(+), 76 deletions(-) diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index 63d8070ff04..34c7e5ad599 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -7,6 +7,7 @@ using NHibernate.Linq; using NHibernate.Linq.Expressions; using NHibernate.Linq.Functions; +using NHibernate.Linq.Visitors; using NHibernate.Persister.Collection; using NHibernate.Persister.Entity; using NHibernate.Type; @@ -33,7 +34,7 @@ internal static string TryGetEntityName( out string memberPath, out IType memberType) { - var memberPaths = TryGetAllMemberMetadata(expression, out var entityName, out var convertType); + var memberPaths = MemberMetadataExtractor.TryGetAllMemberMetadata(expression, out var entityName, out var convertType); if (memberPaths == null) { memberPath = null; @@ -153,81 +154,6 @@ internal static string TryGetEntityName( } } - private static Stack TryGetAllMemberMetadata( - Expression expression, - out string entityName, - out System.Type convertType) - { - var memberPaths = new Stack(); - var currentExpression = expression; - convertType = null; - bool hasIndexer = false; - while (true) - { - if (currentExpression is MemberExpression subMemberExpression) - { - memberPaths.Push(new MemberMetadata(subMemberExpression.Member.Name, convertType, hasIndexer)); - convertType = null; - hasIndexer = false; - currentExpression = subMemberExpression.Expression; - } - else if (currentExpression is QuerySourceReferenceExpression querySourceReferenceExpression) - { - if (querySourceReferenceExpression.ReferencedQuerySource is IFromClause fromClause) - { - currentExpression = fromClause.FromExpression; - } - else if (querySourceReferenceExpression.ReferencedQuerySource is JoinClause joinClause) - { - currentExpression = joinClause.InnerSequence; - } - else - { - // Unknown ReferencedQuerySource - entityName = null; - return null; - } - } - else if (currentExpression is UnaryExpression unaryExpression) // ((BaseEntity)q.Entity).Prop - { - currentExpression = unaryExpression.Operand; - convertType = unaryExpression.Type; - } - else if (currentExpression is NhNominatedExpression nominatedExpression) // ((BaseEntity)q.Entity).Prop - { - currentExpression = nominatedExpression.Expression; - } - else if (currentExpression is ConstantExpression constantExpression) - { - if (!(constantExpression.Value is IEntityNameProvider entityNameProvider)) - { - // Not a NhQueryable - entityName = null; - return null; - } - - entityName = entityNameProvider.EntityName; - break; - } - else if (currentExpression is MethodCallExpression methodCallExpression && - ListIndexerGenerator.IsMethodSupported(methodCallExpression.Method)) - { - currentExpression = methodCallExpression.Object == null - ? Enumerable.First(methodCallExpression.Arguments) // q.Children.ElementAt(0) - : methodCallExpression.Object; // q.Children[0] - hasIndexer = true; - } - else - { - // Not supported expressions - entityName = null; - return null; - } - } - - return memberPaths; - } - private static string GetEntityName( string currentEntityName, System.Type convertedType, @@ -330,6 +256,100 @@ private static IType GetType( return null; } + private class MemberMetadataExtractor : NhExpressionVisitor + { + private readonly Stack _memberPaths = new Stack(); + private System.Type _convertType; + private bool _hasIndexer; + private string _entityName; + + public static Stack TryGetAllMemberMetadata( + Expression expression, + out string entityName, + out System.Type convertType) + { + var extractor = new MemberMetadataExtractor(); + extractor.Accept(expression); + entityName = extractor._entityName; + convertType = entityName != null ? extractor._convertType : null; + return entityName != null ? extractor._memberPaths : null; + } + + private void Accept(Expression expression) + { + base.Visit(expression); + } + + protected override Expression VisitMember(MemberExpression node) + { + _memberPaths.Push(new MemberMetadata(node.Member.Name, _convertType, _hasIndexer)); + _convertType = null; + _hasIndexer = false; + return base.Visit(node.Expression); + } + + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression node) + { + if (node.ReferencedQuerySource is IFromClause fromClause) + { + return base.Visit(fromClause.FromExpression); + } + + if (node.ReferencedQuerySource is JoinClause joinClause) + { + return base.Visit(joinClause.InnerSequence); + } + + // Not supported expression + _entityName = null; + return node; + } + + protected override Expression VisitUnary(UnaryExpression node) + { + _convertType = node.Type; + return base.Visit(node.Operand); + } + + protected internal override Expression VisitNhNominated(NhNominatedExpression node) + { + return base.Visit(node); + } + + protected override Expression VisitConstant(ConstantExpression node) + { + _entityName = node.Value is IEntityNameProvider entityNameProvider + ? entityNameProvider.EntityName + : null; // Not a NhQueryable + + return node; + } + + protected override Expression VisitMethodCall(MethodCallExpression node) + { + if (ListIndexerGenerator.IsMethodSupported(node.Method)) + { + _hasIndexer = true; + return base.Visit( + node.Object == null + ? Enumerable.First(node.Arguments) // q.Children.ElementAt(0) + : node.Object // q.Children[0] + ); + } + + // Not supported expression + _entityName = null; + return node; + } + + public override Expression Visit(Expression node) + { + // Not supported expression + _entityName = null; + return node; + } + } + private struct MemberMetadata { public MemberMetadata(string path, System.Type convertType, bool hasIndexer) From b4027c690351dc5cd6ec0cb483d23a6cc3358d0c Mon Sep 17 00:00:00 2001 From: maca88 Date: Tue, 8 Oct 2019 21:59:30 +0200 Subject: [PATCH 17/29] Replace TryGetEntityName with TryGetMappedType --- src/NHibernate.DomainModel/FooComponent.cs | 2 + .../Northwind/Entities/Address.cs | 4 +- .../Northwind/Entities/IEntity.cs | 15 + .../Northwind/Entities/Product.cs | 4 +- .../Northwind/Entities/User.cs | 6 +- src/NHibernate.Test/Linq/TryGetMappedTests.cs | 734 ++++++++++++++++++ .../Visitors/HqlGeneratorExpressionVisitor.cs | 10 +- src/NHibernate/Util/ExpressionsHelper.cs | 479 +++++++++--- 8 files changed, 1120 insertions(+), 134 deletions(-) create mode 100644 src/NHibernate.DomainModel/Northwind/Entities/IEntity.cs create mode 100644 src/NHibernate.Test/Linq/TryGetMappedTests.cs diff --git a/src/NHibernate.DomainModel/FooComponent.cs b/src/NHibernate.DomainModel/FooComponent.cs index 4bd536eed96..e1c88c7b449 100644 --- a/src/NHibernate.DomainModel/FooComponent.cs +++ b/src/NHibernate.DomainModel/FooComponent.cs @@ -92,6 +92,8 @@ public Int32 Count set { _count = value; } } + public int NotMapped { get; set; } + public DateTime[] ImportantDates { get { return _importantDates; } diff --git a/src/NHibernate.DomainModel/Northwind/Entities/Address.cs b/src/NHibernate.DomainModel/Northwind/Entities/Address.cs index d224bc50cf7..d2d56fd6823 100755 --- a/src/NHibernate.DomainModel/Northwind/Entities/Address.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/Address.cs @@ -61,6 +61,8 @@ public string Fax get { return _fax; } } + public int NotMapped => 1; + public static bool operator ==(Address address1, Address address2) { if (!ReferenceEquals(address1, null) && @@ -114,4 +116,4 @@ public override int GetHashCode() (_fax ?? string.Empty).GetHashCode(); } } -} \ No newline at end of file +} diff --git a/src/NHibernate.DomainModel/Northwind/Entities/IEntity.cs b/src/NHibernate.DomainModel/Northwind/Entities/IEntity.cs new file mode 100644 index 00000000000..52c661c9ec1 --- /dev/null +++ b/src/NHibernate.DomainModel/Northwind/Entities/IEntity.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace NHibernate.DomainModel.Northwind.Entities +{ + public interface IEntity + { + TId Id { get; set; } + } + + public interface IEntity : IEntity + { + } +} diff --git a/src/NHibernate.DomainModel/Northwind/Entities/Product.cs b/src/NHibernate.DomainModel/Northwind/Entities/Product.cs index 5af739cc0d2..8c72b3895d0 100755 --- a/src/NHibernate.DomainModel/Northwind/Entities/Product.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/Product.cs @@ -103,9 +103,11 @@ public virtual float ShippingWeight set { _shippingWeight = value; } } + public virtual int NotMapped => 1; + public virtual ReadOnlyCollection OrderLines { get { return new ReadOnlyCollection(_orderLines); } } } -} \ No newline at end of file +} diff --git a/src/NHibernate.DomainModel/Northwind/Entities/User.cs b/src/NHibernate.DomainModel/Northwind/Entities/User.cs index a2bde32af30..c3f220ffda5 100644 --- a/src/NHibernate.DomainModel/Northwind/Entities/User.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/User.cs @@ -28,7 +28,7 @@ public interface IUser EnumStoredAsInt32 Enum2 { get; set; } } - public class User : IUser + public class User : IUser, IEntity { public virtual int Id { get; set; } @@ -50,6 +50,10 @@ public class User : IUser public virtual EnumStoredAsInt32 Enum2 { get; set; } + public virtual int NotMapped { get; set; } + + public virtual Role NotMappedRole { get; set; } + public User() { } public User(string name, DateTime registeredAt) diff --git a/src/NHibernate.Test/Linq/TryGetMappedTests.cs b/src/NHibernate.Test/Linq/TryGetMappedTests.cs new file mode 100644 index 00000000000..7eea8c8d70d --- /dev/null +++ b/src/NHibernate.Test/Linq/TryGetMappedTests.cs @@ -0,0 +1,734 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using NHibernate.DomainModel; +using NHibernate.DomainModel.NHSpecific; +using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Engine; +using NHibernate.Engine.Query; +using NHibernate.Linq; +using NHibernate.Linq.Visitors; +using NHibernate.Persister.Entity; +using NHibernate.Type; +using NHibernate.Util; +using NUnit.Framework; +using IQueryable = System.Linq.IQueryable; + +namespace NHibernate.Test.Linq +{ + /// + /// Tests form ExpressionsHelper.TryGetMappedType and ExpressionsHelper.TryGetMappedNullability + /// + public class TryGetMappedTests : LinqTestCase + { + private readonly static TryGetMappedType _tryGetMappedType; + private readonly static TryGetMappedNullability _tryGetMappedNullability; + + delegate bool TryGetMappedType( + ISessionFactoryImplementor sessionFactory, + Expression expression, + out IType mappedType, + out IEntityPersister entityPersister, + out IAbstractComponentType component, + out string memberPath); + + delegate bool TryGetMappedNullability( + ISessionFactoryImplementor sessionFactory, + Expression expression, + out bool nullability); + + static TryGetMappedTests() + { + var method = typeof(ExpressionsHelper).GetMethod( + nameof(TryGetMappedType), + BindingFlags.NonPublic | BindingFlags.Static); + var sessionFactoryParam = Expression.Parameter(typeof(ISessionFactoryImplementor), "sessionFactory"); + var expressionParam = Expression.Parameter(typeof(Expression), "expression"); + var mappedTypeParam = Expression.Parameter(typeof(IType).MakeByRefType(), "mappedType"); + var entityPersisterParam = Expression.Parameter(typeof(IEntityPersister).MakeByRefType(), "entityPersister"); + var componentParam = Expression.Parameter(typeof(IAbstractComponentType).MakeByRefType(), "component"); + var memberPathParam = Expression.Parameter(typeof(string).MakeByRefType(), "memberPath"); + var methodCall = Expression.Call( + method, + sessionFactoryParam, + expressionParam, + mappedTypeParam, + entityPersisterParam, + componentParam, + memberPathParam); + _tryGetMappedType = Expression.Lambda( + methodCall, + sessionFactoryParam, + expressionParam, + mappedTypeParam, + entityPersisterParam, + componentParam, + memberPathParam).Compile(); + + method = typeof(ExpressionsHelper).GetMethod( + nameof(TryGetMappedNullability), + BindingFlags.NonPublic | BindingFlags.Static); + var nullabilityParam = Expression.Parameter(typeof(bool).MakeByRefType(), "nullability"); + methodCall = Expression.Call( + method, + sessionFactoryParam, + expressionParam, + nullabilityParam); + _tryGetMappedNullability = Expression.Lambda( + methodCall, + sessionFactoryParam, + expressionParam, + nullabilityParam).Compile(); + } + + protected override string[] Mappings + { + get + { + return + new[] + { + "ABC.hbm.xml", + "Baz.hbm.xml", + "FooBar.hbm.xml", + "Glarch.hbm.xml", + "Fee.hbm.xml", + "Qux.hbm.xml", + "Fum.hbm.xml", + "Holder.hbm.xml", + "One.hbm.xml", + "Many.hbm.xml" + }.Concat(base.Mappings).ToArray(); + } + } + + [Test] + public void SelfTest() + { + var query = db.OrderLines.Select(o => o); + AssertTrueNotNull( + query, + typeof(OrderLine).FullName, + null, + o => o is EntityType entityType && entityType.ReturnedClass == typeof(OrderLine)); + } + + [Test] + public void SelfCastNotMappedTest() + { + var query = session.Query().Select(o => (object) o); + AssertTrueNotNull( + query, + false, + typeof(A).FullName, + null, + o => o is SerializableType serializableType && serializableType.ReturnedClass == typeof(object)); + } + + [Test] + public void PropertyTest() + { + var query = db.OrderLines.Select(o => o.Quantity); + AssertTrueNotNull(query, typeof(OrderLine).FullName, "Quantity", o => o is Int32Type); + } + + [Test] + public void NotMappedPropertyTest() + { + var query = db.Users.Select(o => o.NotMapped); + AssertFalse(query, typeof(User).FullName, "NotMapped", o => o is null); + } + + [Test] + public void NestedNotMappedPropertyTest() + { + var query = db.Users.Select(o => o.Name.Length); + AssertFalse(query, false, null, null, o => o is null); + } + + [Test] + public void PropertyCastTest() + { + var query = db.OrderLines.Select(o => (long) o.Quantity); + AssertTrueNotNull(query, typeof(OrderLine).FullName, "Quantity", o => o is Int64Type); + } + + [Test] + public void PropertyIndexer() + { + var query = db.Products.Select(o => o.Name[0]); + AssertFalse(query, null, null, o => o == null); + } + + [Test] + public void EnumInt32Test() + { + var query = db.Users.Select(o => o.Enum2); + AssertTrueNotNull( + query, + typeof(User).FullName, + "Enum2", + o => o.GetType().GetGenericArguments().FirstOrDefault() == typeof(EnumStoredAsInt32)); + } + + [Test] + public void EnumInt32CastTest() + { + var query = db.Users.Select(o => (int) o.Enum2); + AssertTrueNotNull(query, typeof(User).FullName, "Enum2", o => o is Int32Type); + } + + [Test] + public void EnumAsStringTest() + { + var query = db.Users.Select(o => o.Enum1); + AssertTrue(query, typeof(User).FullName, "Enum1", o => o is EnumStoredAsStringType); + } + + [Test] + public void IdentifierTest() + { + var query = db.OrderLines.Select(o => o.Id); + AssertTrueNotNull(query, typeof(OrderLine).FullName, "Id", o => o is Int64Type); + } + + [Test] + public void CompositeIdentifierTest() + { + var query = session.Query().Select(o => o.Id.Date); + AssertTrueNotNull( + query, + typeof(Fum).FullName, + "Id.Date", + o => o is DateTimeType, + o => o?.Name == "component[String,Short,Date]"); + } + + [Test] + public void ComponentTest() + { + var query = db.Customers.Select(o => o.Address); + AssertTrue( + query, + typeof(Customer).FullName, + "Address", + o => o is ComponentType && o.Name == "component[Street,City,Region,PostalCode,Country,PhoneNumber,Fax]"); + } + + [Test] + public void ComponentPropertyTest() + { + var query = db.Customers.Select(o => o.Address.City); + AssertTrue( + query, + typeof(Customer).FullName, + "Address.City", + o => o is StringType, + o => o?.Name == "component[Street,City,Region,PostalCode,Country,PhoneNumber,Fax]"); + } + + [Test] + public void ComponentNotMappedPropertyTest() + { + var query = db.Customers.Select(o => o.Address.NotMapped); + AssertFalse( + query, + typeof(Customer).FullName, + "Address.NotMapped", + o => o == null, + o => o?.Name == "component[Street,City,Region,PostalCode,Country,PhoneNumber,Fax]"); + } + + [Test] + public void ComponentNestedNotMappedPropertyTest() + { + var query = db.Customers.Select(o => o.Address.City.Length); + AssertFalse(query, false, null, null, o => o == null); + } + + [Test] + public void NestedComponentPropertyTest() + { + var query = db.Users.Select(o => o.Component.OtherComponent.OtherProperty1); + AssertTrue( + query, + typeof(User).FullName, + "Component.OtherComponent.OtherProperty1", + o => o is AnsiStringType, + o => o?.Name == "component[OtherProperty1]"); + } + + [Test] + public void NestedComponentPropertyCastTest() + { + var query = db.Users.Select(o => (object) o.Component.OtherComponent.OtherProperty1); + AssertTrue( + query, + typeof(User).FullName, + "Component.OtherComponent.OtherProperty1", + o => o is SerializableType serializableType && serializableType.ReturnedClass == typeof(object), + o => o?.Name == "component[OtherProperty1]"); + } + + [Test] + public void ManyToOneTest() + { + var query = db.OrderLines.Select(o => o.Order); + AssertTrueNotNull(query, typeof(OrderLine).FullName, "Order", + o => o is ManyToOneType manyToOne && manyToOne.PropertyName == "Order"); + } + + [Test] + public void ManyToOnePropertyTest() + { + var query = db.OrderLines.Select(o => o.Order.Freight); + AssertTrue(query, typeof(Order).FullName, "Freight", o => o is DecimalType); + } + + [Test] + public void ManyToOneNotMappedPropertyTest() + { + var query = db.OrderLines.Select(o => o.Product.NotMapped); + AssertFalse(query, typeof(Product).FullName, "NotMapped", o => o == null); + } + + [Test] + public void NotMappedManyToOnePropertyTest() + { + var query = db.Users.Select(o => o.NotMappedRole.Name); + AssertFalse(query, false, null, null, o => o is null); + } + + [Test] + public void NestedManyToOneTest() + { + var query = db.OrderLines.Select(o => o.Order.Employee); + AssertTrue(query, false, typeof(Order).FullName, "Employee", + o => o is ManyToOneType manyToOne && manyToOne.PropertyName == "Employee"); + } + + [Test] + public void NestedManyToOnePropertyTest() + { + var query = db.OrderLines.Select(o => o.Order.Employee.BirthDate); + AssertTrue(query, typeof(Employee).FullName, "BirthDate", o => o is DateTimeType); + } + + [Test] + public void OneToManyTest() + { + var query = db.Customers.SelectMany(o => o.Orders); + AssertTrue( + query, + typeof(Customer).FullName, + "Orders", + o => o is CollectionType collectionType && collectionType.Role == $"{typeof(Customer).FullName}.Orders"); + } + + [Test] + public void OneToManyElementIndexerTest() + { + var query = session.Query().Select(o => o.StringList[0]); + AssertTrue(query, false, typeof(Baz).FullName, "StringList", o => o is StringType); + } + + [Test] + public void OneToManyElementIndexerNotMappedPropertyTest() + { + var query = session.Query().Select(o => o.StringList[0].Length); + AssertFalse(query, false, null, null, o => o == null); + } + + [Test] + public void OneToManyCustomElementIndexerTest() + { + var query = session.Query().Select(o => o.Customs[0]); + AssertTrue( + query, + false, + typeof(Baz).FullName, + "Customs", + o => o is CompositeCustomType customType && customType.UserType is DoubleStringType); + } + + [Test] + public void OneToManyIndexerCastTest() + { + var query = session.Query().Select(o => (long) o.IntArray[0]); + AssertTrue(query, false, typeof(Baz).FullName, "IntArray", o => o is Int64Type); + } + + [Test] + public void OneToManyIndexerPropertyTest() + { + var query = session.Query().Select(o => o.Fees[0].Count); + AssertTrue(query, false, typeof(Fee).FullName, "Count", o => o is Int32Type); + } + + [Test] + public void OneToManyElementAtTest() + { + var query = session.Query().Select(o => o.StringList.ElementAt(0)); + AssertTrue(query, false, typeof(Baz).FullName, "StringList", o => o is StringType); + } + + [Test] + public void NestedOneToManyManyToOneComponentPropertyTest() + { + var query = session.Query().SelectMany(o => o.Fees).Select(o => o.TheFee.Compon.Name); + AssertTrue( + query, + typeof(Fee).FullName, + "Compon.Name", + o => o is StringType, + o => o?.Name == "component[Name,NullString]"); + } + + [Test] + public void OneToManyCompositeElementPropertyTest() + { + var query = session.Query().Select(o => o.Components[0].Count); + AssertTrue( + query, + false, + null, + "Count", + o => o is Int32Type, + o => o?.Name == "component[Name,Count,Subcomponent]"); + } + + [Test] + public void OneToManyCompositeElementPropertyIndexerTest() + { + var query = session.Query().Select(o => o.Components[0].Name[0]); + AssertFalse(query, false, null, null, o => o == null); + } + + [Test] + public void OneToManyCompositeElementNotMappedPropertyTest() + { + var query = session.Query().Select(o => o.Components[0].NotMapped); + AssertFalse( + query, + false, + null, + "NotMapped", + o => o == null, + o => o?.Name == "component[Name,Count,Subcomponent]"); + } + + [Test] + public void OneToManyCompositeElementCastPropertyTest() + { + var query = session.Query().Select(o => (long) o.Components[0].Count); + AssertTrue( + query, + false, + null, + "Count", + o => o is Int64Type, + o => o?.Name == "component[Name,Count,Subcomponent]"); + } + + [Test] + public void OneToManyCompositeElementCollectionNotMappedPropertyTest() + { + var query = session.Query().SelectMany(o => o.Components[0].ImportantDates); + AssertFalse( + query, + false, + null, + "ImportantDates", + o => o == null, + o => o?.Name == "component[Name,Count,Subcomponent]"); + } + + [Test] + public void NestedOneToManyCompositeElementTest() + { + var query = session.Query().Select(o => o.Components[0].Subcomponent); + AssertTrue( + query, + false, + null, + "Subcomponent", + o => o is IAbstractComponentType componentType && componentType.ReturnedClass == typeof(FooComponent), + o => o?.Name == "component[Name,Count,Subcomponent]"); + } + + [Test] + public void NestedOneToManyCompositeElementPropertyTest() + { + var query = session.Query().Select(o => o.Components[0].Subcomponent.Name); + AssertTrue(query, false, null, "Name", o => o is StringType, o => o?.Name == "component[Name,Count]"); + } + + [Test] + public void NestedOneToManyCompositeElementPropertyIndexerTest() + { + var query = session.Query().Select(o => o.Components[0].Subcomponent.Name[0]); + AssertFalse(query, false, null, null, o => o == null); + } + + [Test] + public void ManyToManyTest() + { + var query = session.Query().Select(o => o.FooArray); + AssertTrue( + query, + false, + typeof(Baz).FullName, + "FooArray", + o => o is ArrayType arrayType && arrayType.Role == $"{typeof(Baz).FullName}.FooArray"); + } + + [Test] + public void ManyToManyIndexerTest() + { + var query = session.Query().Select(o => o.FooArray[0].Null); + AssertTrue(query, false, typeof(Foo).FullName, "Null", o => o is NullableInt32Type); + } + + [Test] + public void SubclassCastTest() + { + var query = session.Query().Select(o => (B) o); + AssertTrueNotNull( + query, + typeof(A).FullName, + null, + o => o is EntityType entityType && entityType.ReturnedClass == typeof(B)); + } + + [Test] + public void NestedSubclassCastTest() + { + var query = session.Query().Select(o => (C1) ((B) o)); + AssertTrueNotNull( + query, + false, + typeof(A).FullName, + null, + o => o is EntityType entityType && entityType.ReturnedClass == typeof(C1)); + } + + [Test] + public void SubclassPropertyTest() + { + var query = session.Query().Select(o => ((C1) o).Count); + AssertTrue(query, typeof(C1).FullName, "Count", o => o is Int32Type); + } + + [Test] + public void NestedSubclassCastPropertyTest() + { + var query = session.Query().Select(o => ((C1) ((B) o)).Id); + AssertTrueNotNull(query, typeof(C1).FullName, "Id", o => o is Int64Type); + } + + [Test] + public void AnyTest() + { + var query = session.Query().Select(o => o.Object); + AssertTrue(query, typeof(Bar).FullName, "Object", o => o.IsAnyType); + } + + [Test] + public void CastAnyTest() + { + var query = session.Query().Select(o => (Foo) o.Object); + AssertTrue( + query, + typeof(Bar).FullName, + "Object", + o => o is EntityType entityType && entityType.ReturnedClass == typeof(Foo)); + } + + [Test] + public void NestedCastAnyTest() + { + var query = session.Query().Select(o => (Foo) ((Bar) o.Object).Object); + AssertTrue( + query, + false, + typeof(Bar).FullName, + "Object", + o => o is EntityType entityType && entityType.ReturnedClass == typeof(Foo)); + } + + [Test] + public void CastAnyManyToOneTest() + { + var query = session.Query().Select(o => ((Foo) o.Object).Dependent); + AssertTrueNotNull( + query, + typeof(Foo).FullName, + "Dependent", + o => o is EntityType entityType && entityType.ReturnedClass == typeof(Fee)); + } + + [Test] + public void CastAnyPropertyTest() + { + var query = session.Query().Select(o => ((Foo) o.Object).String); + AssertTrue(query, false, typeof(Foo).FullName, "String", o => o is StringType); + } + + [Test] + public void QueryUnmppedEntityTest() + { + var query = session.Query>().Select(o => o.Id); + AssertFalse(query, null, null, o => o == null); + } + + [Test] + public void NotSupportedConditionalExpressionTest() + { + var query = db.Users.Select(o => (o.Name == "Test" ? o.RegisteredAt : o.LastLoginDate)); + AssertFalse(query, false, null, null, o => o == null); + } + + [Test] + public void JoinTest() + { + var query = from o in db.Orders + from p in db.Products + join d in db.OrderLines + on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} + into details + from d in details + select d.UnitPrice; + AssertTrueNotNull(query, typeof(OrderLine).FullName, "UnitPrice", o => o is DecimalType); + } + + [Test] + public void NotNullComponentPropertyTest() + { + var query = session.Query().SelectMany(o => o.PatientRecords.Select(r => r.Name.FirstName)); + AssertTrueNotNull( + query, + typeof(PatientRecord).FullName, + "Name.FirstName", + o => o is StringType, + o => o?.Name == "component[FirstName,LastName]"); + } + + private void AssertFalse( + IQueryable query, + string expectedEntityName, + string expectedMemberPath, + Predicate expectedMemberType, + Predicate expectedComponentType = null) + { + AssertResult(query, true, false, expectedEntityName, expectedMemberPath, expectedMemberType, expectedComponentType); + } + + private void AssertFalse( + IQueryable query, + bool rewriteQuery, + string expectedEntityName, + string expectedMemberPath, + Predicate expectedMemberType, + Predicate expectedComponentType = null) + { + AssertResult(query, rewriteQuery, false, expectedEntityName, expectedMemberPath, expectedMemberType, expectedComponentType); + } + + private void AssertTrue( + IQueryable query, + string expectedEntityName, + string expectedMemberPath, + Predicate expectedMemberType, + Predicate expectedComponentType = null) + { + AssertResult(query, true, true, expectedEntityName, expectedMemberPath, expectedMemberType, expectedComponentType); + } + + private void AssertTrue( + IQueryable query, + bool rewriteQuery, + string expectedEntityName, + string expectedMemberPath, + Predicate expectedMemberType, + Predicate expectedComponentType = null) + { + AssertResult(query, rewriteQuery, true, expectedEntityName, expectedMemberPath, expectedMemberType, expectedComponentType); + } + + private void AssertTrueNotNull( + IQueryable query, + string expectedEntityName, + string expectedMemberPath, + Predicate expectedMemberType, + Predicate expectedComponentType = null) + { + AssertResult(query, true, true, expectedEntityName, expectedMemberPath, expectedMemberType, expectedComponentType, false); + } + + private void AssertTrueNotNull( + IQueryable query, + bool rewriteQuery, + string expectedEntityName, + string expectedMemberPath, + Predicate expectedMemberType, + Predicate expectedComponentType = null) + { + AssertResult(query, rewriteQuery, true, expectedEntityName, expectedMemberPath, expectedMemberType, expectedComponentType, false); + } + + private void AssertResult( + IQueryable query, + bool rewriteQuery, + bool result, + string expectedEntityName, + string expectedMemberPath, + Predicate expectedMemberType, + Predicate expectedComponentType = null, + bool nullability = true) + { + expectedComponentType = expectedComponentType ?? (o => o == null); + + var expression = query.Expression; + NhRelinqQueryParser.PreTransform(expression); + var constantToParameterMap = ExpressionParameterVisitor.Visit(expression, Sfi); + var queryModel = NhRelinqQueryParser.Parse(expression); + var requiredHqlParameters = new List(); + var visitorParameters = new VisitorParameters( + Sfi, + constantToParameterMap, + requiredHqlParameters, + new QuerySourceNamer(), + expression.Type, + QueryMode.Select); + if (rewriteQuery) + { + QueryModelVisitor.GenerateHqlQuery( + queryModel, + visitorParameters, + true, + NhLinqExpressionReturnType.Scalar); + } + + var found = _tryGetMappedType( + Sfi, + queryModel.SelectClause.Selector, + out var memberType, + out var entityPersister, + out var componentType, + out var memberPath); + Assert.That(found, Is.EqualTo(result), "Expression should be supported"); + Assert.That(entityPersister?.EntityName, Is.EqualTo(expectedEntityName), "Invalid enity name"); + Assert.That(memberPath, Is.EqualTo(expectedMemberPath), "Invalid member path"); + Assert.That(() => expectedMemberType(memberType), $"Invalid member type: {memberType?.Name ?? "null"}"); + Assert.That(() => expectedComponentType(componentType), $"Invalid component type: {componentType?.Name ?? "null"}"); + + if (found) + { + Assert.That(_tryGetMappedNullability(Sfi, queryModel.SelectClause.Selector, out var isNullable), Is.True, "Expression should be supported"); + Assert.That(nullability, Is.EqualTo(isNullable), "Nullability is not correct"); + } + } + } +} diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index d09730ca757..4e6df61afba 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -666,11 +666,11 @@ private bool IsCastRequired(string sqlFunctionName, Expression argumentExpressio private IType GetType(Expression expression) { // Try to get the mapped type for the member as it may be a non default one - ExpressionsHelper.TryGetEntityName(_parameters.SessionFactory, expression, out _, out var type); - return type ?? - (expression.Type != typeof(object) - ? TypeFactory.GetDefaultTypeFor(expression.Type) - : null); + return expression.Type == typeof(object) + ? null + : (ExpressionsHelper.TryGetMappedType(_parameters.SessionFactory, expression, out var type, out _, out _, out _) + ? type + : TypeFactory.GetDefaultTypeFor(expression.Type)); } } } diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index 34c7e5ad599..91d1def2fce 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -28,165 +28,348 @@ public static MemberInfo DecodeMemberAccessExpression(Expressi return ((MemberExpression)expression.Body).Member; } - internal static string TryGetEntityName( + /// + /// Try to get the mapped nullability from the given expression. + /// + /// The session factory. + /// The expression to evaluate. + /// Output parameter that represents whether the is nullable. + /// Whether the mapped nullability was found. + internal static bool TryGetMappedNullability( ISessionFactoryImplementor sessionFactory, Expression expression, - out string memberPath, - out IType memberType) + out bool nullable) + { + if (!TryGetMappedType( + sessionFactory, + expression, + out _, + out var entityPersister, + out var componentType, + out var memberPath)) + { + nullable = false; + return false; + } + + // The source entity is always not null, as it gets translated to the entity identifier + if (memberPath == null) + { + nullable = false; + return true; + } + + int index; + if (componentType != null) + { + index = Array.IndexOf( + componentType.PropertyNames, + memberPath.Substring(memberPath.LastIndexOf('.') + 1)); + nullable = componentType.PropertyNullability[index]; + return true; + } + + if (entityPersister.EntityMetamodel.GetIdentifierPropertyType(memberPath) != null) + { + nullable = false; // Identifier is always not-null + return true; + } + + index = entityPersister.EntityMetamodel.GetPropertyIndex(memberPath); + nullable = entityPersister.PropertyNullability[index]; + return true; + } + + /// + /// Try to get the mapped type from the given expression. When the type is + /// , the will be set based on the expression type + /// only when the mapping for was found, otherwise + /// will be returned. + /// + /// The session factory to retrieve types. + /// The expression to evaluate. + /// Output parameter that represents the mapped type of . + /// + /// Output parameter that represents the entity persister of the entity where is defined. + /// This parameter will not be set when represents a property in a collection composite element. + /// + /// + /// Output parameter that represents the component type where is defined. + /// This parameter will not be set when does not represent a property in a component. + /// + /// + /// Output parameter that represents the path of the mapped member, which in most cases is the member name. In case + /// when the mapped member is defined inside a component the path will be prefixed with the name of the component member and a dot. + /// (e.g. Component.Property). + /// Whether the mapped type was found. + /// + /// When the contains an expression of type , the + /// result may not be correct when casting to an entity that is mapped with multiple entity names. + /// + internal static bool TryGetMappedType( + ISessionFactoryImplementor sessionFactory, + Expression expression, + out IType mappedType, + out IEntityPersister entityPersister, + out IAbstractComponentType component, + out string memberPath) { + // In order to get the correct entity name from the expression we first have to find the constant expression that contains the + // IEntityNameProvider instance, from which we can retrive the starting entity name. Once we have it, we have to traverse all + // expressions that we had to travese in order to find the IEntityNameProvider instance, but in reverse order (bottom to top) + // and keep tracking the entity name until we reach to top. + + // Try to retrive the starting entity name with all members that were traversed in that process var memberPaths = MemberMetadataExtractor.TryGetAllMemberMetadata(expression, out var entityName, out var convertType); if (memberPaths == null) { + // Failed to find the starting entity name, due to an unsupported expression or the expression didn't contain + // the IEntityNameProvider instance memberPath = null; - memberType = null; - return null; + mappedType = null; + entityPersister = null; + component = null; + return false; } - entityName = GetEntityName(entityName, convertType, sessionFactory, out var persister); - if (entityName == null || memberPaths.Count == 0) // ((NotMapped)q).Prop || q + if (!TryGetEntityPersister(entityName, null, sessionFactory, out var currentEntityPersister)) { + // Querying an unmapped type e.g. s.Query().Where(a => a.Type == "A") + memberPath = null; + mappedType = null; + entityPersister = null; + component = null; + return false; + } + + if (memberPaths.Count == 0) // The expression do not contain any member expressions + { + if (convertType != null) + { + mappedType = TryGetEntityPersister(currentEntityPersister, convertType, sessionFactory, out var convertPersister) + ? convertPersister.EntityMetamodel.EntityType // ((Subclass)q) + : TypeFactory.GetDefaultTypeFor(convertType); // ((NotMapped)q) + } + else + { + mappedType = currentEntityPersister.EntityMetamodel.EntityType; // q + } + memberPath = null; - memberType = null; - return entityName; + component = null; + entityPersister = currentEntityPersister; + return mappedType != null; } + // If there was a cast right after the constant expression that contains the IEntityNameProvider instance, we have + // to update the entity persister according to it, otherwise use the value returned by TryGetAllMemberMetadata method. + if (convertType != null) + { + if (!TryGetEntityPersister(currentEntityPersister, convertType, sessionFactory, out var convertPersister)) // ((NotMapped)q).Id + { + memberPath = null; + mappedType = null; + entityPersister = null; + component = null; + return false; + } + else + { + currentEntityPersister = convertPersister; // ((Subclass)q).Id + } + } + + // Traverse the members that were traversed by the TryGetAllMemberMetadata method in the reverse order and try to keep + // tracking the entity persister until all members are traversed. var member = memberPaths.Pop(); - var type = persister.EntityMetamodel.GetPropertyType(member.Path); + var currentType = currentEntityPersister.EntityMetamodel.GetPropertyType(member.Path); + IAbstractComponentType currentComponentType = null; while (true) { - if (type == null) // q.NotMappedProp + // When traversed to the top of the expression, return the current tracking values + if (memberPaths.Count == 0) + { + memberPath = currentEntityPersister != null || currentComponentType != null ? member.Path : null; + mappedType = GetType(currentEntityPersister, currentType, member, sessionFactory, out _); + entityPersister = currentEntityPersister; + component = currentComponentType; + return mappedType != null; + } + + if (currentType == null) // Member not mapped { memberPath = null; - memberType = null; - return entityName; + mappedType = null; + entityPersister = null; + component = null; + return false; } - if (memberPaths.Count == 0) // q.ManyToOne || q.OneToMany || q.OneToMany[0] || q.Component || q.Prop || q.AnyType + convertType = member.ConvertType; + // Concatenate the component property path in order to be able to use EntityMetamodel.GetPropertyType to retrieve the type. + // As GetPropertyType supports only components, do not concatenate when dealing with collection composite elements or elements. + if (!currentType.IsAnyType && currentType is IAbstractComponentType) + { + var nextMember = memberPaths.Pop(); + member = currentEntityPersister == null // Collection with composite element or element + ? nextMember + : new MemberMetadata($"{member.Path}.{nextMember.Path}", nextMember.ConvertType, nextMember.HasIndexer); + } + else { - memberPath = member.Path; - memberType = GetType(entityName, type, memberPath, member.HasIndexer, sessionFactory); - return entityName; + member = memberPaths.Pop(); } - if (type is IAssociationType associationType) + switch (currentType) { - if (associationType.IsCollectionType) - { - // Check manually for entity association as GetAssociatedEntityName throws when there is none. - var queryableCollection = - (IQueryableCollection) associationType.GetAssociatedJoinable(sessionFactory); - if (!queryableCollection.ElementType.IsEntityType) // q.OneToMany[0].CompositeElement - { - entityName = null; - persister = null; - // Ignore if the type is casted as composite elements cannot be casted to entities - type = queryableCollection.ElementType; - member = memberPaths.Pop(); - continue; - } - - // q.OneToMany[0].Prop - entityName = GetEntityName( - associationType.GetAssociatedEntityName(sessionFactory), - member.ConvertType, - sessionFactory, - out persister); - } - else if (associationType.IsAnyType) - { - // ((Address)q.AnyType).Prop || q.AnyType.Prop - // Unfortunately we cannot detect the exact entity name as cast does not provide it, - // so the only option is to guess it. - entityName = GetEntityName(member.ConvertType, sessionFactory, out persister); - } - else // q.ManyToOne.Prop - { - entityName = GetEntityName( - associationType.GetAssociatedEntityName(sessionFactory), - member.ConvertType, + case IAssociationType associationType: + ProcessAssociationType( + associationType, sessionFactory, - out persister); - } + member, + convertType, + out currentType, + out currentEntityPersister, + out currentComponentType); + break; + case IAbstractComponentType componentType: + currentComponentType = componentType; + ProcessComponentType(componentType, currentEntityPersister, member, out currentType); + break; + default: + // q.Prop.NotMappedProp + currentType = null; + currentEntityPersister = null; + currentComponentType = null; + break; + } + } + } - if (entityName == null) // q.AnyType.Prop || ((NotMappedClass)q.ManyToOne).Prop - { - memberPath = null; - memberType = null; - return null; - } + private static void ProcessComponentType( + IAbstractComponentType componentType, + IEntityPersister persister, + MemberMetadata member, + out IType memberType) + { + // When persister is not available (q.OneToManyCompositeElement[0].Prop), try to get the type from the component + if (persister == null) + { + var index = Array.IndexOf(componentType.PropertyNames, member.Path); + memberType = index < 0 + ? null // q.OneToManyCompositeElement[0].NotMappedProp + : componentType.Subtypes[index]; // q.OneToManyCompositeElement[0].Prop + return; + } - member = memberPaths.Pop(); - type = persister.EntityMetamodel.GetPropertyType(member.Path); - } - else if (type is IAbstractComponentType componentType) + // q.Component.Prop + memberType = persister.EntityMetamodel.GetPropertyType(member.Path); + } + + private static void ProcessAssociationType( + IAssociationType associationType, + ISessionFactoryImplementor sessionFactory, + MemberMetadata member, + System.Type convertType, + out IType memberType, + out IEntityPersister memberPersister, + out IAbstractComponentType memberComponent) + { + if (associationType.IsCollectionType) + { + // Check manually for entity association as GetAssociatedEntityName throws when there is none. + var queryableCollection = + (IQueryableCollection) associationType.GetAssociatedJoinable(sessionFactory); + if (!queryableCollection.ElementType.IsEntityType) // q.OneToManyCompositeElement[0].Member, q.OneToManyElement[0].Member { - // q.OneToMany[0].CompositeElement.Prop - if (entityName == null) + memberPersister = null; + // Can be or + switch (queryableCollection.ElementType) { - var index = Array.IndexOf(componentType.PropertyNames, member.Path); - if (index < 0) - { - memberPath = null; + case IAbstractComponentType componentType: // q.OneToManyCompositeElement[0].Member + memberComponent = componentType; + ProcessComponentType(componentType, null, member, out memberType); + return; + default: // q.OneToManyElement[0].Member memberType = null; - return null; - } - - type = componentType.Subtypes[index]; - continue; + memberComponent = null; + return; } - - // q.Component.Prop - // Ignore if the type is casted as components cannot be casted to entities - var componentMember = memberPaths.Pop(); - member = new MemberMetadata( - $"{member.Path}.{componentMember.Path}", - componentMember.ConvertType, - componentMember.HasIndexer); - type = persister.EntityMetamodel.GetPropertyType(member.Path); - } - else - { - // q.Prop.NotMappedProp - memberPath = null; - memberType = null; - return entityName; } + + // q.OneToMany[0].Member + TryGetEntityPersister( + associationType.GetAssociatedEntityName(sessionFactory), + convertType, + sessionFactory, + out memberPersister); + } + else if (associationType.IsAnyType) + { + // ((Address)q.AnyType).Member, q.AnyType.Member + // Unfortunately we cannot detect the exact entity name as cast does not provide it, + // so the only option is to guess it. + TryGetEntityPersister(convertType, sessionFactory, out memberPersister); } + else // q.ManyToOne.Member + { + TryGetEntityPersister( + associationType.GetAssociatedEntityName(sessionFactory), + convertType, + sessionFactory, + out memberPersister); + } + + memberComponent = null; + memberType = memberPersister != null + ? memberPersister.EntityMetamodel.GetPropertyType(member.Path) + : null; // q.AnyType.Member, ((NotMappedClass)q.ManyToOne) } - private static string GetEntityName( + private static bool TryGetEntityPersister( string currentEntityName, System.Type convertedType, ISessionFactoryImplementor sessionFactory, out IEntityPersister persister) { - persister = sessionFactory.TryGetEntityPersister(currentEntityName); - if (persister == null) + var currentEntityPersister = sessionFactory.TryGetEntityPersister(currentEntityName); + if (currentEntityPersister == null) { - return null; // Querying an unmapped interface e.g. s.Query().Where(a => a.Type == "A") + persister = null; + return false; // Querying an unmapped interface e.g. s.Query().Where(a => a.Type == "A") } + return TryGetEntityPersister(currentEntityPersister, convertedType, sessionFactory, out persister); + } + + private static bool TryGetEntityPersister( + IEntityPersister currentEntityPersister, + System.Type convertedType, + ISessionFactoryImplementor sessionFactory, + out IEntityPersister persister) + { if (convertedType == null) { - return currentEntityName; + persister = currentEntityPersister; + return true; } - if (persister.EntityMetamodel.HasSubclasses) + if (currentEntityPersister.EntityMetamodel.HasSubclasses) { // When a class is casted to a subclass e.g. ((PizzaOrder)c.Order).PizzaName, we // can only guess the entity name of it, as there can be many entity names mapped // to the same subclass. - persister = persister.EntityMetamodel.SubclassEntityNames + persister = currentEntityPersister.EntityMetamodel.SubclassEntityNames .Select(sessionFactory.GetEntityPersister) .FirstOrDefault(p => p.MappedClass == convertedType); - return persister?.EntityName; + return persister != null; } - return GetEntityName(convertedType, sessionFactory, out persister); + return TryGetEntityPersister(convertedType, sessionFactory, out persister); } - private static string GetEntityName( + private static bool TryGetEntityPersister( System.Type convertedType, ISessionFactoryImplementor sessionFactory, out IEntityPersister persister) @@ -194,65 +377,84 @@ private static string GetEntityName( if (convertedType == null) { persister = null; - return null; + return false; } var entityName = sessionFactory.TryGetGuessEntityName(convertedType); if (entityName == null) { persister = null; - return null; + return false; } persister = sessionFactory.GetEntityPersister(entityName); - return entityName; + return true; } private static IType GetType( - string entityName, + IEntityPersister currentEntityPersister, IType currentType, - string memberPath, - bool hasIndexer, - ISessionFactoryImplementor sessionFactory) + MemberMetadata member, + ISessionFactoryImplementor sessionFactory, + out IEntityPersister persister) { - if (entityName == null && memberPath == null) + // Not mapped + if (currentType == null) { + persister = null; return null; } - if (entityName == null) + // Collection composite elements + if (currentEntityPersister == null) { - // q.OneToMany[0].CompositeElement.Prop - if (currentType is IAbstractComponentType componentType) + if (member.ConvertType != null) { - var names = componentType.PropertyNames; - var index = Array.IndexOf(names, memberPath); - return index < 0 - ? null - : componentType.Subtypes[index]; + return TryGetEntityPersister(member.ConvertType, sessionFactory, out persister) + ? persister.EntityMetamodel.EntityType // (Entity)q.OneToManyCompositeElement[0].Prop + : TypeFactory.GetDefaultTypeFor(member.ConvertType); // (long)q.OneToManyCompositeElement[0].Prop } - return null; + persister = null; + return currentType; } - if (memberPath == null) // q.NotMappedProp + if (!member.HasIndexer) { - return null; - } + if (member.ConvertType != null) + { + persister = TryGetEntityPersister(member.ConvertType, sessionFactory, out var newPersister) + ? newPersister + : currentEntityPersister; + return newPersister != null + ? persister.EntityMetamodel.EntityType // (Entity)q.Prop + : TypeFactory.GetDefaultTypeFor(member.ConvertType); // (long)q.Prop + } - if (!hasIndexer) // q.Prop - { - return currentType; + persister = currentEntityPersister; + return currentType; // q.Prop } // q.OneToMany[0] if (currentType is IAssociationType associationType) { var queryableCollection = (IQueryableCollection) associationType.GetAssociatedJoinable(sessionFactory); + if (member.ConvertType != null) + { + return TryGetEntityPersister(member.ConvertType, sessionFactory, out persister) + ? persister.EntityMetamodel.EntityType // (Entity)q.Prop + : TypeFactory.GetDefaultTypeFor(member.ConvertType); // (long)q.Prop + } + + persister = queryableCollection.ElementType.IsEntityType + ? queryableCollection.ElementPersister + : null; + return queryableCollection.ElementType; } // q.Prop[0] + persister = null; return null; } @@ -263,6 +465,15 @@ private class MemberMetadataExtractor : NhExpressionVisitor private bool _hasIndexer; private string _entityName; + /// + /// Traverses the expression from top to bottom until the first containing an IEntityNameProvider instance is found. + /// + /// The expression to travese. + /// An output parameter that will be populated by the first that is found or null otherwise. + /// An output parameter that will be populated only when containing an IEntityNameProvider + /// is followed by an . + /// A stack of information about all that were traversed until the first containing an + /// IEntityNameProvider instance is found or null when it was not found or if one of the expressions is not supported. public static Stack TryGetAllMemberMetadata( Expression expression, out string entityName, @@ -307,13 +518,18 @@ protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpr protected override Expression VisitUnary(UnaryExpression node) { - _convertType = node.Type; + // Store only the outermost cast, when there are multiple casts for the same member + if (_convertType == null) + { + _convertType = node.Type; + } + return base.Visit(node.Operand); } protected internal override Expression VisitNhNominated(NhNominatedExpression node) { - return base.Visit(node); + return base.Visit(node.Expression); } protected override Expression VisitConstant(ConstantExpression node) @@ -325,6 +541,17 @@ protected override Expression VisitConstant(ConstantExpression node) return node; } + protected override Expression VisitBinary(BinaryExpression node) + { + if (node.NodeType == ExpressionType.ArrayIndex) + { + _hasIndexer = true; + return base.Visit(node.Left); + } + + return base.VisitBinary(node); + } + protected override Expression VisitMethodCall(MethodCallExpression node) { if (ListIndexerGenerator.IsMethodSupported(node.Method)) From 7ee50d712247f44da6451e08feb27151fedc959a Mon Sep 17 00:00:00 2001 From: maca88 Date: Tue, 8 Oct 2019 22:15:06 +0200 Subject: [PATCH 18/29] Merge two conditions to reduce complexity --- src/NHibernate/Util/ExpressionsHelper.cs | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index 91d1def2fce..19d1a249e3b 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -121,20 +121,12 @@ internal static bool TryGetMappedType( // Try to retrive the starting entity name with all members that were traversed in that process var memberPaths = MemberMetadataExtractor.TryGetAllMemberMetadata(expression, out var entityName, out var convertType); - if (memberPaths == null) + if (memberPaths == null || !TryGetEntityPersister(entityName, null, sessionFactory, out var currentEntityPersister)) { - // Failed to find the starting entity name, due to an unsupported expression or the expression didn't contain - // the IEntityNameProvider instance - memberPath = null; - mappedType = null; - entityPersister = null; - component = null; - return false; - } - - if (!TryGetEntityPersister(entityName, null, sessionFactory, out var currentEntityPersister)) - { - // Querying an unmapped type e.g. s.Query().Where(a => a.Type == "A") + // Failed to find the starting entity name, due to: + // - Unsupported expression + // - The expression didn't contain the IEntityNameProvider instance + // - Querying an unmapped type e.g. s.Query().Where(a => a.Type == "A") memberPath = null; mappedType = null; entityPersister = null; From ae7a80cac67a81d836c19efcef92fae6cbdfd97b Mon Sep 17 00:00:00 2001 From: maca88 Date: Thu, 10 Oct 2019 18:14:28 +0200 Subject: [PATCH 19/29] Split TryGetMappedType into two methods --- src/NHibernate/Util/ExpressionsHelper.cs | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index 19d1a249e3b..d1088ba57bf 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -171,8 +171,28 @@ internal static bool TryGetMappedType( } } + return TraverseMembers( + sessionFactory, + memberPaths, + currentEntityPersister, + out mappedType, + out entityPersister, + out component, + out memberPath); + } + + private static bool TraverseMembers( + ISessionFactoryImplementor sessionFactory, + Stack memberPaths, + IEntityPersister currentEntityPersister, + out IType mappedType, + out IEntityPersister entityPersister, + out IAbstractComponentType component, + out string memberPath) + { // Traverse the members that were traversed by the TryGetAllMemberMetadata method in the reverse order and try to keep // tracking the entity persister until all members are traversed. + System.Type convertType; var member = memberPaths.Pop(); var currentType = currentEntityPersister.EntityMetamodel.GetPropertyType(member.Path); IAbstractComponentType currentComponentType = null; From 256f5377f003ca44efb5c1b63e622c87ac97a22a Mon Sep 17 00:00:00 2001 From: Alexander Zaytsev Date: Fri, 11 Oct 2019 16:39:56 +1300 Subject: [PATCH 20/29] Avoid while(true) --- src/NHibernate/Util/ExpressionsHelper.cs | 41 +++++++++++------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index d1088ba57bf..74d64097b2f 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -192,32 +192,12 @@ private static bool TraverseMembers( { // Traverse the members that were traversed by the TryGetAllMemberMetadata method in the reverse order and try to keep // tracking the entity persister until all members are traversed. - System.Type convertType; var member = memberPaths.Pop(); var currentType = currentEntityPersister.EntityMetamodel.GetPropertyType(member.Path); IAbstractComponentType currentComponentType = null; - while (true) + while (memberPaths.Count > 0 && currentType != null) { - // When traversed to the top of the expression, return the current tracking values - if (memberPaths.Count == 0) - { - memberPath = currentEntityPersister != null || currentComponentType != null ? member.Path : null; - mappedType = GetType(currentEntityPersister, currentType, member, sessionFactory, out _); - entityPersister = currentEntityPersister; - component = currentComponentType; - return mappedType != null; - } - - if (currentType == null) // Member not mapped - { - memberPath = null; - mappedType = null; - entityPersister = null; - component = null; - return false; - } - - convertType = member.ConvertType; + var convertType = member.ConvertType; // Concatenate the component property path in order to be able to use EntityMetamodel.GetPropertyType to retrieve the type. // As GetPropertyType supports only components, do not concatenate when dealing with collection composite elements or elements. if (!currentType.IsAnyType && currentType is IAbstractComponentType) @@ -256,6 +236,23 @@ private static bool TraverseMembers( break; } } + + // When traversed to the top of the expression, return the current tracking values + if (memberPaths.Count == 0) + { + memberPath = currentEntityPersister != null || currentComponentType != null ? member.Path : null; + mappedType = GetType(currentEntityPersister, currentType, member, sessionFactory, out _); + entityPersister = currentEntityPersister; + component = currentComponentType; + return mappedType != null; + } + + // Member not mapped + memberPath = null; + mappedType = null; + entityPersister = null; + component = null; + return false; } private static void ProcessComponentType( From 418a5887c9a49fe565e69f234fe8466250bcc43e Mon Sep 17 00:00:00 2001 From: Alexander Zaytsev Date: Fri, 11 Oct 2019 16:42:38 +1300 Subject: [PATCH 21/29] Move getting next member into swith-case blocks --- src/NHibernate/Util/ExpressionsHelper.cs | 28 +++++++++++++----------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index 74d64097b2f..12a085f3d00 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -198,23 +198,11 @@ private static bool TraverseMembers( while (memberPaths.Count > 0 && currentType != null) { var convertType = member.ConvertType; - // Concatenate the component property path in order to be able to use EntityMetamodel.GetPropertyType to retrieve the type. - // As GetPropertyType supports only components, do not concatenate when dealing with collection composite elements or elements. - if (!currentType.IsAnyType && currentType is IAbstractComponentType) - { - var nextMember = memberPaths.Pop(); - member = currentEntityPersister == null // Collection with composite element or element - ? nextMember - : new MemberMetadata($"{member.Path}.{nextMember.Path}", nextMember.ConvertType, nextMember.HasIndexer); - } - else - { - member = memberPaths.Pop(); - } switch (currentType) { case IAssociationType associationType: + member = memberPaths.Pop(); ProcessAssociationType( associationType, sessionFactory, @@ -225,10 +213,24 @@ private static bool TraverseMembers( out currentComponentType); break; case IAbstractComponentType componentType: + // Concatenate the component property path in order to be able to use EntityMetamodel.GetPropertyType to retrieve the type. + // As GetPropertyType supports only components, do not concatenate when dealing with collection composite elements or elements. + if (!currentType.IsAnyType) + { + var nextMember = memberPaths.Pop(); + member = currentEntityPersister == null // Collection with composite element or element + ? nextMember + : new MemberMetadata($"{member.Path}.{nextMember.Path}", nextMember.ConvertType, nextMember.HasIndexer); + } + else + { + member = memberPaths.Pop(); + } currentComponentType = componentType; ProcessComponentType(componentType, currentEntityPersister, member, out currentType); break; default: + member = memberPaths.Pop(); // q.Prop.NotMappedProp currentType = null; currentEntityPersister = null; From d2776894ec6320cedd6363979553623059e6c388 Mon Sep 17 00:00:00 2001 From: Alexander Zaytsev Date: Fri, 11 Oct 2019 17:54:58 +1300 Subject: [PATCH 22/29] Inline ProcessComponentType --- src/NHibernate/Util/ExpressionsHelper.cs | 56 +++++++++++------------- 1 file changed, 26 insertions(+), 30 deletions(-) diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index 12a085f3d00..185927d2df7 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -197,12 +197,13 @@ private static bool TraverseMembers( IAbstractComponentType currentComponentType = null; while (memberPaths.Count > 0 && currentType != null) { + memberPath = member.Path; var convertType = member.ConvertType; + member = memberPaths.Pop(); switch (currentType) { case IAssociationType associationType: - member = memberPaths.Pop(); ProcessAssociationType( associationType, sessionFactory, @@ -213,24 +214,31 @@ private static bool TraverseMembers( out currentComponentType); break; case IAbstractComponentType componentType: - // Concatenate the component property path in order to be able to use EntityMetamodel.GetPropertyType to retrieve the type. - // As GetPropertyType supports only components, do not concatenate when dealing with collection composite elements or elements. - if (!currentType.IsAnyType) + currentComponentType = componentType; + if (currentEntityPersister == null) { - var nextMember = memberPaths.Pop(); - member = currentEntityPersister == null // Collection with composite element or element - ? nextMember - : new MemberMetadata($"{member.Path}.{nextMember.Path}", nextMember.ConvertType, nextMember.HasIndexer); + // When persister is not available (q.OneToManyCompositeElement[0].Prop), try to get the type from the component + currentType = TryGetComponentPropertyType(componentType, member.Path); } else { - member = memberPaths.Pop(); + if (!currentType.IsAnyType) + { + // Concatenate the component property path in order to be able to use EntityMetamodel.GetPropertyType to retrieve the type. + // As GetPropertyType supports only components, do not concatenate when dealing with collection composite elements or elements. + // q.Component.Prop + member = new MemberMetadata( + $"{memberPath}.{member.Path}", + member.ConvertType, + member.HasIndexer); + } + + // q.Component.Prop + currentType = currentEntityPersister.EntityMetamodel.GetPropertyType(member.Path); } - currentComponentType = componentType; - ProcessComponentType(componentType, currentEntityPersister, member, out currentType); + break; default: - member = memberPaths.Pop(); // q.Prop.NotMappedProp currentType = null; currentEntityPersister = null; @@ -257,24 +265,12 @@ private static bool TraverseMembers( return false; } - private static void ProcessComponentType( - IAbstractComponentType componentType, - IEntityPersister persister, - MemberMetadata member, - out IType memberType) + private static IType TryGetComponentPropertyType(IAbstractComponentType componentType, string memberPath) { - // When persister is not available (q.OneToManyCompositeElement[0].Prop), try to get the type from the component - if (persister == null) - { - var index = Array.IndexOf(componentType.PropertyNames, member.Path); - memberType = index < 0 - ? null // q.OneToManyCompositeElement[0].NotMappedProp - : componentType.Subtypes[index]; // q.OneToManyCompositeElement[0].Prop - return; - } - - // q.Component.Prop - memberType = persister.EntityMetamodel.GetPropertyType(member.Path); + var index = Array.IndexOf(componentType.PropertyNames, memberPath); + return index < 0 + ? null // q.OneToManyCompositeElement[0].NotMappedProp + : componentType.Subtypes[index]; // q.OneToManyCompositeElement[0].Prop } private static void ProcessAssociationType( @@ -299,7 +295,7 @@ private static void ProcessAssociationType( { case IAbstractComponentType componentType: // q.OneToManyCompositeElement[0].Member memberComponent = componentType; - ProcessComponentType(componentType, null, member, out memberType); + memberType = TryGetComponentPropertyType(componentType, member.Path); return; default: // q.OneToManyElement[0].Member memberType = null; From a49ddefa2bdc27531a47dc63b71a0b36d4efbdcf Mon Sep 17 00:00:00 2001 From: Alexander Zaytsev Date: Fri, 11 Oct 2019 23:57:34 +1300 Subject: [PATCH 23/29] Fix typos --- src/NHibernate/Util/ExpressionsHelper.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index 185927d2df7..b83a8c2e23f 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -115,11 +115,11 @@ internal static bool TryGetMappedType( out string memberPath) { // In order to get the correct entity name from the expression we first have to find the constant expression that contains the - // IEntityNameProvider instance, from which we can retrive the starting entity name. Once we have it, we have to traverse all - // expressions that we had to travese in order to find the IEntityNameProvider instance, but in reverse order (bottom to top) + // IEntityNameProvider instance, from which we can retrieve the starting entity name. Once we have it, we have to traverse all + // expressions that we had to traverse in order to find the IEntityNameProvider instance, but in reverse order (bottom to top) // and keep tracking the entity name until we reach to top. - // Try to retrive the starting entity name with all members that were traversed in that process + // Try to retrieve the starting entity name with all members that were traversed in that process var memberPaths = MemberMetadataExtractor.TryGetAllMemberMetadata(expression, out var entityName, out var convertType); if (memberPaths == null || !TryGetEntityPersister(entityName, null, sessionFactory, out var currentEntityPersister)) { From 68f5407d846ebd804b8030502133e96d9b0d7cb3 Mon Sep 17 00:00:00 2001 From: maca88 Date: Fri, 11 Oct 2019 20:13:39 +0200 Subject: [PATCH 24/29] Small corrections --- src/NHibernate.Test/Linq/TryGetMappedTests.cs | 4 +-- src/NHibernate/Util/ExpressionsHelper.cs | 25 ++++++++----------- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/src/NHibernate.Test/Linq/TryGetMappedTests.cs b/src/NHibernate.Test/Linq/TryGetMappedTests.cs index 7eea8c8d70d..e46c57989ab 100644 --- a/src/NHibernate.Test/Linq/TryGetMappedTests.cs +++ b/src/NHibernate.Test/Linq/TryGetMappedTests.cs @@ -23,8 +23,8 @@ namespace NHibernate.Test.Linq /// public class TryGetMappedTests : LinqTestCase { - private readonly static TryGetMappedType _tryGetMappedType; - private readonly static TryGetMappedNullability _tryGetMappedNullability; + private static readonly TryGetMappedType _tryGetMappedType; + private static readonly TryGetMappedNullability _tryGetMappedNullability; delegate bool TryGetMappedType( ISessionFactoryImplementor sessionFactory, diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index b83a8c2e23f..0c8bbc84158 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -165,10 +165,8 @@ internal static bool TryGetMappedType( component = null; return false; } - else - { - currentEntityPersister = convertPersister; // ((Subclass)q).Id - } + + currentEntityPersister = convertPersister; // ((Subclass)q).Id } return TraverseMembers( @@ -222,16 +220,13 @@ private static bool TraverseMembers( } else { - if (!currentType.IsAnyType) - { - // Concatenate the component property path in order to be able to use EntityMetamodel.GetPropertyType to retrieve the type. - // As GetPropertyType supports only components, do not concatenate when dealing with collection composite elements or elements. - // q.Component.Prop - member = new MemberMetadata( - $"{memberPath}.{member.Path}", - member.ConvertType, - member.HasIndexer); - } + // Concatenate the component property path in order to be able to use EntityMetamodel.GetPropertyType to retrieve the type. + // As GetPropertyType supports only components, do not concatenate when dealing with collection composite elements or elements. + // q.Component.Prop + member = new MemberMetadata( + $"{memberPath}.{member.Path}", + member.ConvertType, + member.HasIndexer); // q.Component.Prop currentType = currentEntityPersister.EntityMetamodel.GetPropertyType(member.Path); @@ -475,7 +470,7 @@ private class MemberMetadataExtractor : NhExpressionVisitor /// /// Traverses the expression from top to bottom until the first containing an IEntityNameProvider instance is found. /// - /// The expression to travese. + /// The expression to traverse. /// An output parameter that will be populated by the first that is found or null otherwise. /// An output parameter that will be populated only when containing an IEntityNameProvider /// is followed by an . From 315b5734d579af5343313d21eeddeb2209460c33 Mon Sep 17 00:00:00 2001 From: maca88 Date: Sun, 13 Oct 2019 13:36:37 +0200 Subject: [PATCH 25/29] Remove unneeded GetType parameter --- src/NHibernate/Util/ExpressionsHelper.cs | 61 ++++++++---------------- 1 file changed, 19 insertions(+), 42 deletions(-) diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index 0c8bbc84158..0092c8ea8bf 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -246,7 +246,7 @@ private static bool TraverseMembers( if (memberPaths.Count == 0) { memberPath = currentEntityPersister != null || currentComponentType != null ? member.Path : null; - mappedType = GetType(currentEntityPersister, currentType, member, sessionFactory, out _); + mappedType = GetType(currentEntityPersister, currentType, member, sessionFactory); entityPersister = currentEntityPersister; component = currentComponentType; return mappedType != null; @@ -397,67 +397,44 @@ private static IType GetType( IEntityPersister currentEntityPersister, IType currentType, MemberMetadata member, - ISessionFactoryImplementor sessionFactory, - out IEntityPersister persister) + ISessionFactoryImplementor sessionFactory) { // Not mapped if (currentType == null) { - persister = null; return null; } - // Collection composite elements - if (currentEntityPersister == null) + IEntityPersister persister; + if (!member.HasIndexer || currentEntityPersister == null) { - if (member.ConvertType != null) + if (member.ConvertType == null) { - return TryGetEntityPersister(member.ConvertType, sessionFactory, out persister) - ? persister.EntityMetamodel.EntityType // (Entity)q.OneToManyCompositeElement[0].Prop - : TypeFactory.GetDefaultTypeFor(member.ConvertType); // (long)q.OneToManyCompositeElement[0].Prop + return currentType; // q.Prop, q.OneToManyCompositeElement[0].Prop } - persister = null; - return currentType; + return TryGetEntityPersister(member.ConvertType, sessionFactory, out persister) + ? persister.EntityMetamodel.EntityType // (Entity)q.Prop, (Entity)q.OneToManyCompositeElement[0].Prop + : TypeFactory.GetDefaultTypeFor(member.ConvertType); // (long)q.Prop, (long)q.OneToManyCompositeElement[0].Prop } - if (!member.HasIndexer) + + if (!(currentType is IAssociationType associationType)) { - if (member.ConvertType != null) - { - persister = TryGetEntityPersister(member.ConvertType, sessionFactory, out var newPersister) - ? newPersister - : currentEntityPersister; - return newPersister != null - ? persister.EntityMetamodel.EntityType // (Entity)q.Prop - : TypeFactory.GetDefaultTypeFor(member.ConvertType); // (long)q.Prop - } - - persister = currentEntityPersister; - return currentType; // q.Prop + // q.Prop[0] + return null; } - // q.OneToMany[0] - if (currentType is IAssociationType associationType) + var queryableCollection = (IQueryableCollection) associationType.GetAssociatedJoinable(sessionFactory); + if (member.ConvertType == null) { - var queryableCollection = (IQueryableCollection) associationType.GetAssociatedJoinable(sessionFactory); - if (member.ConvertType != null) - { - return TryGetEntityPersister(member.ConvertType, sessionFactory, out persister) - ? persister.EntityMetamodel.EntityType // (Entity)q.Prop - : TypeFactory.GetDefaultTypeFor(member.ConvertType); // (long)q.Prop - } - - persister = queryableCollection.ElementType.IsEntityType - ? queryableCollection.ElementPersister - : null; - + // q.OneToMany[0] return queryableCollection.ElementType; } - // q.Prop[0] - persister = null; - return null; + return TryGetEntityPersister(member.ConvertType, sessionFactory, out persister) + ? persister.EntityMetamodel.EntityType // (Entity)q.OneToMany[0] + : TypeFactory.GetDefaultTypeFor(member.ConvertType); // (long)q.OneToMany[0] } private class MemberMetadataExtractor : NhExpressionVisitor From b691d1d726ab6ad0db6185a7f069dad6217eeb7f Mon Sep 17 00:00:00 2001 From: maca88 Date: Sun, 13 Oct 2019 13:39:14 +0200 Subject: [PATCH 26/29] Fix typo --- src/NHibernate/Util/ExpressionsHelper.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index 0092c8ea8bf..175a36937ca 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -418,7 +418,6 @@ private static IType GetType( : TypeFactory.GetDefaultTypeFor(member.ConvertType); // (long)q.Prop, (long)q.OneToManyCompositeElement[0].Prop } - if (!(currentType is IAssociationType associationType)) { // q.Prop[0] From 9ea2d639006b9e80ecbb0c700bca172983435122 Mon Sep 17 00:00:00 2001 From: maca88 Date: Mon, 28 Oct 2019 22:15:25 +0100 Subject: [PATCH 27/29] Add support for polymorphic queries, coalesce and conditional expressions --- src/NHibernate.Test/Linq/TryGetMappedTests.cs | 88 ++++++- src/NHibernate/Util/ExpressionsHelper.cs | 248 +++++++++++++++--- 2 files changed, 302 insertions(+), 34 deletions(-) diff --git a/src/NHibernate.Test/Linq/TryGetMappedTests.cs b/src/NHibernate.Test/Linq/TryGetMappedTests.cs index e46c57989ab..a60d96538b9 100644 --- a/src/NHibernate.Test/Linq/TryGetMappedTests.cs +++ b/src/NHibernate.Test/Linq/TryGetMappedTests.cs @@ -580,14 +580,82 @@ public void CastAnyPropertyTest() public void QueryUnmppedEntityTest() { var query = session.Query>().Select(o => o.Id); - AssertFalse(query, null, null, o => o == null); + AssertTrueNotNull(query, typeof(User).FullName, "Id", o => o is Int32Type); } [Test] - public void NotSupportedConditionalExpressionTest() + public void ConditionalExpressionTest() { var query = db.Users.Select(o => (o.Name == "Test" ? o.RegisteredAt : o.LastLoginDate)); - AssertFalse(query, false, null, null, o => o == null); + AssertTrue(query, false, typeof(User).FullName, "RegisteredAt", o => o is DateTimeType); + } + + [Test] + public void ConditionalIfFalseExpressionTest() + { + var query = db.Users.Select(o => (o.Name == "Test" ? DateTime.Today : o.LastLoginDate)); + AssertTrue(query, false, typeof(User).FullName, "LastLoginDate", o => o is DateTimeType); + } + + [Test] + public void ConditionalMemberExpressionTest() + { + var query = db.Users.Select(o => (o.Name == "Test" ? o.NotMappedRole : o.Role).IsActive); + AssertTrue(query, false, typeof(Role).FullName, "IsActive", o => o is BooleanType); + } + + [Test] + public void ConditionalNestedExpressionTest() + { + var query = db.Users.Select(o => (o.Name == "Test" ? o.Component.OtherComponent.OtherProperty1 : o.Component.Property1)); + AssertTrue( + query, + false, + typeof(User).FullName, + "Component.OtherComponent.OtherProperty1", + o => o is AnsiStringType, + o => o?.Name == "component[OtherProperty1]"); + } + + [Test] + public void CoalesceExpressionTest() + { + var query = db.Users.Select(o => o.LastLoginDate ?? o.RegisteredAt); + AssertTrue(query, false, typeof(User).FullName, "LastLoginDate", o => o is DateTimeType); + } + + [Test] + public void CoalesceRightExpressionTest() + { + var query = db.Users.Select(o => ((DateTime?) DateTime.Now) ?? o.RegisteredAt); + AssertTrue(query, false, typeof(User).FullName, "RegisteredAt", o => o is DateTimeType); + } + + [Test] + public void CoalesceMemberExpressionTest() + { + var query = db.Users.Select(o => (o.NotMappedRole ?? o.Role).IsActive); + AssertTrue(query, false, typeof(Role).FullName, "IsActive", o => o is BooleanType); + } + + [Test] + public void CoalesceNestedExpressionTest() + { + var query = db.Users.Select(o => o.Component.OtherComponent.OtherProperty1 ?? o.Component.Property1); + AssertTrue( + query, + false, + typeof(User).FullName, + "Component.OtherComponent.OtherProperty1", + o => o is AnsiStringType, + o => o?.Name == "component[OtherProperty1]"); + } + + [Test] + public void CoalesceConditionalMemberExpressionTest() + { + var query = db.Users.Select(o => (o.Name == "Test" ? o.NotMappedRole : (o.NotMappedRole ?? new Role() ?? o.Role)).IsActive); + AssertTrue(query, false, typeof(Role).FullName, "IsActive", o => o is BooleanType); } [Test] @@ -615,6 +683,20 @@ public void NotNullComponentPropertyTest() o => o?.Name == "component[FirstName,LastName]"); } + [Test] + public void NotRelatedTypeTest() + { + var query = session.Query().Select(o => o.CanReduce); + AssertFalse(query, null, null, o => o == null); + } + + [Test] + public void NotNhQueryableTest() + { + var query = new List().AsQueryable().Select(o => o.Name); + AssertFalse(query, false, null, null, o => o == null); + } + private void AssertFalse( IQueryable query, string expectedEntityName, diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index 175a36937ca..86fb8d04bd3 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -105,6 +105,11 @@ internal static bool TryGetMappedNullability( /// /// When the contains an expression of type , the /// result may not be correct when casting to an entity that is mapped with multiple entity names. + /// When the is polymorphic, the first implementor will be returned. + /// When the contains a , the first found entity name + /// will be returned from or . + /// When the contains a expression, the first found entity name + /// will be returned from or . /// internal static bool TryGetMappedType( ISessionFactoryImplementor sessionFactory, @@ -119,14 +124,50 @@ internal static bool TryGetMappedType( // expressions that we had to traverse in order to find the IEntityNameProvider instance, but in reverse order (bottom to top) // and keep tracking the entity name until we reach to top. - // Try to retrieve the starting entity name with all members that were traversed in that process - var memberPaths = MemberMetadataExtractor.TryGetAllMemberMetadata(expression, out var entityName, out var convertType); - if (memberPaths == null || !TryGetEntityPersister(entityName, null, sessionFactory, out var currentEntityPersister)) + memberPath = null; + mappedType = null; + entityPersister = null; + component = null; + // Try to retrieve the starting entity name with all members that were traversed in that process. + if (!MemberMetadataExtractor.TryGetAllMemberMetadata(expression, out var metadataResults)) { // Failed to find the starting entity name, due to: // - Unsupported expression // - The expression didn't contain the IEntityNameProvider instance - // - Querying an unmapped type e.g. s.Query().Where(a => a.Type == "A") + return false; + } + + // Due to coalesce and conditional expressions we can have multiple paths to traverse, in that case find the first path + // for which we are able to determine the mapped type. + foreach (var metadataResult in metadataResults) + { + if (ProcessMembersMetadataResult( + metadataResult, + sessionFactory, + out mappedType, + out entityPersister, + out component, + out memberPath)) + { + return true; + } + } + + return false; + } + + private static bool ProcessMembersMetadataResult( + MemberMetadataResult metadataResult, + ISessionFactoryImplementor sessionFactory, + out IType mappedType, + out IEntityPersister entityPersister, + out IAbstractComponentType component, + out string memberPath) + { + if (!TryGetEntityPersister(metadataResult.EntityName, null, sessionFactory, out var currentEntityPersister)) + { + // Failed to find the starting entity name, due to: + // - Querying a type that is not related to any entity e.g. s.Query().Where(a => a.Type == "A") memberPath = null; mappedType = null; entityPersister = null; @@ -134,13 +175,17 @@ internal static bool TryGetMappedType( return false; } - if (memberPaths.Count == 0) // The expression do not contain any member expressions + if (metadataResult.MemberPaths.Count == 0) // The expression do not contain any member expressions { - if (convertType != null) + if (metadataResult.ConvertType != null) { - mappedType = TryGetEntityPersister(currentEntityPersister, convertType, sessionFactory, out var convertPersister) + mappedType = TryGetEntityPersister( + currentEntityPersister, + metadataResult.ConvertType, + sessionFactory, + out var convertPersister) ? convertPersister.EntityMetamodel.EntityType // ((Subclass)q) - : TypeFactory.GetDefaultTypeFor(convertType); // ((NotMapped)q) + : TypeFactory.GetDefaultTypeFor(metadataResult.ConvertType); // ((NotMapped)q) } else { @@ -155,9 +200,13 @@ internal static bool TryGetMappedType( // If there was a cast right after the constant expression that contains the IEntityNameProvider instance, we have // to update the entity persister according to it, otherwise use the value returned by TryGetAllMemberMetadata method. - if (convertType != null) + if (metadataResult.ConvertType != null) { - if (!TryGetEntityPersister(currentEntityPersister, convertType, sessionFactory, out var convertPersister)) // ((NotMapped)q).Id + if (!TryGetEntityPersister( + currentEntityPersister, + metadataResult.ConvertType, + sessionFactory, + out var convertPersister)) // ((NotMapped)q).Id { memberPath = null; mappedType = null; @@ -171,7 +220,7 @@ internal static bool TryGetMappedType( return TraverseMembers( sessionFactory, - memberPaths, + metadataResult.MemberPaths, currentEntityPersister, out mappedType, out entityPersister, @@ -337,8 +386,19 @@ private static bool TryGetEntityPersister( var currentEntityPersister = sessionFactory.TryGetEntityPersister(currentEntityName); if (currentEntityPersister == null) { - persister = null; - return false; // Querying an unmapped interface e.g. s.Query().Where(a => a.Type == "A") + // When dealing with a polymorphic query it is not important which entity name we pick + // as they all need to have the same mapped types for members of the type that is queried. + // If one of the entites has a different type mapped (e.g. enum mapped as string instead of numeric), + // the query will fail to execute as currently the ParameterMetadata is bound to IQueryPlan and not to IQueryTranslator + // (e.g. s.Query().Where(a => a.MyEnum == MyEnum.Option)). + currentEntityName = sessionFactory.GetImplementors(currentEntityName).FirstOrDefault(); + if (currentEntityName == null) + { + persister = null; + return false; + } + + currentEntityPersister = sessionFactory.GetEntityPersister(currentEntityName); } return TryGetEntityPersister(currentEntityPersister, convertedType, sessionFactory, out persister); @@ -438,30 +498,60 @@ private static IType GetType( private class MemberMetadataExtractor : NhExpressionVisitor { - private readonly Stack _memberPaths = new Stack(); + private readonly List _childrenResults = new List(); + private readonly Stack _memberPaths; private System.Type _convertType; private bool _hasIndexer; private string _entityName; + private MemberMetadataExtractor(Stack memberPaths, System.Type convertType, bool hasIndexer) + { + _memberPaths = memberPaths; + _convertType = convertType; + _hasIndexer = hasIndexer; + } + /// - /// Traverses the expression from top to bottom until the first containing an IEntityNameProvider instance is found. + /// Traverses the expression from top to bottom until the first containing an IEntityNameProvider + /// instance is found. /// /// The expression to traverse. - /// An output parameter that will be populated by the first that is found or null otherwise. - /// An output parameter that will be populated only when containing an IEntityNameProvider - /// is followed by an . - /// A stack of information about all that were traversed until the first containing an - /// IEntityNameProvider instance is found or null when it was not found or if one of the expressions is not supported. - public static Stack TryGetAllMemberMetadata( + /// Output parameter that represents a collection, where each item contains information about all + /// that were traversed until the first containing an + /// instance is found. The number of items depends on how many different paths exist + /// in the that contains a instance. When + /// is not found or one of the expressions is not supported the parameter will be set to . + /// Whether was populated. + public static bool TryGetAllMemberMetadata(Expression expression, out List results) + { + if (TryGetAllMemberMetadata(expression, new Stack(), null, false, out var result)) + { + results = result.GetAllResults().ToList(); + return true; + } + + results = null; + return false; + } + + private static bool TryGetAllMemberMetadata( Expression expression, - out string entityName, - out System.Type convertType) + Stack memberPaths, + System.Type convertType, + bool hasIndexer, + out MemberMetadataResult results) { - var extractor = new MemberMetadataExtractor(); + var extractor = new MemberMetadataExtractor(memberPaths, convertType, hasIndexer); extractor.Accept(expression); - entityName = extractor._entityName; - convertType = entityName != null ? extractor._convertType : null; - return entityName != null ? extractor._memberPaths : null; + results = extractor._entityName != null || extractor._childrenResults.Count > 0 + ? new MemberMetadataResult( + extractor._childrenResults, + extractor._memberPaths, + extractor._entityName, + extractor._convertType) + : null; + + return results != null; } private void Accept(Expression expression) @@ -527,7 +617,23 @@ protected override Expression VisitBinary(BinaryExpression node) return base.Visit(node.Left); } - return base.VisitBinary(node); + if (node.NodeType == ExpressionType.Coalesce && + (TryGetMembersMetadata(node.Left) | TryGetMembersMetadata(node.Right))) + { + return node; + } + + return Visit(node); + } + + protected override Expression VisitConditional(ConditionalExpression node) + { + if (TryGetMembersMetadata(node.IfTrue) | TryGetMembersMetadata(node.IfFalse)) + { + return node; + } + + return Visit(node); } protected override Expression VisitMethodCall(MethodCallExpression node) @@ -542,9 +648,7 @@ protected override Expression VisitMethodCall(MethodCallExpression node) ); } - // Not supported expression - _entityName = null; - return node; + return Visit(node); } public override Expression Visit(Expression node) @@ -553,6 +657,25 @@ public override Expression Visit(Expression node) _entityName = null; return node; } + + private bool TryGetMembersMetadata(Expression expression) + { + if (TryGetAllMemberMetadata(expression, Clone(_memberPaths), _convertType, _hasIndexer, out var result)) + { + _childrenResults.Add(result); + return true; + } + + return false; + } + + private static Stack Clone(Stack original) + { + var arr = new T[original.Count]; + original.CopyTo(arr, 0); + Array.Reverse(arr); + return new Stack(arr); + } } private struct MemberMetadata @@ -570,5 +693,68 @@ public MemberMetadata(string path, System.Type convertType, bool hasIndexer) public bool HasIndexer { get; } } + + private class MemberMetadataResult + { + public MemberMetadataResult( + List childrenResults, + Stack memberPaths, + string entityName, + System.Type convertType) + { + ChildrenResults = childrenResults; + MemberPaths = memberPaths; + EntityName = entityName; + ConvertType = convertType; + } + + /// + /// Metadata about all that were traversed. + /// + public Stack MemberPaths { get; } + + /// + /// type that was used on a containing + /// an . + /// + public System.Type ConvertType { get; } + + /// + /// The entity name from . + /// + public string EntityName { get; } + + /// + /// Direct children of the current metadata result. + /// + public List ChildrenResults { get; } + + /// + /// Gets all leaf (bottom) children that have the entity name set. + /// + /// + public IEnumerable GetAllResults() + { + return GetAllResults(this); + } + + private static IEnumerable GetAllResults(MemberMetadataResult result) + { + if (result.ChildrenResults.Count == 0) + { + yield return result; + } + else + { + foreach (var childResult in result.ChildrenResults) + { + foreach (var childChildrenResult in GetAllResults(childResult)) + { + yield return childChildrenResult; + } + } + } + } + } } } From 17422d7b58064b3e70f5d39c1a4359adf2d5e870 Mon Sep 17 00:00:00 2001 From: maca88 Date: Tue, 12 Nov 2019 22:25:04 +0100 Subject: [PATCH 28/29] Revert cast function change --- src/NHibernate/Dialect/Function/CastFunction.cs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/NHibernate/Dialect/Function/CastFunction.cs b/src/NHibernate/Dialect/Function/CastFunction.cs index e7a41a91882..8580da1997f 100644 --- a/src/NHibernate/Dialect/Function/CastFunction.cs +++ b/src/NHibernate/Dialect/Function/CastFunction.cs @@ -51,11 +51,7 @@ public SqlString Render(IList args, ISessionFactoryImplementor factory) throw new QueryException("invalid NHibernate type for cast(), was:" + typeName); } - if (!factory.Dialect.TryGetCastTypeName(sqlTypeCodes[0], out sqlType)) - { - sqlType = typeName; - } - //else + sqlType = factory.Dialect.GetCastTypeName(sqlTypeCodes[0]); //{ // //trim off the length/precision/scale // int loc = sqlType.IndexOf('('); From 660f15da8fca82cc5c93bcf42cde092b2a2ed1da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Delaporte?= <12201973+fredericdelaporte@users.noreply.github.com> Date: Sun, 16 Feb 2020 20:18:23 +0100 Subject: [PATCH 29/29] Adjust test namings --- src/NHibernate.Test/Linq/TryGetMappedTests.cs | 154 +++++++++--------- 1 file changed, 77 insertions(+), 77 deletions(-) diff --git a/src/NHibernate.Test/Linq/TryGetMappedTests.cs b/src/NHibernate.Test/Linq/TryGetMappedTests.cs index a60d96538b9..11724e1ac9b 100644 --- a/src/NHibernate.Test/Linq/TryGetMappedTests.cs +++ b/src/NHibernate.Test/Linq/TryGetMappedTests.cs @@ -26,7 +26,7 @@ public class TryGetMappedTests : LinqTestCase private static readonly TryGetMappedType _tryGetMappedType; private static readonly TryGetMappedNullability _tryGetMappedNullability; - delegate bool TryGetMappedType( + private delegate bool TryGetMappedType( ISessionFactoryImplementor sessionFactory, Expression expression, out IType mappedType, @@ -34,7 +34,7 @@ delegate bool TryGetMappedType( out IAbstractComponentType component, out string memberPath); - delegate bool TryGetMappedNullability( + private delegate bool TryGetMappedNullability( ISessionFactoryImplementor sessionFactory, Expression expression, out bool nullability); @@ -108,7 +108,7 @@ protected override string[] Mappings public void SelfTest() { var query = db.OrderLines.Select(o => o); - AssertTrueNotNull( + AssertSupportedAndResultNotNullable( query, typeof(OrderLine).FullName, null, @@ -119,7 +119,7 @@ public void SelfTest() public void SelfCastNotMappedTest() { var query = session.Query().Select(o => (object) o); - AssertTrueNotNull( + AssertSupportedAndResultNotNullable( query, false, typeof(A).FullName, @@ -131,42 +131,42 @@ public void SelfCastNotMappedTest() public void PropertyTest() { var query = db.OrderLines.Select(o => o.Quantity); - AssertTrueNotNull(query, typeof(OrderLine).FullName, "Quantity", o => o is Int32Type); + AssertSupportedAndResultNotNullable(query, typeof(OrderLine).FullName, "Quantity", o => o is Int32Type); } [Test] public void NotMappedPropertyTest() { var query = db.Users.Select(o => o.NotMapped); - AssertFalse(query, typeof(User).FullName, "NotMapped", o => o is null); + AssertUnsupported(query, typeof(User).FullName, "NotMapped", o => o is null); } [Test] public void NestedNotMappedPropertyTest() { var query = db.Users.Select(o => o.Name.Length); - AssertFalse(query, false, null, null, o => o is null); + AssertUnsupported(query, false, null, null, o => o is null); } [Test] public void PropertyCastTest() { var query = db.OrderLines.Select(o => (long) o.Quantity); - AssertTrueNotNull(query, typeof(OrderLine).FullName, "Quantity", o => o is Int64Type); + AssertSupportedAndResultNotNullable(query, typeof(OrderLine).FullName, "Quantity", o => o is Int64Type); } [Test] public void PropertyIndexer() { var query = db.Products.Select(o => o.Name[0]); - AssertFalse(query, null, null, o => o == null); + AssertUnsupported(query, null, null, o => o == null); } [Test] public void EnumInt32Test() { var query = db.Users.Select(o => o.Enum2); - AssertTrueNotNull( + AssertSupportedAndResultNotNullable( query, typeof(User).FullName, "Enum2", @@ -177,28 +177,28 @@ public void EnumInt32Test() public void EnumInt32CastTest() { var query = db.Users.Select(o => (int) o.Enum2); - AssertTrueNotNull(query, typeof(User).FullName, "Enum2", o => o is Int32Type); + AssertSupportedAndResultNotNullable(query, typeof(User).FullName, "Enum2", o => o is Int32Type); } [Test] public void EnumAsStringTest() { var query = db.Users.Select(o => o.Enum1); - AssertTrue(query, typeof(User).FullName, "Enum1", o => o is EnumStoredAsStringType); + AssertSupported(query, typeof(User).FullName, "Enum1", o => o is EnumStoredAsStringType); } [Test] public void IdentifierTest() { var query = db.OrderLines.Select(o => o.Id); - AssertTrueNotNull(query, typeof(OrderLine).FullName, "Id", o => o is Int64Type); + AssertSupportedAndResultNotNullable(query, typeof(OrderLine).FullName, "Id", o => o is Int64Type); } [Test] public void CompositeIdentifierTest() { var query = session.Query().Select(o => o.Id.Date); - AssertTrueNotNull( + AssertSupportedAndResultNotNullable( query, typeof(Fum).FullName, "Id.Date", @@ -210,7 +210,7 @@ public void CompositeIdentifierTest() public void ComponentTest() { var query = db.Customers.Select(o => o.Address); - AssertTrue( + AssertSupported( query, typeof(Customer).FullName, "Address", @@ -221,7 +221,7 @@ public void ComponentTest() public void ComponentPropertyTest() { var query = db.Customers.Select(o => o.Address.City); - AssertTrue( + AssertSupported( query, typeof(Customer).FullName, "Address.City", @@ -233,7 +233,7 @@ public void ComponentPropertyTest() public void ComponentNotMappedPropertyTest() { var query = db.Customers.Select(o => o.Address.NotMapped); - AssertFalse( + AssertUnsupported( query, typeof(Customer).FullName, "Address.NotMapped", @@ -245,14 +245,14 @@ public void ComponentNotMappedPropertyTest() public void ComponentNestedNotMappedPropertyTest() { var query = db.Customers.Select(o => o.Address.City.Length); - AssertFalse(query, false, null, null, o => o == null); + AssertUnsupported(query, false, null, null, o => o == null); } [Test] public void NestedComponentPropertyTest() { var query = db.Users.Select(o => o.Component.OtherComponent.OtherProperty1); - AssertTrue( + AssertSupported( query, typeof(User).FullName, "Component.OtherComponent.OtherProperty1", @@ -264,7 +264,7 @@ public void NestedComponentPropertyTest() public void NestedComponentPropertyCastTest() { var query = db.Users.Select(o => (object) o.Component.OtherComponent.OtherProperty1); - AssertTrue( + AssertSupported( query, typeof(User).FullName, "Component.OtherComponent.OtherProperty1", @@ -276,7 +276,7 @@ public void NestedComponentPropertyCastTest() public void ManyToOneTest() { var query = db.OrderLines.Select(o => o.Order); - AssertTrueNotNull(query, typeof(OrderLine).FullName, "Order", + AssertSupportedAndResultNotNullable(query, typeof(OrderLine).FullName, "Order", o => o is ManyToOneType manyToOne && manyToOne.PropertyName == "Order"); } @@ -284,28 +284,28 @@ public void ManyToOneTest() public void ManyToOnePropertyTest() { var query = db.OrderLines.Select(o => o.Order.Freight); - AssertTrue(query, typeof(Order).FullName, "Freight", o => o is DecimalType); + AssertSupported(query, typeof(Order).FullName, "Freight", o => o is DecimalType); } [Test] public void ManyToOneNotMappedPropertyTest() { var query = db.OrderLines.Select(o => o.Product.NotMapped); - AssertFalse(query, typeof(Product).FullName, "NotMapped", o => o == null); + AssertUnsupported(query, typeof(Product).FullName, "NotMapped", o => o == null); } [Test] public void NotMappedManyToOnePropertyTest() { var query = db.Users.Select(o => o.NotMappedRole.Name); - AssertFalse(query, false, null, null, o => o is null); + AssertUnsupported(query, false, null, null, o => o is null); } [Test] public void NestedManyToOneTest() { var query = db.OrderLines.Select(o => o.Order.Employee); - AssertTrue(query, false, typeof(Order).FullName, "Employee", + AssertSupported(query, false, typeof(Order).FullName, "Employee", o => o is ManyToOneType manyToOne && manyToOne.PropertyName == "Employee"); } @@ -313,14 +313,14 @@ public void NestedManyToOneTest() public void NestedManyToOnePropertyTest() { var query = db.OrderLines.Select(o => o.Order.Employee.BirthDate); - AssertTrue(query, typeof(Employee).FullName, "BirthDate", o => o is DateTimeType); + AssertSupported(query, typeof(Employee).FullName, "BirthDate", o => o is DateTimeType); } [Test] public void OneToManyTest() { var query = db.Customers.SelectMany(o => o.Orders); - AssertTrue( + AssertSupported( query, typeof(Customer).FullName, "Orders", @@ -331,21 +331,21 @@ public void OneToManyTest() public void OneToManyElementIndexerTest() { var query = session.Query().Select(o => o.StringList[0]); - AssertTrue(query, false, typeof(Baz).FullName, "StringList", o => o is StringType); + AssertSupported(query, false, typeof(Baz).FullName, "StringList", o => o is StringType); } [Test] public void OneToManyElementIndexerNotMappedPropertyTest() { var query = session.Query().Select(o => o.StringList[0].Length); - AssertFalse(query, false, null, null, o => o == null); + AssertUnsupported(query, false, null, null, o => o == null); } [Test] public void OneToManyCustomElementIndexerTest() { var query = session.Query().Select(o => o.Customs[0]); - AssertTrue( + AssertSupported( query, false, typeof(Baz).FullName, @@ -357,28 +357,28 @@ public void OneToManyCustomElementIndexerTest() public void OneToManyIndexerCastTest() { var query = session.Query().Select(o => (long) o.IntArray[0]); - AssertTrue(query, false, typeof(Baz).FullName, "IntArray", o => o is Int64Type); + AssertSupported(query, false, typeof(Baz).FullName, "IntArray", o => o is Int64Type); } [Test] public void OneToManyIndexerPropertyTest() { var query = session.Query().Select(o => o.Fees[0].Count); - AssertTrue(query, false, typeof(Fee).FullName, "Count", o => o is Int32Type); + AssertSupported(query, false, typeof(Fee).FullName, "Count", o => o is Int32Type); } [Test] public void OneToManyElementAtTest() { var query = session.Query().Select(o => o.StringList.ElementAt(0)); - AssertTrue(query, false, typeof(Baz).FullName, "StringList", o => o is StringType); + AssertSupported(query, false, typeof(Baz).FullName, "StringList", o => o is StringType); } [Test] public void NestedOneToManyManyToOneComponentPropertyTest() { var query = session.Query().SelectMany(o => o.Fees).Select(o => o.TheFee.Compon.Name); - AssertTrue( + AssertSupported( query, typeof(Fee).FullName, "Compon.Name", @@ -390,7 +390,7 @@ public void NestedOneToManyManyToOneComponentPropertyTest() public void OneToManyCompositeElementPropertyTest() { var query = session.Query().Select(o => o.Components[0].Count); - AssertTrue( + AssertSupported( query, false, null, @@ -403,14 +403,14 @@ public void OneToManyCompositeElementPropertyTest() public void OneToManyCompositeElementPropertyIndexerTest() { var query = session.Query().Select(o => o.Components[0].Name[0]); - AssertFalse(query, false, null, null, o => o == null); + AssertUnsupported(query, false, null, null, o => o == null); } [Test] public void OneToManyCompositeElementNotMappedPropertyTest() { var query = session.Query().Select(o => o.Components[0].NotMapped); - AssertFalse( + AssertUnsupported( query, false, null, @@ -423,7 +423,7 @@ public void OneToManyCompositeElementNotMappedPropertyTest() public void OneToManyCompositeElementCastPropertyTest() { var query = session.Query().Select(o => (long) o.Components[0].Count); - AssertTrue( + AssertSupported( query, false, null, @@ -436,7 +436,7 @@ public void OneToManyCompositeElementCastPropertyTest() public void OneToManyCompositeElementCollectionNotMappedPropertyTest() { var query = session.Query().SelectMany(o => o.Components[0].ImportantDates); - AssertFalse( + AssertUnsupported( query, false, null, @@ -449,7 +449,7 @@ public void OneToManyCompositeElementCollectionNotMappedPropertyTest() public void NestedOneToManyCompositeElementTest() { var query = session.Query().Select(o => o.Components[0].Subcomponent); - AssertTrue( + AssertSupported( query, false, null, @@ -462,21 +462,21 @@ public void NestedOneToManyCompositeElementTest() public void NestedOneToManyCompositeElementPropertyTest() { var query = session.Query().Select(o => o.Components[0].Subcomponent.Name); - AssertTrue(query, false, null, "Name", o => o is StringType, o => o?.Name == "component[Name,Count]"); + AssertSupported(query, false, null, "Name", o => o is StringType, o => o?.Name == "component[Name,Count]"); } [Test] public void NestedOneToManyCompositeElementPropertyIndexerTest() { var query = session.Query().Select(o => o.Components[0].Subcomponent.Name[0]); - AssertFalse(query, false, null, null, o => o == null); + AssertUnsupported(query, false, null, null, o => o == null); } [Test] public void ManyToManyTest() { var query = session.Query().Select(o => o.FooArray); - AssertTrue( + AssertSupported( query, false, typeof(Baz).FullName, @@ -488,14 +488,14 @@ public void ManyToManyTest() public void ManyToManyIndexerTest() { var query = session.Query().Select(o => o.FooArray[0].Null); - AssertTrue(query, false, typeof(Foo).FullName, "Null", o => o is NullableInt32Type); + AssertSupported(query, false, typeof(Foo).FullName, "Null", o => o is NullableInt32Type); } [Test] public void SubclassCastTest() { var query = session.Query().Select(o => (B) o); - AssertTrueNotNull( + AssertSupportedAndResultNotNullable( query, typeof(A).FullName, null, @@ -506,7 +506,7 @@ public void SubclassCastTest() public void NestedSubclassCastTest() { var query = session.Query().Select(o => (C1) ((B) o)); - AssertTrueNotNull( + AssertSupportedAndResultNotNullable( query, false, typeof(A).FullName, @@ -518,28 +518,28 @@ public void NestedSubclassCastTest() public void SubclassPropertyTest() { var query = session.Query().Select(o => ((C1) o).Count); - AssertTrue(query, typeof(C1).FullName, "Count", o => o is Int32Type); + AssertSupported(query, typeof(C1).FullName, "Count", o => o is Int32Type); } [Test] public void NestedSubclassCastPropertyTest() { var query = session.Query().Select(o => ((C1) ((B) o)).Id); - AssertTrueNotNull(query, typeof(C1).FullName, "Id", o => o is Int64Type); + AssertSupportedAndResultNotNullable(query, typeof(C1).FullName, "Id", o => o is Int64Type); } [Test] public void AnyTest() { var query = session.Query().Select(o => o.Object); - AssertTrue(query, typeof(Bar).FullName, "Object", o => o.IsAnyType); + AssertSupported(query, typeof(Bar).FullName, "Object", o => o.IsAnyType); } [Test] public void CastAnyTest() { var query = session.Query().Select(o => (Foo) o.Object); - AssertTrue( + AssertSupported( query, typeof(Bar).FullName, "Object", @@ -550,7 +550,7 @@ public void CastAnyTest() public void NestedCastAnyTest() { var query = session.Query().Select(o => (Foo) ((Bar) o.Object).Object); - AssertTrue( + AssertSupported( query, false, typeof(Bar).FullName, @@ -562,7 +562,7 @@ public void NestedCastAnyTest() public void CastAnyManyToOneTest() { var query = session.Query().Select(o => ((Foo) o.Object).Dependent); - AssertTrueNotNull( + AssertSupportedAndResultNotNullable( query, typeof(Foo).FullName, "Dependent", @@ -573,42 +573,42 @@ public void CastAnyManyToOneTest() public void CastAnyPropertyTest() { var query = session.Query().Select(o => ((Foo) o.Object).String); - AssertTrue(query, false, typeof(Foo).FullName, "String", o => o is StringType); + AssertSupported(query, false, typeof(Foo).FullName, "String", o => o is StringType); } [Test] - public void QueryUnmppedEntityTest() + public void QueryUnmappedEntityTest() { var query = session.Query>().Select(o => o.Id); - AssertTrueNotNull(query, typeof(User).FullName, "Id", o => o is Int32Type); + AssertSupportedAndResultNotNullable(query, typeof(User).FullName, "Id", o => o is Int32Type); } [Test] public void ConditionalExpressionTest() { var query = db.Users.Select(o => (o.Name == "Test" ? o.RegisteredAt : o.LastLoginDate)); - AssertTrue(query, false, typeof(User).FullName, "RegisteredAt", o => o is DateTimeType); + AssertSupported(query, false, typeof(User).FullName, "RegisteredAt", o => o is DateTimeType); } [Test] public void ConditionalIfFalseExpressionTest() { var query = db.Users.Select(o => (o.Name == "Test" ? DateTime.Today : o.LastLoginDate)); - AssertTrue(query, false, typeof(User).FullName, "LastLoginDate", o => o is DateTimeType); + AssertSupported(query, false, typeof(User).FullName, "LastLoginDate", o => o is DateTimeType); } [Test] public void ConditionalMemberExpressionTest() { var query = db.Users.Select(o => (o.Name == "Test" ? o.NotMappedRole : o.Role).IsActive); - AssertTrue(query, false, typeof(Role).FullName, "IsActive", o => o is BooleanType); + AssertSupported(query, false, typeof(Role).FullName, "IsActive", o => o is BooleanType); } [Test] public void ConditionalNestedExpressionTest() { var query = db.Users.Select(o => (o.Name == "Test" ? o.Component.OtherComponent.OtherProperty1 : o.Component.Property1)); - AssertTrue( + AssertSupported( query, false, typeof(User).FullName, @@ -621,28 +621,28 @@ public void ConditionalNestedExpressionTest() public void CoalesceExpressionTest() { var query = db.Users.Select(o => o.LastLoginDate ?? o.RegisteredAt); - AssertTrue(query, false, typeof(User).FullName, "LastLoginDate", o => o is DateTimeType); + AssertSupported(query, false, typeof(User).FullName, "LastLoginDate", o => o is DateTimeType); } [Test] public void CoalesceRightExpressionTest() { var query = db.Users.Select(o => ((DateTime?) DateTime.Now) ?? o.RegisteredAt); - AssertTrue(query, false, typeof(User).FullName, "RegisteredAt", o => o is DateTimeType); + AssertSupported(query, false, typeof(User).FullName, "RegisteredAt", o => o is DateTimeType); } [Test] public void CoalesceMemberExpressionTest() { var query = db.Users.Select(o => (o.NotMappedRole ?? o.Role).IsActive); - AssertTrue(query, false, typeof(Role).FullName, "IsActive", o => o is BooleanType); + AssertSupported(query, false, typeof(Role).FullName, "IsActive", o => o is BooleanType); } [Test] public void CoalesceNestedExpressionTest() { var query = db.Users.Select(o => o.Component.OtherComponent.OtherProperty1 ?? o.Component.Property1); - AssertTrue( + AssertSupported( query, false, typeof(User).FullName, @@ -655,7 +655,7 @@ public void CoalesceNestedExpressionTest() public void CoalesceConditionalMemberExpressionTest() { var query = db.Users.Select(o => (o.Name == "Test" ? o.NotMappedRole : (o.NotMappedRole ?? new Role() ?? o.Role)).IsActive); - AssertTrue(query, false, typeof(Role).FullName, "IsActive", o => o is BooleanType); + AssertSupported(query, false, typeof(Role).FullName, "IsActive", o => o is BooleanType); } [Test] @@ -668,14 +668,14 @@ join d in db.OrderLines into details from d in details select d.UnitPrice; - AssertTrueNotNull(query, typeof(OrderLine).FullName, "UnitPrice", o => o is DecimalType); + AssertSupportedAndResultNotNullable(query, typeof(OrderLine).FullName, "UnitPrice", o => o is DecimalType); } [Test] public void NotNullComponentPropertyTest() { var query = session.Query().SelectMany(o => o.PatientRecords.Select(r => r.Name.FirstName)); - AssertTrueNotNull( + AssertSupportedAndResultNotNullable( query, typeof(PatientRecord).FullName, "Name.FirstName", @@ -687,17 +687,17 @@ public void NotNullComponentPropertyTest() public void NotRelatedTypeTest() { var query = session.Query().Select(o => o.CanReduce); - AssertFalse(query, null, null, o => o == null); + AssertUnsupported(query, null, null, o => o == null); } [Test] public void NotNhQueryableTest() { var query = new List().AsQueryable().Select(o => o.Name); - AssertFalse(query, false, null, null, o => o == null); + AssertUnsupported(query, false, null, null, o => o == null); } - private void AssertFalse( + private void AssertUnsupported( IQueryable query, string expectedEntityName, string expectedMemberPath, @@ -707,7 +707,7 @@ private void AssertFalse( AssertResult(query, true, false, expectedEntityName, expectedMemberPath, expectedMemberType, expectedComponentType); } - private void AssertFalse( + private void AssertUnsupported( IQueryable query, bool rewriteQuery, string expectedEntityName, @@ -718,7 +718,7 @@ private void AssertFalse( AssertResult(query, rewriteQuery, false, expectedEntityName, expectedMemberPath, expectedMemberType, expectedComponentType); } - private void AssertTrue( + private void AssertSupported( IQueryable query, string expectedEntityName, string expectedMemberPath, @@ -728,7 +728,7 @@ private void AssertTrue( AssertResult(query, true, true, expectedEntityName, expectedMemberPath, expectedMemberType, expectedComponentType); } - private void AssertTrue( + private void AssertSupported( IQueryable query, bool rewriteQuery, string expectedEntityName, @@ -739,7 +739,7 @@ private void AssertTrue( AssertResult(query, rewriteQuery, true, expectedEntityName, expectedMemberPath, expectedMemberType, expectedComponentType); } - private void AssertTrueNotNull( + private void AssertSupportedAndResultNotNullable( IQueryable query, string expectedEntityName, string expectedMemberPath, @@ -749,7 +749,7 @@ private void AssertTrueNotNull( AssertResult(query, true, true, expectedEntityName, expectedMemberPath, expectedMemberType, expectedComponentType, false); } - private void AssertTrueNotNull( + private void AssertSupportedAndResultNotNullable( IQueryable query, bool rewriteQuery, string expectedEntityName, @@ -763,7 +763,7 @@ private void AssertTrueNotNull( private void AssertResult( IQueryable query, bool rewriteQuery, - bool result, + bool supported, string expectedEntityName, string expectedMemberPath, Predicate expectedMemberType, @@ -800,8 +800,8 @@ private void AssertResult( out var entityPersister, out var componentType, out var memberPath); - Assert.That(found, Is.EqualTo(result), "Expression should be supported"); - Assert.That(entityPersister?.EntityName, Is.EqualTo(expectedEntityName), "Invalid enity name"); + Assert.That(found, Is.EqualTo(supported), $"Expression should be {(supported ? "supported" : "unsupported")}"); + Assert.That(entityPersister?.EntityName, Is.EqualTo(expectedEntityName), "Invalid entity name"); Assert.That(memberPath, Is.EqualTo(expectedMemberPath), "Invalid member path"); Assert.That(() => expectedMemberType(memberType), $"Invalid member type: {memberType?.Name ?? "null"}"); Assert.That(() => expectedComponentType(componentType), $"Invalid component type: {componentType?.Name ?? "null"}");