From 591f30ed3838ad4fb56faa742642e8f38101fdaa Mon Sep 17 00:00:00 2001 From: maca88 Date: Mon, 17 Aug 2020 22:49:26 +0200 Subject: [PATCH 01/21] Avoid unnecessary casting for Linq provider --- .../Northwind/Entities/User.cs | 4 +++ .../Northwind/Mappings/User.hbm.xml | 3 +++ .../Async/Linq/NullComparisonTests.cs | 26 +++++++++++++++++++ .../Async/Linq/ParameterTests.cs | 9 +++++++ .../Linq/NullComparisonTests.cs | 26 +++++++++++++++++++ src/NHibernate.Test/Linq/ParameterTests.cs | 9 +++++++ .../RemoveRedundantCast.cs | 10 ++++++- .../Visitors/HqlGeneratorExpressionVisitor.cs | 15 ++++++++++- .../NhPartialEvaluatingExpressionVisitor.cs | 10 ++++++- 9 files changed, 109 insertions(+), 3 deletions(-) diff --git a/src/NHibernate.DomainModel/Northwind/Entities/User.cs b/src/NHibernate.DomainModel/Northwind/Entities/User.cs index 14096dac912..7b2993236fb 100644 --- a/src/NHibernate.DomainModel/Northwind/Entities/User.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/User.cs @@ -50,6 +50,10 @@ public class User : IUser, IEntity public virtual User NotMappedUser => this; + public virtual short Short { get; set; } + + public virtual short? NullableShort { get; set; } + public virtual EnumStoredAsString Enum1 { get; set; } public virtual EnumStoredAsString? NullableEnum1 { get; set; } diff --git a/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml b/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml index f249de9574e..1b6548bde0f 100644 --- a/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml +++ b/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml @@ -20,6 +20,9 @@ + + + diff --git a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs index 6a5c75091c7..d573c87569e 100644 --- a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs @@ -472,6 +472,23 @@ public async Task NullEqualityAsync() await (ExpectAsync(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id == 5), Does.Not.Contain("is null").IgnoreCase)); await (ExpectAsync(session.Query().Where(o => 5 == o.CreatedBy.ModifiedBy.Id), Does.Not.Contain("is null").IgnoreCase)); + + short value = 3; + await (ExpectAsync(session.Query().Where(o => o.NullableShort == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(session.Query().Where(o => value == o.NullableShort), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + + await (ExpectAsync(session.Query().Where(o => o.NullableShort.Value == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(session.Query().Where(o => value == o.NullableShort.Value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(session.Query().Where(o => o.Short == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(session.Query().Where(o => value == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + + await (ExpectAsync(session.Query().Where(o => o.NullableShort == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"))); + await (ExpectAsync(session.Query().Where(o => 3 == o.NullableShort), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"))); + + await (ExpectAsync(session.Query().Where(o => o.NullableShort.Value == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"))); + await (ExpectAsync(session.Query().Where(o => 3 == o.NullableShort.Value), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"))); + await (ExpectAsync(session.Query().Where(o => o.Short == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"))); + await (ExpectAsync(session.Query().Where(o => 3 == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"))); } [Test] @@ -560,6 +577,15 @@ public async Task NullInequalityAsync() await (ExpectAsync(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id != 5), Does.Contain("is null").IgnoreCase)); await (ExpectAsync(session.Query().Where(o => 5 != o.CreatedBy.ModifiedBy.Id), Does.Contain("is null").IgnoreCase)); + + short value = 3; + await (ExpectAsync(session.Query().Where(o => o.NullableShort != value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(session.Query().Where(o => value != o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + + await (ExpectAsync(session.Query().Where(o => o.NullableShort.Value != value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(session.Query().Where(o => value != o.NullableShort.Value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(session.Query().Where(o => o.Short != value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(session.Query().Where(o => value != o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); } [Test] diff --git a/src/NHibernate.Test/Async/Linq/ParameterTests.cs b/src/NHibernate.Test/Async/Linq/ParameterTests.cs index ad1c885dc4f..b7f021b6dc0 100644 --- a/src/NHibernate.Test/Async/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Async/Linq/ParameterTests.cs @@ -125,6 +125,15 @@ public async Task UsingValueTypeParameterTwiceAsync() 1)); } + [Test] + public async Task UsingValueTypeParameterTwiceOnNullablePropertyAsync() + { + short value = 1; + await (AssertTotalParametersAsync( + db.Users.Where(o => o.NullableShort == value && o.NullableShort != value && o.Short == value), + 1)); + } + [Test] public async Task UsingParameterInEvaluatableExpressionAsync() { diff --git a/src/NHibernate.Test/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Linq/NullComparisonTests.cs index 0ed569813f3..3ec9cf74019 100644 --- a/src/NHibernate.Test/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Linq/NullComparisonTests.cs @@ -460,6 +460,23 @@ public void NullEquality() Expect(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id == 5), Does.Not.Contain("is null").IgnoreCase); Expect(session.Query().Where(o => 5 == o.CreatedBy.ModifiedBy.Id), Does.Not.Contain("is null").IgnoreCase); + + short value = 3; + Expect(session.Query().Where(o => o.NullableShort == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(session.Query().Where(o => value == o.NullableShort), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + + Expect(session.Query().Where(o => o.NullableShort.Value == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(session.Query().Where(o => value == o.NullableShort.Value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(session.Query().Where(o => o.Short == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(session.Query().Where(o => value == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + + Expect(session.Query().Where(o => o.NullableShort == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")); + Expect(session.Query().Where(o => 3 == o.NullableShort), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")); + + Expect(session.Query().Where(o => o.NullableShort.Value == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")); + Expect(session.Query().Where(o => 3 == o.NullableShort.Value), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")); + Expect(session.Query().Where(o => o.Short == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")); + Expect(session.Query().Where(o => 3 == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")); } [Test] @@ -548,6 +565,15 @@ public void NullInequality() Expect(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id != 5), Does.Contain("is null").IgnoreCase); Expect(session.Query().Where(o => 5 != o.CreatedBy.ModifiedBy.Id), Does.Contain("is null").IgnoreCase); + + short value = 3; + Expect(session.Query().Where(o => o.NullableShort != value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(session.Query().Where(o => value != o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + + Expect(session.Query().Where(o => o.NullableShort.Value != value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(session.Query().Where(o => value != o.NullableShort.Value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(session.Query().Where(o => o.Short != value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(session.Query().Where(o => value != o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); } [Test] diff --git a/src/NHibernate.Test/Linq/ParameterTests.cs b/src/NHibernate.Test/Linq/ParameterTests.cs index 97da1e0a079..8a831ff28fc 100644 --- a/src/NHibernate.Test/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTests.cs @@ -113,6 +113,15 @@ public void UsingValueTypeParameterTwice() 1); } + [Test] + public void UsingValueTypeParameterTwiceOnNullableProperty() + { + short value = 1; + AssertTotalParameters( + db.Users.Where(o => o.NullableShort == value && o.NullableShort != value && o.Short == value), + 1); + } + [Test] public void UsingParameterInEvaluatableExpression() { diff --git a/src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs b/src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs index 538e46cb828..8feac321a1c 100644 --- a/src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs +++ b/src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs @@ -1,4 +1,5 @@ using System.Linq.Expressions; +using NHibernate.Util; using Remotion.Linq.Parsing.ExpressionVisitors.Transformation; namespace NHibernate.Linq.ExpressionTransformers @@ -26,6 +27,13 @@ public Expression Transform(UnaryExpression expression) return expression.Operand; } + // Reduce double casting (e.g. (long?)(long)3 => (long?)3) + if (expression.Operand.NodeType == ExpressionType.Convert && + expression.Type.UnwrapIfNullable() == expression.Operand.Type) + { + return Expression.Convert(((UnaryExpression) expression.Operand).Operand, expression.Type); + } + return expression; } @@ -34,4 +42,4 @@ public ExpressionType[] SupportedExpressionTypes get { return _supportedExpressionTypes; } } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index 34315c240d2..f76fecef9c3 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Data; using System.Dynamic; using System.Linq; @@ -23,6 +24,7 @@ public class HqlGeneratorExpressionVisitor : IHqlExpressionVisitor private readonly VisitorParameters _parameters; private readonly ILinqToHqlGeneratorsRegistry _functionRegistry; private readonly NullableExpressionDetector _nullableExpressionDetector; + private readonly HashSet _notCastableExpressions = new HashSet(); public static HqlTreeNode Visit(Expression expression, VisitorParameters parameters) { @@ -308,6 +310,17 @@ private HqlTreeNode VisitVBStringComparisonExpression(VBStringComparisonExpressi protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression) { + // In .NET some numeric types do not have thier own operators (e.g. short == short is converted to (int) short == (int) short), + // in such case we dont want to add a sql cast + if (expression.Left.NodeType == ExpressionType.Convert && + expression.Right.NodeType == ExpressionType.Convert && + ((UnaryExpression) expression.Left).Operand.Type.UnwrapIfNullable() == + ((UnaryExpression) expression.Right).Operand.Type.UnwrapIfNullable()) + { + _notCastableExpressions.Add(expression.Left); + _notCastableExpressions.Add(expression.Right); + } + if (expression.NodeType == ExpressionType.Equal) { return TranslateEqualityComparison(expression); @@ -496,7 +509,7 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression) case ExpressionType.Convert: case ExpressionType.ConvertChecked: case ExpressionType.TypeAs: - return IsCastRequired(expression.Operand, expression.Type, out var existType) + return IsCastRequired(expression.Operand, expression.Type, out var existType) && !_notCastableExpressions.Contains(expression) ? _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), expression.Type) // Make a transparent cast when an IType exists, so that it can be used to retrieve the value from the data reader : existType && HqlIdent.SupportsType(expression.Type) diff --git a/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs index 9ee40092e6b..b198da9002f 100644 --- a/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs @@ -81,7 +81,15 @@ public override Expression Visit(Expression expression) #region NH additions // Variables should be evaluated only when they are part of an evaluatable expression (e.g. o => string.Format("...", variable)) expression is UnaryExpression unaryExpression && - ExpressionsHelper.IsVariable(unaryExpression.Operand, out _, out _)) + ( + ExpressionsHelper.IsVariable(unaryExpression.Operand, out _, out _) || + // Check whether the variable is casted due to comparison with a nullable expression + // (e.g. o.NullableShort == shortVariable) + unaryExpression.Operand is UnaryExpression subUnaryExpression && + unaryExpression.Type.UnwrapIfNullable() == subUnaryExpression.Type && + ExpressionsHelper.IsVariable(subUnaryExpression.Operand, out _, out _) + ) + ) #endregion return base.Visit(expression); From 05dd4d6a73946ad482c7151bc0a4af89a9060cb0 Mon Sep 17 00:00:00 2001 From: maca88 Date: Tue, 18 Aug 2020 23:34:32 +0200 Subject: [PATCH 02/21] Fix tests --- .../Async/Linq/NullComparisonTests.cs | 51 ++++++++++++++++--- .../Linq/NullComparisonTests.cs | 51 ++++++++++++++++--- .../Visitors/HqlGeneratorExpressionVisitor.cs | 16 +++--- 3 files changed, 98 insertions(+), 20 deletions(-) diff --git a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs index d573c87569e..56cdb135022 100644 --- a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs @@ -12,6 +12,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using NHibernate.Dialect; using NHibernate.Linq; using NHibernate.DomainModel.Northwind.Entities; using NUnit.Framework; @@ -482,13 +483,17 @@ public async Task NullEqualityAsync() await (ExpectAsync(session.Query().Where(o => o.Short == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); await (ExpectAsync(session.Query().Where(o => value == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(session.Query().Where(o => o.NullableShort == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"))); - await (ExpectAsync(session.Query().Where(o => 3 == o.NullableShort), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"))); - - await (ExpectAsync(session.Query().Where(o => o.NullableShort.Value == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"))); - await (ExpectAsync(session.Query().Where(o => 3 == o.NullableShort.Value), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"))); - await (ExpectAsync(session.Query().Where(o => o.Short == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"))); - await (ExpectAsync(session.Query().Where(o => 3 == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"))); + var shouldCast = Sfi.Dialect is SQLiteDialect || // transparent cast is translated to sql + Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int16.SqlType, out var shortCast) && + Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int32.SqlType, out var intCast) && + shortCast != intCast; + await (ExpectAsync(session.Query().Where(o => o.NullableShort == 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast())); + await (ExpectAsync(session.Query().Where(o => 3 == o.NullableShort), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast())); + + await (ExpectAsync(session.Query().Where(o => o.NullableShort.Value == 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast())); + await (ExpectAsync(session.Query().Where(o => 3 == o.NullableShort.Value), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast())); + await (ExpectAsync(session.Query().Where(o => o.Short == 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast())); + await (ExpectAsync(session.Query().Where(o => 3 == o.Short), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast())); } [Test] @@ -586,6 +591,38 @@ public async Task NullInequalityAsync() await (ExpectAsync(session.Query().Where(o => value != o.NullableShort.Value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); await (ExpectAsync(session.Query().Where(o => o.Short != value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); await (ExpectAsync(session.Query().Where(o => value != o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + + var shouldCast = Sfi.Dialect is SQLiteDialect || // transparent cast is translated to sql + Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int16.SqlType, out var shortCast) && + Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int32.SqlType, out var intCast) && + shortCast != intCast; + await (ExpectAsync(session.Query().Where(o => o.NullableShort != 3), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast())); + await (ExpectAsync(session.Query().Where(o => 3 != o.NullableShort), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast())); + + await (ExpectAsync(session.Query().Where(o => o.NullableShort.Value != 3), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast())); + await (ExpectAsync(session.Query().Where(o => 3 != o.NullableShort.Value), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast())); + await (ExpectAsync(session.Query().Where(o => o.Short != 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast())); + await (ExpectAsync(session.Query().Where(o => 3 != o.Short), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast())); + } + + private IResolveConstraint WithIsNullAndWithoutCast() + { + return Does.Contain("is null").IgnoreCase.And.Not.Contain("cast").IgnoreCase; + } + + private IResolveConstraint WithIsNullAndWithCast() + { + return Does.Contain("is null").IgnoreCase.And.Contain("cast").IgnoreCase; + } + + private IResolveConstraint WithoutIsNullAndWithoutCast() + { + return Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast").IgnoreCase; + } + + private IResolveConstraint WithoutIsNullAndWithCast() + { + return Does.Not.Contain("is null").IgnoreCase.And.Contain("cast").IgnoreCase; } [Test] diff --git a/src/NHibernate.Test/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Linq/NullComparisonTests.cs index 3ec9cf74019..ad3b9a74bc7 100644 --- a/src/NHibernate.Test/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Linq/NullComparisonTests.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using NHibernate.Dialect; using NHibernate.Linq; using NHibernate.DomainModel.Northwind.Entities; using NUnit.Framework; @@ -470,13 +471,17 @@ public void NullEquality() Expect(session.Query().Where(o => o.Short == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); Expect(session.Query().Where(o => value == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(session.Query().Where(o => o.NullableShort == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")); - Expect(session.Query().Where(o => 3 == o.NullableShort), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")); - - Expect(session.Query().Where(o => o.NullableShort.Value == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")); - Expect(session.Query().Where(o => 3 == o.NullableShort.Value), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")); - Expect(session.Query().Where(o => o.Short == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")); - Expect(session.Query().Where(o => 3 == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")); + var shouldCast = Sfi.Dialect is SQLiteDialect || // transparent cast is translated to sql + Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int16.SqlType, out var shortCast) && + Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int32.SqlType, out var intCast) && + shortCast != intCast; + Expect(session.Query().Where(o => o.NullableShort == 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()); + Expect(session.Query().Where(o => 3 == o.NullableShort), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()); + + Expect(session.Query().Where(o => o.NullableShort.Value == 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()); + Expect(session.Query().Where(o => 3 == o.NullableShort.Value), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()); + Expect(session.Query().Where(o => o.Short == 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()); + Expect(session.Query().Where(o => 3 == o.Short), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()); } [Test] @@ -574,6 +579,38 @@ public void NullInequality() Expect(session.Query().Where(o => value != o.NullableShort.Value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); Expect(session.Query().Where(o => o.Short != value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); Expect(session.Query().Where(o => value != o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + + var shouldCast = Sfi.Dialect is SQLiteDialect || // transparent cast is translated to sql + Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int16.SqlType, out var shortCast) && + Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int32.SqlType, out var intCast) && + shortCast != intCast; + Expect(session.Query().Where(o => o.NullableShort != 3), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast()); + Expect(session.Query().Where(o => 3 != o.NullableShort), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast()); + + Expect(session.Query().Where(o => o.NullableShort.Value != 3), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast()); + Expect(session.Query().Where(o => 3 != o.NullableShort.Value), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast()); + Expect(session.Query().Where(o => o.Short != 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()); + Expect(session.Query().Where(o => 3 != o.Short), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()); + } + + private IResolveConstraint WithIsNullAndWithoutCast() + { + return Does.Contain("is null").IgnoreCase.And.Not.Contain("cast").IgnoreCase; + } + + private IResolveConstraint WithIsNullAndWithCast() + { + return Does.Contain("is null").IgnoreCase.And.Contain("cast").IgnoreCase; + } + + private IResolveConstraint WithoutIsNullAndWithoutCast() + { + return Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast").IgnoreCase; + } + + private IResolveConstraint WithoutIsNullAndWithCast() + { + return Does.Not.Contain("is null").IgnoreCase.And.Contain("cast").IgnoreCase; } [Test] diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index f76fecef9c3..b58ca33a305 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -24,7 +24,7 @@ public class HqlGeneratorExpressionVisitor : IHqlExpressionVisitor private readonly VisitorParameters _parameters; private readonly ILinqToHqlGeneratorsRegistry _functionRegistry; private readonly NullableExpressionDetector _nullableExpressionDetector; - private readonly HashSet _notCastableExpressions = new HashSet(); + private readonly Dictionary _notCastableExpressions = new Dictionary(); public static HqlTreeNode Visit(Expression expression, VisitorParameters parameters) { @@ -317,8 +317,9 @@ protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression) ((UnaryExpression) expression.Left).Operand.Type.UnwrapIfNullable() == ((UnaryExpression) expression.Right).Operand.Type.UnwrapIfNullable()) { - _notCastableExpressions.Add(expression.Left); - _notCastableExpressions.Add(expression.Right); + var type = ((UnaryExpression) expression.Left).Operand.Type.UnwrapIfNullable(); + _notCastableExpressions.Add(expression.Left, type); + _notCastableExpressions.Add(expression.Right, type); } if (expression.NodeType == ExpressionType.Equal) @@ -509,11 +510,14 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression) case ExpressionType.Convert: case ExpressionType.ConvertChecked: case ExpressionType.TypeAs: - return IsCastRequired(expression.Operand, expression.Type, out var existType) && !_notCastableExpressions.Contains(expression) + var notCastable = _notCastableExpressions.TryGetValue(expression, out var castType); + castType = castType ?? expression.Type; + + return IsCastRequired(expression.Operand, castType, out var existType) && !notCastable ? _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), expression.Type) // Make a transparent cast when an IType exists, so that it can be used to retrieve the value from the data reader - : existType && HqlIdent.SupportsType(expression.Type) - ? _hqlTreeBuilder.TransparentCast(VisitExpression(expression.Operand).AsExpression(), expression.Type) + : existType && HqlIdent.SupportsType(castType) + ? _hqlTreeBuilder.TransparentCast(VisitExpression(expression.Operand).AsExpression(), castType) : VisitExpression(expression.Operand); } From 085a1705ba1a57c74680191d33e1784f50e688b0 Mon Sep 17 00:00:00 2001 From: maca88 Date: Sat, 22 Aug 2020 01:44:41 +0200 Subject: [PATCH 03/21] Avoid additional casting and add more test cases --- .../Northwind/Entities/Northwind.cs | 5 + .../Northwind/Entities/NumericEntity.cs | 29 ++ .../Northwind/Entities/User.cs | 4 - .../Northwind/Mappings/NumericEntity.hbm.xml | 20 + .../Northwind/Mappings/User.hbm.xml | 3 - .../Async/Linq/NullComparisonTests.cs | 78 ++-- .../Async/Linq/ParameterTests.cs | 383 +++++++++++++++++- ...Sql2008DialectLinqReadonlyCreateScript.sql | Bin 1868954 -> 1870228 bytes ...MsSql2008DialectLinqReadonlyDropScript.sql | Bin 5406 -> 5474 bytes ...Sql2012DialectLinqReadonlyCreateScript.sql | 19 + ...MsSql2012DialectLinqReadonlyDropScript.sql | 1 + ...reSQL83DialectLinqReadonlyCreateScript.sql | Bin 1438124 -> 1439046 bytes ...tgreSQL83DialectLinqReadonlyDropScript.sql | Bin 1240 -> 1294 bytes .../Linq/LinqReadonlyTestsContext.cs | 3 +- src/NHibernate.Test/Linq/LinqTestCase.cs | 3 +- .../Linq/NullComparisonTests.cs | 78 ++-- src/NHibernate.Test/Linq/ParameterTests.cs | 383 +++++++++++++++++- .../Linq/ParameterTypeLocatorTests.cs | 78 +++- .../Visitors/ExpressionParameterVisitor.cs | 2 +- .../Visitors/HqlGeneratorExpressionVisitor.cs | 34 +- .../Linq/Visitors/ParameterTypeLocator.cs | 116 ++++-- src/NHibernate/Type/SingleType.cs | 2 +- src/NHibernate/Util/ExpressionsHelper.cs | 15 + 23 files changed, 1114 insertions(+), 142 deletions(-) create mode 100644 src/NHibernate.DomainModel/Northwind/Entities/NumericEntity.cs create mode 100644 src/NHibernate.DomainModel/Northwind/Mappings/NumericEntity.hbm.xml diff --git a/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs b/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs index 4551ce0e9d8..670b94b5bc5 100755 --- a/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs @@ -69,6 +69,11 @@ public IQueryable Users get { return _session.Query(); } } + public IQueryable NumericEntities + { + get { return _session.Query(); } + } + public IQueryable DynamicUsers { get { return _session.Query(); } diff --git a/src/NHibernate.DomainModel/Northwind/Entities/NumericEntity.cs b/src/NHibernate.DomainModel/Northwind/Entities/NumericEntity.cs new file mode 100644 index 00000000000..f38641303b1 --- /dev/null +++ b/src/NHibernate.DomainModel/Northwind/Entities/NumericEntity.cs @@ -0,0 +1,29 @@ +namespace NHibernate.DomainModel.Northwind.Entities +{ + public class NumericEntity + { + public virtual short Short { get; set; } + + public virtual short? NullableShort { get; set; } + + public virtual int Integer { get; set; } + + public virtual int? NullableInteger { get; set; } + + public virtual long Long { get; set; } + + public virtual long? NullableLong { get; set; } + + public virtual decimal Decimal { get; set; } + + public virtual decimal? NullableDecimal { get; set; } + + public virtual float Single { get; set; } + + public virtual float? NullableSingle { get; set; } + + public virtual double Double { get; set; } + + public virtual double? NullableDouble { get; set; } + } +} diff --git a/src/NHibernate.DomainModel/Northwind/Entities/User.cs b/src/NHibernate.DomainModel/Northwind/Entities/User.cs index 7b2993236fb..14096dac912 100644 --- a/src/NHibernate.DomainModel/Northwind/Entities/User.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/User.cs @@ -50,10 +50,6 @@ public class User : IUser, IEntity public virtual User NotMappedUser => this; - public virtual short Short { get; set; } - - public virtual short? NullableShort { get; set; } - public virtual EnumStoredAsString Enum1 { get; set; } public virtual EnumStoredAsString? NullableEnum1 { get; set; } diff --git a/src/NHibernate.DomainModel/Northwind/Mappings/NumericEntity.hbm.xml b/src/NHibernate.DomainModel/Northwind/Mappings/NumericEntity.hbm.xml new file mode 100644 index 00000000000..2eb696e1acd --- /dev/null +++ b/src/NHibernate.DomainModel/Northwind/Mappings/NumericEntity.hbm.xml @@ -0,0 +1,20 @@ + + + + + + + + + + + + + + + + + + + + diff --git a/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml b/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml index 1b6548bde0f..f249de9574e 100644 --- a/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml +++ b/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml @@ -20,9 +20,6 @@ - - - diff --git a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs index 56cdb135022..728d76649bf 100644 --- a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs @@ -474,26 +474,27 @@ public async Task NullEqualityAsync() await (ExpectAsync(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id == 5), Does.Not.Contain("is null").IgnoreCase)); await (ExpectAsync(session.Query().Where(o => 5 == o.CreatedBy.ModifiedBy.Id), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort == o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(db.NumericEntities.Where(o => o.Short == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(db.NumericEntities.Where(o => o.Short == o.NullableShort), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + short value = 3; - await (ExpectAsync(session.Query().Where(o => o.NullableShort == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(session.Query().Where(o => value == o.NullableShort), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - - await (ExpectAsync(session.Query().Where(o => o.NullableShort.Value == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(session.Query().Where(o => value == o.NullableShort.Value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(session.Query().Where(o => o.Short == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(session.Query().Where(o => value == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - - var shouldCast = Sfi.Dialect is SQLiteDialect || // transparent cast is translated to sql - Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int16.SqlType, out var shortCast) && - Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int32.SqlType, out var intCast) && - shortCast != intCast; - await (ExpectAsync(session.Query().Where(o => o.NullableShort == 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast())); - await (ExpectAsync(session.Query().Where(o => 3 == o.NullableShort), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast())); - - await (ExpectAsync(session.Query().Where(o => o.NullableShort.Value == 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast())); - await (ExpectAsync(session.Query().Where(o => 3 == o.NullableShort.Value), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast())); - await (ExpectAsync(session.Query().Where(o => o.Short == 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast())); - await (ExpectAsync(session.Query().Where(o => 3 == o.Short), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(db.NumericEntities.Where(o => value == o.NullableShort), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort.Value == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(db.NumericEntities.Where(o => value == o.NullableShort.Value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(db.NumericEntities.Where(o => o.Short == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(db.NumericEntities.Where(o => value == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort == 3L), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => 3L == o.NullableShort), WithoutIsNullAndWithoutCast())); + + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort.Value == 3L), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => 3L == o.NullableShort.Value), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => o.Short == 3L), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => 3L == o.Short), WithoutIsNullAndWithoutCast())); } [Test] @@ -583,26 +584,27 @@ public async Task NullInequalityAsync() await (ExpectAsync(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id != 5), Does.Contain("is null").IgnoreCase)); await (ExpectAsync(session.Query().Where(o => 5 != o.CreatedBy.ModifiedBy.Id), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort != o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(db.NumericEntities.Where(o => o.Short != o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort != o.Short), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(db.NumericEntities.Where(o => o.Short != o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + short value = 3; - await (ExpectAsync(session.Query().Where(o => o.NullableShort != value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(session.Query().Where(o => value != o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - - await (ExpectAsync(session.Query().Where(o => o.NullableShort.Value != value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(session.Query().Where(o => value != o.NullableShort.Value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(session.Query().Where(o => o.Short != value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(session.Query().Where(o => value != o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - - var shouldCast = Sfi.Dialect is SQLiteDialect || // transparent cast is translated to sql - Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int16.SqlType, out var shortCast) && - Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int32.SqlType, out var intCast) && - shortCast != intCast; - await (ExpectAsync(session.Query().Where(o => o.NullableShort != 3), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast())); - await (ExpectAsync(session.Query().Where(o => 3 != o.NullableShort), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast())); - - await (ExpectAsync(session.Query().Where(o => o.NullableShort.Value != 3), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast())); - await (ExpectAsync(session.Query().Where(o => 3 != o.NullableShort.Value), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast())); - await (ExpectAsync(session.Query().Where(o => o.Short != 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast())); - await (ExpectAsync(session.Query().Where(o => 3 != o.Short), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort != value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(db.NumericEntities.Where(o => value != o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort.Value != value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(db.NumericEntities.Where(o => value != o.NullableShort.Value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(db.NumericEntities.Where(o => o.Short != value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(db.NumericEntities.Where(o => value != o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort != 3L), WithIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => 3 != o.NullableShort), WithIsNullAndWithoutCast())); + + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort.Value != 3L), WithIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => 3L != o.NullableShort.Value), WithIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => o.Short != 3L), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => 3L != o.Short), WithoutIsNullAndWithoutCast())); } private IResolveConstraint WithIsNullAndWithoutCast() diff --git a/src/NHibernate.Test/Async/Linq/ParameterTests.cs b/src/NHibernate.Test/Async/Linq/ParameterTests.cs index b7f021b6dc0..008daa1f49e 100644 --- a/src/NHibernate.Test/Async/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Async/Linq/ParameterTests.cs @@ -14,11 +14,13 @@ using System.Linq.Expressions; using System.Reflection; using System.Text.RegularExpressions; +using NHibernate.Dialect; using NHibernate.DomainModel.Northwind.Entities; using NHibernate.Engine.Query; using NHibernate.Linq; using NHibernate.Util; using NUnit.Framework; +using NUnit.Framework.Constraints; namespace NHibernate.Test.Linq { @@ -125,12 +127,382 @@ public async Task UsingValueTypeParameterTwiceAsync() 1)); } + [Test] + public async Task CompareIntegralParametersAndColumnsAsync() + { + short shortParam = 1; + var intParam = 2; + var longParam = 3L; + short? nullShortParam = 1; + int? nullIntParam = 2; + long? nullLongParam = 3L; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Short == shortParam || o.Short < intParam || o.Short > longParam), "Int16"}, + {db.NumericEntities.Where(o => o.NullableShort == shortParam || o.NullableShort <= intParam || o.NullableShort != longParam), "Int16"}, + {db.NumericEntities.Where(o => o.Short == nullShortParam || o.Short < nullIntParam || o.Short > nullLongParam), "Int16"}, + {db.NumericEntities.Where(o => o.NullableShort == nullShortParam || o.NullableShort <= nullIntParam || o.NullableShort != nullLongParam), "Int16"}, + + {db.NumericEntities.Where(o => o.Integer == shortParam || o.Integer < intParam || o.Integer > longParam), "Int32"}, + {db.NumericEntities.Where(o => o.NullableInteger == shortParam || o.NullableInteger <= intParam || o.NullableInteger != longParam), "Int32"}, + {db.NumericEntities.Where(o => o.Integer == nullShortParam || o.Integer < nullIntParam || o.Integer > nullLongParam), "Int32"}, + {db.NumericEntities.Where(o => o.NullableInteger == nullShortParam || o.NullableInteger <= nullIntParam || o.NullableInteger != nullLongParam), "Int32"}, + + {db.NumericEntities.Where(o => o.Long == shortParam || o.Long < intParam || o.Long > longParam), "Int64"}, + {db.NumericEntities.Where(o => o.NullableLong == shortParam || o.NullableLong <= intParam || o.NullableLong != longParam), "Int64"}, + {db.NumericEntities.Where(o => o.Long == nullShortParam || o.Long < nullIntParam || o.Long > nullLongParam), "Int64"}, + {db.NumericEntities.Where(o => o.NullableLong == nullShortParam || o.NullableLong <= nullIntParam || o.NullableLong != nullLongParam), "Int64"} + }; + + foreach (var pair in queriables) + { + // Parameters should be pre-evaluated + await (AssertTotalParametersAsync( + pair.Key, + 3, + sql => + { + Assert.That(sql, Does.Not.Contain("cast")); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(3)); + })); + } + } + + [Test] + public async Task CompareIntegralParametersWithFloatingPointColumnsAsync() + { + short shortParam = 1; + var intParam = 2; + var longParam = 3L; + short? nullShortParam = 1; + int? nullIntParam = 2; + long? nullLongParam = 3L; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Decimal == shortParam || o.Decimal < intParam || o.Decimal > longParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == shortParam || o.NullableDecimal <= intParam || o.NullableDecimal != longParam), "Decimal"}, + {db.NumericEntities.Where(o => o.Decimal == nullShortParam || o.Decimal < nullIntParam || o.Decimal > nullLongParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == nullShortParam || o.NullableDecimal <= nullIntParam || o.NullableDecimal != nullLongParam), "Decimal"}, + + {db.NumericEntities.Where(o => o.Single == shortParam || o.Single < intParam || o.Single > longParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == shortParam || o.NullableSingle <= intParam || o.NullableSingle != longParam), "Single"}, + {db.NumericEntities.Where(o => o.Single == nullShortParam || o.Single < nullIntParam || o.Single > nullLongParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == nullShortParam || o.NullableSingle <= nullIntParam || o.NullableSingle != nullLongParam), "Single"}, + + {db.NumericEntities.Where(o => o.Double == shortParam || o.Double < intParam || o.Double > longParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == shortParam || o.NullableDouble <= intParam || o.NullableDouble != longParam), "Double"}, + {db.NumericEntities.Where(o => o.Double == nullShortParam || o.Double < nullIntParam || o.Double > nullLongParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == nullShortParam || o.NullableDouble <= nullIntParam || o.NullableDouble != nullLongParam), "Double"}, + }; + + foreach (var pair in queriables) + { + // Parameters should be pre-evaluated + await (AssertTotalParametersAsync( + pair.Key, + 3, + sql => + { + Assert.That(sql, Does.Not.Contain("cast")); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(3)); + })); + } + } + + [Test] + public async Task CompareFloatingPointParametersAndColumnsAsync() + { + var decimalParam = 1.1m; + var singleParam = 2.2f; + var doubleParam = 3.3d; + decimal? nullDecimalParam = 1.1m; + float? nullSingleParam = 2.2f; + double? nullDoubleParam = 3.3d; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Decimal == decimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == decimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.Decimal == nullDecimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == nullDecimalParam), "Decimal"}, + + {db.NumericEntities.Where(o => o.Single <= singleParam || o.Single >= doubleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == singleParam || o.NullableSingle != doubleParam), "Single"}, + {db.NumericEntities.Where(o => o.Single <= nullSingleParam || o.Single >= nullDoubleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == nullSingleParam || o.NullableSingle != nullDoubleParam), "Single"}, + + {db.NumericEntities.Where(o => o.Double <= singleParam || o.Double >= doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == singleParam || o.NullableDouble != doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.Double <= nullSingleParam || o.Double >= nullDoubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == nullSingleParam || o.NullableDouble != nullDoubleParam), "Double"}, + }; + + foreach (var pair in queriables) + { + var totalParameters = pair.Value == "Decimal" ? 1 : 2; + // Parameters should be pre-evaluated + await (AssertTotalParametersAsync( + pair.Key, + totalParameters, + sql => + { + Assert.That(sql, Does.Not.Contain("cast")); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(totalParameters)); + })); + } + } + + [Test] + public async Task CompareFloatingPointParametersWithIntegralColumnsAsync() + { + var decimalParam = 1.1m; + var singleParam = 2.2f; + var doubleParam = 3.3d; + decimal? nullDecimalParam = 1.1m; + float? nullSingleParam = 2.2f; + double? nullDoubleParam = 3.3d; + var queriables = new List> + { + db.NumericEntities.Where(o => o.Short == decimalParam || o.Short != singleParam || o.Short <= doubleParam), + db.NumericEntities.Where(o => o.NullableShort <= decimalParam || o.NullableShort == singleParam || o.NullableShort >= doubleParam), + db.NumericEntities.Where(o => o.Short == nullDecimalParam || o.Short != nullSingleParam || o.Short <= nullDoubleParam), + db.NumericEntities.Where(o => o.NullableShort <= nullDecimalParam || o.NullableShort == nullSingleParam || o.NullableShort >= nullDoubleParam), + + db.NumericEntities.Where(o => o.Integer == decimalParam || o.Integer != singleParam || o.Integer <= doubleParam), + db.NumericEntities.Where(o => o.NullableInteger <= decimalParam || o.NullableInteger == singleParam || o.NullableInteger >= doubleParam), + db.NumericEntities.Where(o => o.Integer == nullDecimalParam || o.Integer != nullSingleParam || o.Integer <= nullDoubleParam), + db.NumericEntities.Where(o => o.NullableInteger <= nullDecimalParam || o.NullableInteger == nullSingleParam || o.NullableInteger >= nullDoubleParam), + + db.NumericEntities.Where(o => o.Long == decimalParam || o.Long != singleParam || o.Long <= doubleParam), + db.NumericEntities.Where(o => o.NullableLong <= decimalParam || o.NullableLong == singleParam || o.NullableLong >= doubleParam), + db.NumericEntities.Where(o => o.Long == nullDecimalParam || o.Long != nullSingleParam || o.Long <= nullDoubleParam), + db.NumericEntities.Where(o => o.NullableLong <= nullDecimalParam || o.NullableLong == nullSingleParam || o.NullableLong >= nullDoubleParam), + }; + + foreach (var query in queriables) + { + // Columns should be casted + await (AssertTotalParametersAsync( + query, + 3, + sql => + { + var matches = Regex.Matches(sql, @"cast\([\w\d]+\..+\)"); + Assert.That(matches.Count, Is.EqualTo(3)); + Assert.That(GetTotalOccurrences(sql, $"Type: Decimal"), Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, $"Type: Single"), Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, $"Type: Double"), Is.EqualTo(1)); + })); + } + } + + [Test] + public async Task CompareFloatingPointParameterWithIntegralAndFloatingPointColumnsAsync() + { + var decimalParam = 1.1m; + var singleParam = 2.2f; + var doubleParam = 3.3d; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Decimal == decimalParam || o.NullableShort >= decimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.Decimal == decimalParam || o.Long >= decimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == decimalParam || o.Integer != decimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == decimalParam || o.NullableInteger == decimalParam), "Decimal"}, + + {db.NumericEntities.Where(o => o.Single == singleParam || o.NullableShort >= singleParam), "Single"}, + {db.NumericEntities.Where(o => o.Single == singleParam || o.Long >= singleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == singleParam || o.Integer != singleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == singleParam || o.NullableInteger == singleParam), "Single"}, + + {db.NumericEntities.Where(o => o.Double == doubleParam || o.NullableShort >= doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.Double == doubleParam || o.Long >= doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == doubleParam || o.Integer != doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == doubleParam || o.NullableInteger == doubleParam), "Double"} + }; + + foreach (var pair in queriables) + { + // Integral columns should be casted + await (AssertTotalParametersAsync( + pair.Key, + 1, + sql => + { + var matches = Regex.Matches(sql, @"cast\([\w\d]+\..+\)"); + Assert.That(matches.Count, Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(1)); + })); + } + } + + [Test] + public async Task CompareFloatingPointParameterWithDifferentFloatingPointColumnsAsync() + { + var singleParam = 2.2f; + var doubleParam = 3.3d; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Single == singleParam || o.Double >= singleParam), "Single"}, + {db.NumericEntities.Where(o => o.Double >= singleParam || o.Single == singleParam), "Single"}, + {db.NumericEntities.Where(o => o.Single == singleParam || o.NullableDouble <= singleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == singleParam || o.Double != singleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == singleParam || o.NullableDouble == singleParam), "Single"}, + + {db.NumericEntities.Where(o => o.Double == doubleParam || o.Single >= doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.Single >= doubleParam || o.Double == doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.Double == doubleParam || o.NullableSingle >= doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == doubleParam || o.Single != doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == doubleParam || o.NullableSingle == doubleParam), "Double"} + }; + var sameType = Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Single.SqlType, out var singleCast) && + Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Double.SqlType, out var doubleCast) && + singleCast == doubleCast; + + foreach (var pair in queriables) + { + // Columns should be casted for Double parameter and parameters for Single parameter + await (AssertTotalParametersAsync( + pair.Key, + 1, + sql => + { + var matches = pair.Value == "Double" + ? Regex.Matches(sql, @"cast\([\w\d]+\..+\)") + : Regex.Matches(sql, @"cast\(((@|\?|:)p\d+|\?)\s+as.*\)"); + // SQLiteDialect uses sql cast for transparentcast method + Assert.That(matches.Count, Is.EqualTo(sameType && !(Sfi.Dialect is SQLiteDialect) ? 0 : 1)); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(1)); + })); + } + } + + [Test] + public async Task CompareIntegralParameterWithIntegralAndFloatingPointColumnsAsync() + { + short shortParam = 1; + var intParam = 2; + var longParam = 3L; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Short == shortParam || o.Double >= shortParam), "Int16"}, + {db.NumericEntities.Where(o => o.Short == shortParam || o.NullableDouble >= shortParam), "Int16"}, + {db.NumericEntities.Where(o => o.NullableShort == shortParam || o.Decimal != shortParam), "Int16"}, + {db.NumericEntities.Where(o => o.NullableShort == shortParam || o.NullableSingle > shortParam), "Int16"}, + + {db.NumericEntities.Where(o => o.Integer == intParam || o.Double >= intParam), "Int32"}, + {db.NumericEntities.Where(o => o.Integer == intParam || o.NullableDouble >= intParam), "Int32"}, + {db.NumericEntities.Where(o => o.NullableInteger == intParam || o.Decimal != intParam), "Int32"}, + {db.NumericEntities.Where(o => o.NullableInteger == intParam || o.NullableSingle > intParam), "Int32"}, + + {db.NumericEntities.Where(o => o.Long == longParam || o.Double >= longParam), "Int64"}, + {db.NumericEntities.Where(o => o.Long == longParam || o.NullableDouble >= longParam), "Int64"}, + {db.NumericEntities.Where(o => o.NullableLong == longParam || o.Decimal != longParam), "Int64"}, + {db.NumericEntities.Where(o => o.NullableLong == longParam || o.NullableSingle > longParam), "Int64"} + }; + + foreach (var pair in queriables) + { + // Parameters should be casted + await (AssertTotalParametersAsync( + pair.Key, + 1, + sql => + { + var matches = Regex.Matches(sql, @"cast\(((@|\?|:)p\d+|\?)\s+as.*\)"); + Assert.That(matches.Count, Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(1)); + })); + } + } + + [Test] + public async Task UsingValueTypeParameterOfDifferentTypeAsync() + { + var value = 1; + var queriables = new List> + { + db.NumericEntities.Where(o => o.Short == value), + db.NumericEntities.Where(o => o.Short != value), + db.NumericEntities.Where(o => o.Short >= value), + db.NumericEntities.Where(o => o.Short <= value), + db.NumericEntities.Where(o => o.Short > value), + db.NumericEntities.Where(o => o.Short < value), + + db.NumericEntities.Where(o => o.NullableShort == value), + db.NumericEntities.Where(o => o.NullableShort != value), + db.NumericEntities.Where(o => o.NullableShort >= value), + db.NumericEntities.Where(o => o.NullableShort <= value), + db.NumericEntities.Where(o => o.NullableShort > value), + db.NumericEntities.Where(o => o.NullableShort < value), + + db.NumericEntities.Where(o => o.NullableShort.Value == value), + db.NumericEntities.Where(o => o.NullableShort.Value != value), + db.NumericEntities.Where(o => o.NullableShort.Value >= value), + db.NumericEntities.Where(o => o.NullableShort.Value <= value), + db.NumericEntities.Where(o => o.NullableShort.Value > value), + db.NumericEntities.Where(o => o.NullableShort.Value < value) + }; + + foreach (var query in queriables) + { + await (AssertTotalParametersAsync( + query, + 1, + sql => Assert.That(sql, Does.Not.Contain("cast")))); + } + + queriables = new List> + { + db.NumericEntities.Where(o => o.Short + value > value), + db.NumericEntities.Where(o => o.Short - value > value), + db.NumericEntities.Where(o => o.Short * value > value), + + db.NumericEntities.Where(o => o.NullableShort + value > value), + db.NumericEntities.Where(o => o.NullableShort - value > value), + db.NumericEntities.Where(o => o.NullableShort * value > value), + + db.NumericEntities.Where(o => o.NullableShort.Value + value > value), + db.NumericEntities.Where(o => o.NullableShort.Value - value > value), + db.NumericEntities.Where(o => o.NullableShort.Value * value > value), + }; + + var sameType = Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int16.SqlType, out var shortCast) && + Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int32.SqlType, out var intCast) && + shortCast == intCast; + foreach (var query in queriables) + { + await (AssertTotalParametersAsync( + query, + 1, + sql => { + // SQLiteDialect uses sql cast for transparentcast method + Assert.That(sql, !sameType || Sfi.Dialect is SQLiteDialect ? Does.Match("where\\s+cast") : (IResolveConstraint)Does.Not.Contain("cast")); + Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(!sameType || Sfi.Dialect is SQLiteDialect ? 1 : 0)); + })); + } + } + [Test] public async Task UsingValueTypeParameterTwiceOnNullablePropertyAsync() { short value = 1; await (AssertTotalParametersAsync( - db.Users.Where(o => o.NullableShort == value && o.NullableShort != value && o.Short == value), + db.NumericEntities.Where(o => o.NullableShort == value && o.NullableShort != value && o.Short == value), + 1, sql => { + + Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(0)); + })); + } + + [Test] + public async Task UsingValueTypeParameterOnDifferentPropertiesAsync() + { + int value = 1; + await (AssertTotalParametersAsync( + db.NumericEntities.Where(o => o.NullableShort == value && o.NullableShort != value && o.Integer == value), + 1)); + + await (AssertTotalParametersAsync( + db.NumericEntities.Where(o => o.Integer == value && o.NullableShort == value && o.NullableShort != value), 1)); } @@ -384,7 +756,12 @@ public async Task UsingTwoParametersInDMLDeleteAsync() 2)); } - private async Task AssertTotalParametersAsync(IQueryable query, int parameterNumber, int? linqParameterNumber = null, CancellationToken cancellationToken = default(CancellationToken)) + private Task AssertTotalParametersAsync(IQueryable query, int parameterNumber, Action sqlAction, CancellationToken cancellationToken = default(CancellationToken)) + { + return AssertTotalParametersAsync(query, parameterNumber, null, sqlAction, cancellationToken); + } + + private async Task AssertTotalParametersAsync(IQueryable query, int parameterNumber, int? linqParameterNumber = null, Action sqlAction = null, CancellationToken cancellationToken = default(CancellationToken)) { using (var sqlSpy = new SqlLogSpy()) { @@ -403,6 +780,8 @@ public async Task UsingTwoParametersInDMLDeleteAsync() await (query.ToListAsync(cancellationToken)); + sqlAction?.Invoke(sqlSpy.GetWholeLog()); + // In case of arrays two query plans will be stored, one with an one without expended parameters Assert.That(cache, Has.Count.EqualTo(linqParameterNumber.HasValue ? 2 : 1), "Query should be cacheable"); diff --git a/src/NHibernate.Test/DbScripts/MsSql2008DialectLinqReadonlyCreateScript.sql b/src/NHibernate.Test/DbScripts/MsSql2008DialectLinqReadonlyCreateScript.sql index 12235095e0ddb6ef8e6d4eea097d578a206e2565..cb42443637d31502508cf8adff0f18ab015d10ea 100644 GIT binary patch delta 575 zcmaLTJx{_=6b9gP{Q#lT78FYb{D>w3Mhu$+#z7oRj4@%;IPlR}62w>_#Kl1e;^J~m zZsKA>oSZZcuKWfUSK|+Ga&S`Lf&(Sd;c41?&)a*>y}q!nxOKmx@Jm;0ZJX+JNCsI{ zq5^GDg=%=K(FyHSjC|xJ5A9KwOgchM4XZQwt|F78GL^;bl)S!Bw_mZ4$J6+yIQPP$ zp3Tk~1|5i-cX?U28|Y0twl&dCMpX9}4zD2a@XVHYo4JC{NYAz|z7ldkGB8W2+oSlF zlvO(vc7`(O)KL@{I|*6Uvk2m>&LrJoEmUGyky@0e|H4@`$JKn={d?W~*i7+TGtJ-3T|Qq(ijO8MKfd}ijGq?0=Pk9_QW*@uAPm7U oxWEk~-~lg4FbY19K><7OhX9PhI0Rt=GzftXlMsGtsgcj>57|DL!2kdN delta 140 zcmWm6y$ymu0D#dTsDNJx{_=6oy~?76nvvz{F5%!eAUU(fGBpn3x!Y35&$=(G(OZwq;;qFsVP__GdV_ zIXSTSN8H?8nS6UIp~XN#Zk~Red(L^^bMNa{{_{KkcI%4X!Dgy~79OF&L=9!W8?gCh zqb<7CWTy&)J2n(p%qpVJRg*aipN8D=I(Tt~^#-FBqb73)c_S2XhbA7V*=B81I`A}1 z`7h%)*h36aZu_e*RYZw`8l^@X3|FeSkhV4GB65;C9yTA&9@gQ|0^8t)(Z3)bP3m}$a1FE-~CWcbPL<3lPXpEJ;X=hNc8qI&yM@zxDhf%rpXMM iC39q+ERaPKB}-(P#7LYZ$O>5{Ya~h5-yL@&Q}_k4JgTe! delta 102 zcmWN=xe-866hP6j?>qS$fj~n)p&Ee#2wo8anN*|^nJU~fbDqx}uQ{iEyZdnwT}3V0 cbm-Ef&wwE##!Q$pW6pvlE7ok-UZUMF|Je#8EdT%j diff --git a/src/NHibernate.Test/DbScripts/PostgreSQL83DialectLinqReadonlyDropScript.sql b/src/NHibernate.Test/DbScripts/PostgreSQL83DialectLinqReadonlyDropScript.sql index 27ff0b0e3940a249b7b1f6d66e78a4123a2f98e6..a9b1f2642ee5d6bc1eeb736160b0b1619989de24 100644 GIT binary patch delta 40 tcmcb?*~hhE2a9ALLn%WpLn=cNLncEqkj`T$0kTRMDkn~qoxFoZ0RZZx3q1e; delta 11 ScmeC Mappings "Northwind.Mappings.User.hbm.xml", "Northwind.Mappings.TimeSheet.hbm.xml", "Northwind.Mappings.Animal.hbm.xml", - "Northwind.Mappings.Patient.hbm.xml" + "Northwind.Mappings.Patient.hbm.xml", + "Northwind.Mappings.NumericEntity.hbm.xml" }; } } diff --git a/src/NHibernate.Test/Linq/LinqTestCase.cs b/src/NHibernate.Test/Linq/LinqTestCase.cs index daf14b9cd18..ecd9c7690e3 100755 --- a/src/NHibernate.Test/Linq/LinqTestCase.cs +++ b/src/NHibernate.Test/Linq/LinqTestCase.cs @@ -35,7 +35,8 @@ protected override string[] Mappings "Northwind.Mappings.TimeSheet.hbm.xml", "Northwind.Mappings.Animal.hbm.xml", "Northwind.Mappings.Patient.hbm.xml", - "Northwind.Mappings.DynamicUser.hbm.xml" + "Northwind.Mappings.DynamicUser.hbm.xml", + "Northwind.Mappings.NumericEntity.hbm.xml" }; } } diff --git a/src/NHibernate.Test/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Linq/NullComparisonTests.cs index ad3b9a74bc7..4b54bab5557 100644 --- a/src/NHibernate.Test/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Linq/NullComparisonTests.cs @@ -462,26 +462,27 @@ public void NullEquality() Expect(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id == 5), Does.Not.Contain("is null").IgnoreCase); Expect(session.Query().Where(o => 5 == o.CreatedBy.ModifiedBy.Id), Does.Not.Contain("is null").IgnoreCase); + Expect(db.NumericEntities.Where(o => o.NullableShort == o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(db.NumericEntities.Where(o => o.Short == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(db.NumericEntities.Where(o => o.NullableShort == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(db.NumericEntities.Where(o => o.Short == o.NullableShort), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + short value = 3; - Expect(session.Query().Where(o => o.NullableShort == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(session.Query().Where(o => value == o.NullableShort), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - - Expect(session.Query().Where(o => o.NullableShort.Value == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(session.Query().Where(o => value == o.NullableShort.Value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(session.Query().Where(o => o.Short == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(session.Query().Where(o => value == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - - var shouldCast = Sfi.Dialect is SQLiteDialect || // transparent cast is translated to sql - Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int16.SqlType, out var shortCast) && - Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int32.SqlType, out var intCast) && - shortCast != intCast; - Expect(session.Query().Where(o => o.NullableShort == 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()); - Expect(session.Query().Where(o => 3 == o.NullableShort), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()); - - Expect(session.Query().Where(o => o.NullableShort.Value == 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()); - Expect(session.Query().Where(o => 3 == o.NullableShort.Value), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()); - Expect(session.Query().Where(o => o.Short == 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()); - Expect(session.Query().Where(o => 3 == o.Short), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => o.NullableShort == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(db.NumericEntities.Where(o => value == o.NullableShort), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + + Expect(db.NumericEntities.Where(o => o.NullableShort.Value == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(db.NumericEntities.Where(o => value == o.NullableShort.Value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(db.NumericEntities.Where(o => o.Short == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(db.NumericEntities.Where(o => value == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + + Expect(db.NumericEntities.Where(o => o.NullableShort == 3L), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => 3L == o.NullableShort), WithoutIsNullAndWithoutCast()); + + Expect(db.NumericEntities.Where(o => o.NullableShort.Value == 3L), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => 3L == o.NullableShort.Value), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => o.Short == 3L), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => 3L == o.Short), WithoutIsNullAndWithoutCast()); } [Test] @@ -571,26 +572,27 @@ public void NullInequality() Expect(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id != 5), Does.Contain("is null").IgnoreCase); Expect(session.Query().Where(o => 5 != o.CreatedBy.ModifiedBy.Id), Does.Contain("is null").IgnoreCase); + Expect(db.NumericEntities.Where(o => o.NullableShort != o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(db.NumericEntities.Where(o => o.Short != o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(db.NumericEntities.Where(o => o.NullableShort != o.Short), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(db.NumericEntities.Where(o => o.Short != o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + short value = 3; - Expect(session.Query().Where(o => o.NullableShort != value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(session.Query().Where(o => value != o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - - Expect(session.Query().Where(o => o.NullableShort.Value != value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(session.Query().Where(o => value != o.NullableShort.Value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(session.Query().Where(o => o.Short != value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(session.Query().Where(o => value != o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - - var shouldCast = Sfi.Dialect is SQLiteDialect || // transparent cast is translated to sql - Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int16.SqlType, out var shortCast) && - Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int32.SqlType, out var intCast) && - shortCast != intCast; - Expect(session.Query().Where(o => o.NullableShort != 3), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast()); - Expect(session.Query().Where(o => 3 != o.NullableShort), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast()); - - Expect(session.Query().Where(o => o.NullableShort.Value != 3), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast()); - Expect(session.Query().Where(o => 3 != o.NullableShort.Value), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast()); - Expect(session.Query().Where(o => o.Short != 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()); - Expect(session.Query().Where(o => 3 != o.Short), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => o.NullableShort != value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(db.NumericEntities.Where(o => value != o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + + Expect(db.NumericEntities.Where(o => o.NullableShort.Value != value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(db.NumericEntities.Where(o => value != o.NullableShort.Value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(db.NumericEntities.Where(o => o.Short != value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(db.NumericEntities.Where(o => value != o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + + Expect(db.NumericEntities.Where(o => o.NullableShort != 3L), WithIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => 3 != o.NullableShort), WithIsNullAndWithoutCast()); + + Expect(db.NumericEntities.Where(o => o.NullableShort.Value != 3L), WithIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => 3L != o.NullableShort.Value), WithIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => o.Short != 3L), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => 3L != o.Short), WithoutIsNullAndWithoutCast()); } private IResolveConstraint WithIsNullAndWithoutCast() diff --git a/src/NHibernate.Test/Linq/ParameterTests.cs b/src/NHibernate.Test/Linq/ParameterTests.cs index 8a831ff28fc..c0f80819fec 100644 --- a/src/NHibernate.Test/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTests.cs @@ -4,11 +4,13 @@ using System.Linq.Expressions; using System.Reflection; using System.Text.RegularExpressions; +using NHibernate.Dialect; using NHibernate.DomainModel.Northwind.Entities; using NHibernate.Engine.Query; using NHibernate.Linq; using NHibernate.Util; using NUnit.Framework; +using NUnit.Framework.Constraints; namespace NHibernate.Test.Linq { @@ -113,12 +115,382 @@ public void UsingValueTypeParameterTwice() 1); } + [Test] + public void CompareIntegralParametersAndColumns() + { + short shortParam = 1; + var intParam = 2; + var longParam = 3L; + short? nullShortParam = 1; + int? nullIntParam = 2; + long? nullLongParam = 3L; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Short == shortParam || o.Short < intParam || o.Short > longParam), "Int16"}, + {db.NumericEntities.Where(o => o.NullableShort == shortParam || o.NullableShort <= intParam || o.NullableShort != longParam), "Int16"}, + {db.NumericEntities.Where(o => o.Short == nullShortParam || o.Short < nullIntParam || o.Short > nullLongParam), "Int16"}, + {db.NumericEntities.Where(o => o.NullableShort == nullShortParam || o.NullableShort <= nullIntParam || o.NullableShort != nullLongParam), "Int16"}, + + {db.NumericEntities.Where(o => o.Integer == shortParam || o.Integer < intParam || o.Integer > longParam), "Int32"}, + {db.NumericEntities.Where(o => o.NullableInteger == shortParam || o.NullableInteger <= intParam || o.NullableInteger != longParam), "Int32"}, + {db.NumericEntities.Where(o => o.Integer == nullShortParam || o.Integer < nullIntParam || o.Integer > nullLongParam), "Int32"}, + {db.NumericEntities.Where(o => o.NullableInteger == nullShortParam || o.NullableInteger <= nullIntParam || o.NullableInteger != nullLongParam), "Int32"}, + + {db.NumericEntities.Where(o => o.Long == shortParam || o.Long < intParam || o.Long > longParam), "Int64"}, + {db.NumericEntities.Where(o => o.NullableLong == shortParam || o.NullableLong <= intParam || o.NullableLong != longParam), "Int64"}, + {db.NumericEntities.Where(o => o.Long == nullShortParam || o.Long < nullIntParam || o.Long > nullLongParam), "Int64"}, + {db.NumericEntities.Where(o => o.NullableLong == nullShortParam || o.NullableLong <= nullIntParam || o.NullableLong != nullLongParam), "Int64"} + }; + + foreach (var pair in queriables) + { + // Parameters should be pre-evaluated + AssertTotalParameters( + pair.Key, + 3, + sql => + { + Assert.That(sql, Does.Not.Contain("cast")); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(3)); + }); + } + } + + [Test] + public void CompareIntegralParametersWithFloatingPointColumns() + { + short shortParam = 1; + var intParam = 2; + var longParam = 3L; + short? nullShortParam = 1; + int? nullIntParam = 2; + long? nullLongParam = 3L; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Decimal == shortParam || o.Decimal < intParam || o.Decimal > longParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == shortParam || o.NullableDecimal <= intParam || o.NullableDecimal != longParam), "Decimal"}, + {db.NumericEntities.Where(o => o.Decimal == nullShortParam || o.Decimal < nullIntParam || o.Decimal > nullLongParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == nullShortParam || o.NullableDecimal <= nullIntParam || o.NullableDecimal != nullLongParam), "Decimal"}, + + {db.NumericEntities.Where(o => o.Single == shortParam || o.Single < intParam || o.Single > longParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == shortParam || o.NullableSingle <= intParam || o.NullableSingle != longParam), "Single"}, + {db.NumericEntities.Where(o => o.Single == nullShortParam || o.Single < nullIntParam || o.Single > nullLongParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == nullShortParam || o.NullableSingle <= nullIntParam || o.NullableSingle != nullLongParam), "Single"}, + + {db.NumericEntities.Where(o => o.Double == shortParam || o.Double < intParam || o.Double > longParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == shortParam || o.NullableDouble <= intParam || o.NullableDouble != longParam), "Double"}, + {db.NumericEntities.Where(o => o.Double == nullShortParam || o.Double < nullIntParam || o.Double > nullLongParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == nullShortParam || o.NullableDouble <= nullIntParam || o.NullableDouble != nullLongParam), "Double"}, + }; + + foreach (var pair in queriables) + { + // Parameters should be pre-evaluated + AssertTotalParameters( + pair.Key, + 3, + sql => + { + Assert.That(sql, Does.Not.Contain("cast")); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(3)); + }); + } + } + + [Test] + public void CompareFloatingPointParametersAndColumns() + { + var decimalParam = 1.1m; + var singleParam = 2.2f; + var doubleParam = 3.3d; + decimal? nullDecimalParam = 1.1m; + float? nullSingleParam = 2.2f; + double? nullDoubleParam = 3.3d; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Decimal == decimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == decimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.Decimal == nullDecimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == nullDecimalParam), "Decimal"}, + + {db.NumericEntities.Where(o => o.Single <= singleParam || o.Single >= doubleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == singleParam || o.NullableSingle != doubleParam), "Single"}, + {db.NumericEntities.Where(o => o.Single <= nullSingleParam || o.Single >= nullDoubleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == nullSingleParam || o.NullableSingle != nullDoubleParam), "Single"}, + + {db.NumericEntities.Where(o => o.Double <= singleParam || o.Double >= doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == singleParam || o.NullableDouble != doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.Double <= nullSingleParam || o.Double >= nullDoubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == nullSingleParam || o.NullableDouble != nullDoubleParam), "Double"}, + }; + + foreach (var pair in queriables) + { + var totalParameters = pair.Value == "Decimal" ? 1 : 2; + // Parameters should be pre-evaluated + AssertTotalParameters( + pair.Key, + totalParameters, + sql => + { + Assert.That(sql, Does.Not.Contain("cast")); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(totalParameters)); + }); + } + } + + [Test] + public void CompareFloatingPointParametersWithIntegralColumns() + { + var decimalParam = 1.1m; + var singleParam = 2.2f; + var doubleParam = 3.3d; + decimal? nullDecimalParam = 1.1m; + float? nullSingleParam = 2.2f; + double? nullDoubleParam = 3.3d; + var queriables = new List> + { + db.NumericEntities.Where(o => o.Short == decimalParam || o.Short != singleParam || o.Short <= doubleParam), + db.NumericEntities.Where(o => o.NullableShort <= decimalParam || o.NullableShort == singleParam || o.NullableShort >= doubleParam), + db.NumericEntities.Where(o => o.Short == nullDecimalParam || o.Short != nullSingleParam || o.Short <= nullDoubleParam), + db.NumericEntities.Where(o => o.NullableShort <= nullDecimalParam || o.NullableShort == nullSingleParam || o.NullableShort >= nullDoubleParam), + + db.NumericEntities.Where(o => o.Integer == decimalParam || o.Integer != singleParam || o.Integer <= doubleParam), + db.NumericEntities.Where(o => o.NullableInteger <= decimalParam || o.NullableInteger == singleParam || o.NullableInteger >= doubleParam), + db.NumericEntities.Where(o => o.Integer == nullDecimalParam || o.Integer != nullSingleParam || o.Integer <= nullDoubleParam), + db.NumericEntities.Where(o => o.NullableInteger <= nullDecimalParam || o.NullableInteger == nullSingleParam || o.NullableInteger >= nullDoubleParam), + + db.NumericEntities.Where(o => o.Long == decimalParam || o.Long != singleParam || o.Long <= doubleParam), + db.NumericEntities.Where(o => o.NullableLong <= decimalParam || o.NullableLong == singleParam || o.NullableLong >= doubleParam), + db.NumericEntities.Where(o => o.Long == nullDecimalParam || o.Long != nullSingleParam || o.Long <= nullDoubleParam), + db.NumericEntities.Where(o => o.NullableLong <= nullDecimalParam || o.NullableLong == nullSingleParam || o.NullableLong >= nullDoubleParam), + }; + + foreach (var query in queriables) + { + // Columns should be casted + AssertTotalParameters( + query, + 3, + sql => + { + var matches = Regex.Matches(sql, @"cast\([\w\d]+\..+\)"); + Assert.That(matches.Count, Is.EqualTo(3)); + Assert.That(GetTotalOccurrences(sql, $"Type: Decimal"), Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, $"Type: Single"), Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, $"Type: Double"), Is.EqualTo(1)); + }); + } + } + + [Test] + public void CompareFloatingPointParameterWithIntegralAndFloatingPointColumns() + { + var decimalParam = 1.1m; + var singleParam = 2.2f; + var doubleParam = 3.3d; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Decimal == decimalParam || o.NullableShort >= decimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.Decimal == decimalParam || o.Long >= decimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == decimalParam || o.Integer != decimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == decimalParam || o.NullableInteger == decimalParam), "Decimal"}, + + {db.NumericEntities.Where(o => o.Single == singleParam || o.NullableShort >= singleParam), "Single"}, + {db.NumericEntities.Where(o => o.Single == singleParam || o.Long >= singleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == singleParam || o.Integer != singleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == singleParam || o.NullableInteger == singleParam), "Single"}, + + {db.NumericEntities.Where(o => o.Double == doubleParam || o.NullableShort >= doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.Double == doubleParam || o.Long >= doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == doubleParam || o.Integer != doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == doubleParam || o.NullableInteger == doubleParam), "Double"} + }; + + foreach (var pair in queriables) + { + // Integral columns should be casted + AssertTotalParameters( + pair.Key, + 1, + sql => + { + var matches = Regex.Matches(sql, @"cast\([\w\d]+\..+\)"); + Assert.That(matches.Count, Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(1)); + }); + } + } + + [Test] + public void CompareFloatingPointParameterWithDifferentFloatingPointColumns() + { + var singleParam = 2.2f; + var doubleParam = 3.3d; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Single == singleParam || o.Double >= singleParam), "Single"}, + {db.NumericEntities.Where(o => o.Double >= singleParam || o.Single == singleParam), "Single"}, + {db.NumericEntities.Where(o => o.Single == singleParam || o.NullableDouble <= singleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == singleParam || o.Double != singleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == singleParam || o.NullableDouble == singleParam), "Single"}, + + {db.NumericEntities.Where(o => o.Double == doubleParam || o.Single >= doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.Single >= doubleParam || o.Double == doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.Double == doubleParam || o.NullableSingle >= doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == doubleParam || o.Single != doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == doubleParam || o.NullableSingle == doubleParam), "Double"} + }; + var sameType = Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Single.SqlType, out var singleCast) && + Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Double.SqlType, out var doubleCast) && + singleCast == doubleCast; + + foreach (var pair in queriables) + { + // Columns should be casted for Double parameter and parameters for Single parameter + AssertTotalParameters( + pair.Key, + 1, + sql => + { + var matches = pair.Value == "Double" + ? Regex.Matches(sql, @"cast\([\w\d]+\..+\)") + : Regex.Matches(sql, @"cast\(((@|\?|:)p\d+|\?)\s+as.*\)"); + // SQLiteDialect uses sql cast for transparentcast method + Assert.That(matches.Count, Is.EqualTo(sameType && !(Sfi.Dialect is SQLiteDialect) ? 0 : 1)); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(1)); + }); + } + } + + [Test] + public void CompareIntegralParameterWithIntegralAndFloatingPointColumns() + { + short shortParam = 1; + var intParam = 2; + var longParam = 3L; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Short == shortParam || o.Double >= shortParam), "Int16"}, + {db.NumericEntities.Where(o => o.Short == shortParam || o.NullableDouble >= shortParam), "Int16"}, + {db.NumericEntities.Where(o => o.NullableShort == shortParam || o.Decimal != shortParam), "Int16"}, + {db.NumericEntities.Where(o => o.NullableShort == shortParam || o.NullableSingle > shortParam), "Int16"}, + + {db.NumericEntities.Where(o => o.Integer == intParam || o.Double >= intParam), "Int32"}, + {db.NumericEntities.Where(o => o.Integer == intParam || o.NullableDouble >= intParam), "Int32"}, + {db.NumericEntities.Where(o => o.NullableInteger == intParam || o.Decimal != intParam), "Int32"}, + {db.NumericEntities.Where(o => o.NullableInteger == intParam || o.NullableSingle > intParam), "Int32"}, + + {db.NumericEntities.Where(o => o.Long == longParam || o.Double >= longParam), "Int64"}, + {db.NumericEntities.Where(o => o.Long == longParam || o.NullableDouble >= longParam), "Int64"}, + {db.NumericEntities.Where(o => o.NullableLong == longParam || o.Decimal != longParam), "Int64"}, + {db.NumericEntities.Where(o => o.NullableLong == longParam || o.NullableSingle > longParam), "Int64"} + }; + + foreach (var pair in queriables) + { + // Parameters should be casted + AssertTotalParameters( + pair.Key, + 1, + sql => + { + var matches = Regex.Matches(sql, @"cast\(((@|\?|:)p\d+|\?)\s+as.*\)"); + Assert.That(matches.Count, Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(1)); + }); + } + } + + [Test] + public void UsingValueTypeParameterOfDifferentType() + { + var value = 1; + var queriables = new List> + { + db.NumericEntities.Where(o => o.Short == value), + db.NumericEntities.Where(o => o.Short != value), + db.NumericEntities.Where(o => o.Short >= value), + db.NumericEntities.Where(o => o.Short <= value), + db.NumericEntities.Where(o => o.Short > value), + db.NumericEntities.Where(o => o.Short < value), + + db.NumericEntities.Where(o => o.NullableShort == value), + db.NumericEntities.Where(o => o.NullableShort != value), + db.NumericEntities.Where(o => o.NullableShort >= value), + db.NumericEntities.Where(o => o.NullableShort <= value), + db.NumericEntities.Where(o => o.NullableShort > value), + db.NumericEntities.Where(o => o.NullableShort < value), + + db.NumericEntities.Where(o => o.NullableShort.Value == value), + db.NumericEntities.Where(o => o.NullableShort.Value != value), + db.NumericEntities.Where(o => o.NullableShort.Value >= value), + db.NumericEntities.Where(o => o.NullableShort.Value <= value), + db.NumericEntities.Where(o => o.NullableShort.Value > value), + db.NumericEntities.Where(o => o.NullableShort.Value < value) + }; + + foreach (var query in queriables) + { + AssertTotalParameters( + query, + 1, + sql => Assert.That(sql, Does.Not.Contain("cast"))); + } + + queriables = new List> + { + db.NumericEntities.Where(o => o.Short + value > value), + db.NumericEntities.Where(o => o.Short - value > value), + db.NumericEntities.Where(o => o.Short * value > value), + + db.NumericEntities.Where(o => o.NullableShort + value > value), + db.NumericEntities.Where(o => o.NullableShort - value > value), + db.NumericEntities.Where(o => o.NullableShort * value > value), + + db.NumericEntities.Where(o => o.NullableShort.Value + value > value), + db.NumericEntities.Where(o => o.NullableShort.Value - value > value), + db.NumericEntities.Where(o => o.NullableShort.Value * value > value), + }; + + var sameType = Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int16.SqlType, out var shortCast) && + Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int32.SqlType, out var intCast) && + shortCast == intCast; + foreach (var query in queriables) + { + AssertTotalParameters( + query, + 1, + sql => { + // SQLiteDialect uses sql cast for transparentcast method + Assert.That(sql, !sameType || Sfi.Dialect is SQLiteDialect ? Does.Match("where\\s+cast") : (IResolveConstraint)Does.Not.Contain("cast")); + Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(!sameType || Sfi.Dialect is SQLiteDialect ? 1 : 0)); + }); + } + } + [Test] public void UsingValueTypeParameterTwiceOnNullableProperty() { short value = 1; AssertTotalParameters( - db.Users.Where(o => o.NullableShort == value && o.NullableShort != value && o.Short == value), + db.NumericEntities.Where(o => o.NullableShort == value && o.NullableShort != value && o.Short == value), + 1, sql => { + + Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(0)); + }); + } + + [Test] + public void UsingValueTypeParameterOnDifferentProperties() + { + int value = 1; + AssertTotalParameters( + db.NumericEntities.Where(o => o.NullableShort == value && o.NullableShort != value && o.Integer == value), + 1); + + AssertTotalParameters( + db.NumericEntities.Where(o => o.Integer == value && o.NullableShort == value && o.NullableShort != value), 1); } @@ -445,7 +817,12 @@ public void DMLDeleteShouldHaveSameCacheKeys() Assert.That(expression1.Key, Is.EqualTo(expression2.Key)); } - private void AssertTotalParameters(IQueryable query, int parameterNumber, int? linqParameterNumber = null) + private void AssertTotalParameters(IQueryable query, int parameterNumber, Action sqlAction) + { + AssertTotalParameters(query, parameterNumber, null, sqlAction); + } + + private void AssertTotalParameters(IQueryable query, int parameterNumber, int? linqParameterNumber = null, Action sqlAction = null) { using (var sqlSpy = new SqlLogSpy()) { @@ -464,6 +841,8 @@ private void AssertTotalParameters(IQueryable query, int parameterNumber, query.ToList(); + sqlAction?.Invoke(sqlSpy.GetWholeLog()); + // In case of arrays two query plans will be stored, one with an one without expended parameters Assert.That(cache, Has.Count.EqualTo(linqParameterNumber.HasValue ? 2 : 1), "Query should be cacheable"); diff --git a/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs b/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs index ba4505dda8c..48cc2e14048 100644 --- a/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs @@ -64,7 +64,7 @@ public void GreaterThanTest() AssertResults( new Dictionary> { - {"2.1", o => o is Int32Type} + {"2.1", o => o is DoubleType} }, db.Users.Where(o => o.Id > 2.1), db.Users.Where(o => 2.1 > o.Id) @@ -99,6 +99,82 @@ public void EqualsMethodStringTest() ); } + [Test] + public void BinaryIntShortTest() + { + AssertResults( + new Dictionary> + { + {"3", o => o is Int16Type} + }, + db.NumericEntities.Where(o => o.Short == 3), + db.NumericEntities.Where(o => 3 == o.Short), + db.NumericEntities.Where(o => o.Short < 3), + db.NumericEntities.Where(o => 3 < o.Short), + db.NumericEntities.Where(o => o.Short > 3), + db.NumericEntities.Where(o => 3 > o.Short), + db.NumericEntities.Where(o => o.Short >= 3), + db.NumericEntities.Where(o => 3 >= o.Short), + db.NumericEntities.Where(o => o.Short <= 3), + db.NumericEntities.Where(o => 3 <= o.Short), + db.NumericEntities.Where(o => o.Short != 3), + db.NumericEntities.Where(o => 3 != o.Short), + + db.NumericEntities.Where(o => o.NullableShort == 3), + db.NumericEntities.Where(o => 3 == o.NullableShort), + db.NumericEntities.Where(o => o.NullableShort < 3), + db.NumericEntities.Where(o => 3 < o.NullableShort), + db.NumericEntities.Where(o => o.NullableShort > 3), + db.NumericEntities.Where(o => 3 > o.NullableShort), + db.NumericEntities.Where(o => o.NullableShort >= 3), + db.NumericEntities.Where(o => 3 >= o.NullableShort), + db.NumericEntities.Where(o => o.NullableShort <= 3), + db.NumericEntities.Where(o => 3 <= o.NullableShort), + db.NumericEntities.Where(o => o.NullableShort != 3), + db.NumericEntities.Where(o => 3 != o.NullableShort), + + db.NumericEntities.Where(o => o.NullableShort.Value == 3), + db.NumericEntities.Where(o => 3 == o.NullableShort.Value), + db.NumericEntities.Where(o => o.NullableShort.Value < 3), + db.NumericEntities.Where(o => 3 < o.NullableShort.Value), + db.NumericEntities.Where(o => o.NullableShort.Value > 3), + db.NumericEntities.Where(o => 3 > o.NullableShort.Value), + db.NumericEntities.Where(o => o.NullableShort.Value >= 3), + db.NumericEntities.Where(o => 3 >= o.NullableShort.Value), + db.NumericEntities.Where(o => o.NullableShort.Value <= 3), + db.NumericEntities.Where(o => 3 <= o.NullableShort.Value), + db.NumericEntities.Where(o => o.NullableShort.Value != 3), + db.NumericEntities.Where(o => 3 != o.NullableShort.Value) + ); + } + + [Test] + public void BinaryNullableIntShortTest() + { + int? value = 3; + AssertResults( + new Dictionary> + { + {"3", o => o is Int16Type} + }, + db.NumericEntities.Where(o => o.Short == value), + db.NumericEntities.Where(o => value == o.Short), + db.NumericEntities.Where(o => o.Short < value), + db.NumericEntities.Where(o => value < o.Short), + + db.NumericEntities.Where(o => o.NullableShort == value), + db.NumericEntities.Where(o => value == o.NullableShort), + db.NumericEntities.Where(o => o.NullableShort < value), + db.NumericEntities.Where(o => value < o.NullableShort), + + db.NumericEntities.Where(o => o.NullableShort.Value == value), + db.NumericEntities.Where(o => value == o.NullableShort.Value), + db.NumericEntities.Where(o => o.NullableShort.Value < value), + db.NumericEntities.Where(o => value < o.NullableShort.Value), + db.NumericEntities.Where(o => o.NullableShort.Value > value) + ); + } + [Test] public void ContainsStringEnumTest() { diff --git a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs index f6a9e5de43f..a70109bd59c 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs @@ -141,7 +141,7 @@ protected override Expression VisitConstant(ConstantExpression expression) // We have a bit more information about the null parameter value. // Figure out a type so that HQL doesn't break on the null. (Related to NH-2430) - // In v5.3 types are calculated by ConstantTypeLocator, this logic is only for back compatibility. + // In v5.3 types are calculated by ParameterTypeLocator, this logic is only for back compatibility. // TODO 6.0: Remove if (expression.Value == null) type = NHibernateUtil.GuessType(expression.Type); diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index b58ca33a305..a36c7292182 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -310,16 +310,15 @@ private HqlTreeNode VisitVBStringComparisonExpression(VBStringComparisonExpressi protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression) { - // In .NET some numeric types do not have thier own operators (e.g. short == short is converted to (int) short == (int) short), - // in such case we dont want to add a sql cast - if (expression.Left.NodeType == ExpressionType.Convert && - expression.Right.NodeType == ExpressionType.Convert && - ((UnaryExpression) expression.Left).Operand.Type.UnwrapIfNullable() == - ((UnaryExpression) expression.Right).Operand.Type.UnwrapIfNullable()) + // There are some cases where we do not want to add a sql cast: + // - When comparing numeric types that do not have thier own operator (e.g. short == short) + // - When comparing a member expression with a parameter of similar type (e.g. o.Short == intParameter) + var leftType = GetExpressionType(expression.Left); + var rightType = GetExpressionType(expression.Right); + if (leftType != null && leftType == rightType) { - var type = ((UnaryExpression) expression.Left).Operand.Type.UnwrapIfNullable(); - _notCastableExpressions.Add(expression.Left, type); - _notCastableExpressions.Add(expression.Right, type); + _notCastableExpressions.Add(expression.Left, leftType); + _notCastableExpressions.Add(expression.Right, rightType); } if (expression.NodeType == ExpressionType.Equal) @@ -388,6 +387,23 @@ protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression) throw new InvalidOperationException(); } + private System.Type GetExpressionType(Expression expression) + { + switch (expression.NodeType) + { + case ExpressionType.Constant: + return _parameters.ConstantToParameterMap.TryGetValue((ConstantExpression) expression, out var param) + ? param.Type?.ReturnedClass + : null; + case ExpressionType.Convert: + case ExpressionType.ConvertChecked: + var operand = ((UnaryExpression) expression).Operand; + return GetExpressionType(operand) ?? operand.Type.UnwrapIfNullable(); + } + + return null; + } + private HqlTreeNode TranslateInequalityComparison(BinaryExpression expression) { var lhs = VisitExpression(expression.Left).ToArithmeticExpression(); diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 457e04fcbdb..3c50e143010 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -13,6 +13,7 @@ using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Clauses.ResultOperators; using Remotion.Linq.Parsing; +using static NHibernate.Util.ExpressionsHelper; namespace NHibernate.Linq.Visitors { @@ -49,6 +50,26 @@ public static class ParameterTypeLocator ExpressionType.Conditional }; + + private static readonly HashSet IntegralNumericTypes = new HashSet + { + typeof(sbyte), + typeof(short), + typeof(int), + typeof(long), + typeof(byte), + typeof(ushort), + typeof(uint), + typeof(ulong) + }; + + private static readonly HashSet FloatingPointNumericTypes = new HashSet + { + typeof(decimal), + typeof(float), + typeof(double) + }; + /// /// Set query parameter types based on the given query model. /// @@ -80,55 +101,68 @@ internal static void SetParameterTypes( var visitor = new ConstantTypeLocatorVisitor(removeMappedAsCalls, targetType, parameters, sessionFactory); queryModel.TransformExpressions(visitor.Visit); - foreach (var pair in visitor.ConstantExpressions) + var processedConstants = new HashSet(); + foreach (var pair in visitor.ParameterConstants) { - var type = pair.Value; - var constantExpression = pair.Key; - if (!parameters.TryGetValue(constantExpression, out var namedParameter)) + var namedParameter = pair.Key; + var constantExpressions = pair.Value; + // In case any of the constants has the type set, use it (e.g. MappedAs) + namedParameter.Type = constantExpressions.Select(o => visitor.ConstantExpressions[o]).FirstOrDefault(o => o != null); + if (namedParameter.Type != null) { continue; } - if (type != null) + var parameterRelatedExpressions = new List(); + foreach (var expression in constantExpressions) { - // MappedAs was used - namedParameter.Type = type; - continue; + if (visitor.RelatedExpressions.TryGetValue(expression, out var relatedExpressions)) + { + parameterRelatedExpressions.AddRange(relatedExpressions); + } } + var candidateTypes = new HashSet(); // In order to get the actual type we have to check first the related member expressions, as // an enum is translated in a numeric type when used in a BinaryExpression and also it can be mapped as string. // By getting the type from a related member expression we also get the correct length in case of StringType // or precision when having a DecimalType. - if (visitor.RelatedExpressions.TryGetValue(constantExpression, out var memberExpressions)) + foreach (var relatedExpression in parameterRelatedExpressions) { - foreach (var memberExpression in memberExpressions) + if (TryGetMappedType(sessionFactory, relatedExpression, out var candidateType, out _, out _, out _)) { - if (ExpressionsHelper.TryGetMappedType( - sessionFactory, - memberExpression, - out type, - out _, - out _, - out _)) + if (candidateType.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(relatedExpression)) { - if (type.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(memberExpression)) - { - var collection = (IQueryableCollection) ((IAssociationType) type).GetAssociatedJoinable(sessionFactory); - type = collection.ElementType; - } - - break; + var collection = (IQueryableCollection) ((IAssociationType) candidateType).GetAssociatedJoinable(sessionFactory); + candidateType = collection.ElementType; } + + candidateTypes.Add(candidateType); } } + constantExpressions.Select(o => o.Type.UnwrapIfNullable()).Distinct().Single(); + var constantExpression = constantExpressions.First(); // TODO: check when types are different + var constantType = constantExpression.Type.UnwrapIfNullable(); + IType type = null; + if ( + candidateTypes.Count == 1 && + // When comparing an integral column with a floating-point parameter, the parameter type must remain floating-point type + // and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)). + !(candidateTypes.Any(t => IntegralNumericTypes.Contains(t.ReturnedClass)) && FloatingPointNumericTypes.Contains(constantType)) + ) + { + type = candidateTypes.FirstOrDefault(); + } + // No related MemberExpressions was found, guess the type by value or its type when null. + // When a numeric parameter is compared to different columns with different types (e.g. Where(o => o.Single >= singleParam || o.Double <= singleParam)) + // do not change the parameter type, but instead cast the parameter when comparing with different column types. if (type == null) { type = constantExpression.Value != null ? ParameterHelper.TryGuessType(constantExpression.Value, sessionFactory, namedParameter.IsCollection) - : ParameterHelper.TryGuessType(constantExpression.Type, sessionFactory, namedParameter.IsCollection); + : ParameterHelper.TryGuessType(constantType, sessionFactory, namedParameter.IsCollection); } namedParameter.Type = type; @@ -143,6 +177,8 @@ private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor private readonly ISessionFactoryImplementor _sessionFactory; public readonly Dictionary ConstantExpressions = new Dictionary(); + public readonly Dictionary> ParameterConstants = + new Dictionary>(); public readonly Dictionary> RelatedExpressions = new Dictionary>(); public readonly HashSet SequenceSelectorExpressions = new HashSet(); @@ -167,8 +203,8 @@ protected override Expression VisitBinary(BinaryExpression node) return node; } - var left = Unwrap(node.Left); - var right = Unwrap(node.Right); + var left = UnwrapUnary(node.Left); + var right = UnwrapUnary(node.Right); if (node.NodeType == ExpressionType.Assign) { VisitAssign(left, right); @@ -185,8 +221,8 @@ protected override Expression VisitBinary(BinaryExpression node) protected override Expression VisitConditional(ConditionalExpression node) { node = (ConditionalExpression) base.VisitConditional(node); - var ifTrue = Unwrap(node.IfTrue); - var ifFalse = Unwrap(node.IfFalse); + var ifTrue = UnwrapUnary(node.IfTrue); + var ifFalse = UnwrapUnary(node.IfFalse); AddRelatedExpression(node, ifTrue, ifFalse); AddRelatedExpression(node, ifFalse, ifTrue); @@ -232,13 +268,21 @@ protected override Expression VisitMethodCall(MethodCallExpression node) protected override Expression VisitConstant(ConstantExpression node) { - if (node.Value is IEntityNameProvider || RelatedExpressions.ContainsKey(node) || !_parameters.ContainsKey(node)) + if (node.Value is IEntityNameProvider || RelatedExpressions.ContainsKey(node) || !_parameters.TryGetValue(node, out var param)) { return node; } RelatedExpressions.Add(node, new HashSet()); ConstantExpressions.Add(node, null); + if (!ParameterConstants.TryGetValue(param, out var set)) + { + set = new HashSet(); + ParameterConstants.Add(param, set); + } + + set.Add(node); + return node; } @@ -254,7 +298,7 @@ querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause && mainFromClause.FromExpression is ConstantExpression constantExpression) { VisitConstant(constantExpression); - AddRelatedExpression(constantExpression, Unwrap(Visit(containsOperator.Item))); + AddRelatedExpression(constantExpression, UnwrapUnary(Visit(containsOperator.Item))); // Copy all found MemberExpressions to the constant expression // (e.g. values.Contains(o.Name != o.Name2 ? o.Enum1 : o.Enum2) -> copy o.Enum1 and o.Enum2) if (RelatedExpressions.TryGetValue(containsOperator.Item, out var set)) @@ -357,16 +401,6 @@ private bool IsDynamicMember(Expression expression) return false; } } - - private static Expression Unwrap(Expression expression) - { - if (expression is UnaryExpression unaryExpression) - { - return unaryExpression.Operand; - } - - return expression; - } } } } diff --git a/src/NHibernate/Type/SingleType.cs b/src/NHibernate/Type/SingleType.cs index bf719e3832b..70ca434e04d 100644 --- a/src/NHibernate/Type/SingleType.cs +++ b/src/NHibernate/Type/SingleType.cs @@ -62,7 +62,7 @@ public override System.Type ReturnedClass public override void Set(DbCommand rs, object value, int index, ISessionImplementor session) { - rs.Parameters[index].Value = value; + rs.Parameters[index].Value = Convert.ToSingle(value); } // Since 5.2 diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index 6f9f290aa72..1940fde78bf 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -59,6 +59,21 @@ constant.Value is CallSite site && } #endif + /// + /// Unwraps . + /// + /// The expression to unwrap. + /// The unwrapped expression. + internal static Expression UnwrapUnary(Expression expression) + { + if (expression is UnaryExpression unaryExpression) + { + return UnwrapUnary(unaryExpression.Operand); + } + + return expression; + } + /// /// Check whether the given expression represent a variable. /// From 492d1bd3ca437ba6dd8dcc1d7ff7fd970ecffb69 Mon Sep 17 00:00:00 2001 From: maca88 Date: Sat, 22 Aug 2020 01:57:24 +0200 Subject: [PATCH 04/21] Fix build --- src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 3c50e143010..9a985e73047 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -255,8 +255,8 @@ protected override Expression VisitMethodCall(MethodCallExpression node) if (EqualsGenerator.Methods.Contains(node.Method) || CompareGenerator.IsCompareMethod(node.Method)) { node = (MethodCallExpression) base.VisitMethodCall(node); - var left = Unwrap(node.Method.IsStatic ? node.Arguments[0] : node.Object); - var right = Unwrap(node.Method.IsStatic ? node.Arguments[1] : node.Arguments[0]); + var left = UnwrapUnary(node.Method.IsStatic ? node.Arguments[0] : node.Object); + var right = UnwrapUnary(node.Method.IsStatic ? node.Arguments[1] : node.Arguments[0]); AddRelatedExpression(node, left, right); AddRelatedExpression(node, right, left); From 39cf3f553b848bf594e3d45ed8724487d13122a2 Mon Sep 17 00:00:00 2001 From: maca88 Date: Sat, 22 Aug 2020 03:42:24 +0200 Subject: [PATCH 05/21] Try fix firebird parameter regex --- .../Async/Linq/ParameterTests.cs | 70 ++++++++++++++++--- src/NHibernate.Test/Linq/ParameterTests.cs | 70 ++++++++++++++++--- src/NHibernate/Driver/FirebirdClientDriver.cs | 4 +- 3 files changed, 124 insertions(+), 20 deletions(-) diff --git a/src/NHibernate.Test/Async/Linq/ParameterTests.cs b/src/NHibernate.Test/Async/Linq/ParameterTests.cs index 008daa1f49e..5f3a763749d 100644 --- a/src/NHibernate.Test/Async/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Async/Linq/ParameterTests.cs @@ -162,7 +162,15 @@ public async Task CompareIntegralParametersAndColumnsAsync() 3, sql => { - Assert.That(sql, Does.Not.Contain("cast")); + if (Sfi.Dialect is FirebirdDialect) + { + Assert.That(sql, Does.Contain("cast")); + } + else + { + Assert.That(sql, Does.Not.Contain("cast")); + } + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(3)); })); } @@ -203,7 +211,15 @@ public async Task CompareIntegralParametersWithFloatingPointColumnsAsync() 3, sql => { - Assert.That(sql, Does.Not.Contain("cast")); + if (Sfi.Dialect is FirebirdDialect) + { + Assert.That(sql, Does.Contain("cast")); + } + else + { + Assert.That(sql, Does.Not.Contain("cast")); + } + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(3)); })); } @@ -245,7 +261,15 @@ public async Task CompareFloatingPointParametersAndColumnsAsync() totalParameters, sql => { - Assert.That(sql, Does.Not.Contain("cast")); + if (Sfi.Dialect is FirebirdDialect) + { + Assert.That(sql, Does.Contain("cast")); + } + else + { + Assert.That(sql, Does.Not.Contain("cast")); + } + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(totalParameters)); })); } @@ -368,8 +392,17 @@ public async Task CompareFloatingPointParameterWithDifferentFloatingPointColumns var matches = pair.Value == "Double" ? Regex.Matches(sql, @"cast\([\w\d]+\..+\)") : Regex.Matches(sql, @"cast\(((@|\?|:)p\d+|\?)\s+as.*\)"); - // SQLiteDialect uses sql cast for transparentcast method - Assert.That(matches.Count, Is.EqualTo(sameType && !(Sfi.Dialect is SQLiteDialect) ? 0 : 1)); + if (Sfi.Dialect is FirebirdDialect) + { + // Additional casts are added by FirebirdClientDriver + Assert.That(matches.Count, Is.EqualTo(pair.Value == "Double" ? 1 : 2)); + } + else + { + // SQLiteDialect uses sql cast for transparentcast method + Assert.That(matches.Count, Is.EqualTo(sameType && !(Sfi.Dialect is SQLiteDialect) ? 0 : 1)); + } + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(1)); })); } @@ -408,7 +441,7 @@ public async Task CompareIntegralParameterWithIntegralAndFloatingPointColumnsAsy sql => { var matches = Regex.Matches(sql, @"cast\(((@|\?|:)p\d+|\?)\s+as.*\)"); - Assert.That(matches.Count, Is.EqualTo(1)); + Assert.That(matches.Count, Is.EqualTo(Sfi.Dialect is FirebirdDialect ? 2 : 1)); Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(1)); })); } @@ -447,7 +480,17 @@ public async Task UsingValueTypeParameterOfDifferentTypeAsync() await (AssertTotalParametersAsync( query, 1, - sql => Assert.That(sql, Does.Not.Contain("cast")))); + sql => + { + if (Sfi.Dialect is FirebirdDialect) + { + Assert.That(sql, Does.Contain("cast")); + } + else + { + Assert.That(sql, Does.Not.Contain("cast")); + } + })); } queriables = new List> @@ -476,7 +519,16 @@ public async Task UsingValueTypeParameterOfDifferentTypeAsync() sql => { // SQLiteDialect uses sql cast for transparentcast method Assert.That(sql, !sameType || Sfi.Dialect is SQLiteDialect ? Does.Match("where\\s+cast") : (IResolveConstraint)Does.Not.Contain("cast")); - Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(!sameType || Sfi.Dialect is SQLiteDialect ? 1 : 0)); + if (Sfi.Dialect is FirebirdDialect) + { + // Additional casts are added by FirebirdClientDriver + Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(3)); + } + else + { + // SQLiteDialect uses sql cast for transparentcast method + Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(!sameType || Sfi.Dialect is SQLiteDialect ? 1 : 0)); + } })); } } @@ -489,7 +541,7 @@ public async Task UsingValueTypeParameterTwiceOnNullablePropertyAsync() db.NumericEntities.Where(o => o.NullableShort == value && o.NullableShort != value && o.Short == value), 1, sql => { - Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(0)); + Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(Sfi.Dialect is FirebirdDialect ? 3 : 0)); })); } diff --git a/src/NHibernate.Test/Linq/ParameterTests.cs b/src/NHibernate.Test/Linq/ParameterTests.cs index c0f80819fec..b993b6bf73e 100644 --- a/src/NHibernate.Test/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTests.cs @@ -150,7 +150,15 @@ public void CompareIntegralParametersAndColumns() 3, sql => { - Assert.That(sql, Does.Not.Contain("cast")); + if (Sfi.Dialect is FirebirdDialect) + { + Assert.That(sql, Does.Contain("cast")); + } + else + { + Assert.That(sql, Does.Not.Contain("cast")); + } + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(3)); }); } @@ -191,7 +199,15 @@ public void CompareIntegralParametersWithFloatingPointColumns() 3, sql => { - Assert.That(sql, Does.Not.Contain("cast")); + if (Sfi.Dialect is FirebirdDialect) + { + Assert.That(sql, Does.Contain("cast")); + } + else + { + Assert.That(sql, Does.Not.Contain("cast")); + } + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(3)); }); } @@ -233,7 +249,15 @@ public void CompareFloatingPointParametersAndColumns() totalParameters, sql => { - Assert.That(sql, Does.Not.Contain("cast")); + if (Sfi.Dialect is FirebirdDialect) + { + Assert.That(sql, Does.Contain("cast")); + } + else + { + Assert.That(sql, Does.Not.Contain("cast")); + } + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(totalParameters)); }); } @@ -356,8 +380,17 @@ public void CompareFloatingPointParameterWithDifferentFloatingPointColumns() var matches = pair.Value == "Double" ? Regex.Matches(sql, @"cast\([\w\d]+\..+\)") : Regex.Matches(sql, @"cast\(((@|\?|:)p\d+|\?)\s+as.*\)"); - // SQLiteDialect uses sql cast for transparentcast method - Assert.That(matches.Count, Is.EqualTo(sameType && !(Sfi.Dialect is SQLiteDialect) ? 0 : 1)); + if (Sfi.Dialect is FirebirdDialect) + { + // Additional casts are added by FirebirdClientDriver + Assert.That(matches.Count, Is.EqualTo(pair.Value == "Double" ? 1 : 2)); + } + else + { + // SQLiteDialect uses sql cast for transparentcast method + Assert.That(matches.Count, Is.EqualTo(sameType && !(Sfi.Dialect is SQLiteDialect) ? 0 : 1)); + } + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(1)); }); } @@ -396,7 +429,7 @@ public void CompareIntegralParameterWithIntegralAndFloatingPointColumns() sql => { var matches = Regex.Matches(sql, @"cast\(((@|\?|:)p\d+|\?)\s+as.*\)"); - Assert.That(matches.Count, Is.EqualTo(1)); + Assert.That(matches.Count, Is.EqualTo(Sfi.Dialect is FirebirdDialect ? 2 : 1)); Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(1)); }); } @@ -435,7 +468,17 @@ public void UsingValueTypeParameterOfDifferentType() AssertTotalParameters( query, 1, - sql => Assert.That(sql, Does.Not.Contain("cast"))); + sql => + { + if (Sfi.Dialect is FirebirdDialect) + { + Assert.That(sql, Does.Contain("cast")); + } + else + { + Assert.That(sql, Does.Not.Contain("cast")); + } + }); } queriables = new List> @@ -464,7 +507,16 @@ public void UsingValueTypeParameterOfDifferentType() sql => { // SQLiteDialect uses sql cast for transparentcast method Assert.That(sql, !sameType || Sfi.Dialect is SQLiteDialect ? Does.Match("where\\s+cast") : (IResolveConstraint)Does.Not.Contain("cast")); - Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(!sameType || Sfi.Dialect is SQLiteDialect ? 1 : 0)); + if (Sfi.Dialect is FirebirdDialect) + { + // Additional casts are added by FirebirdClientDriver + Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(3)); + } + else + { + // SQLiteDialect uses sql cast for transparentcast method + Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(!sameType || Sfi.Dialect is SQLiteDialect ? 1 : 0)); + } }); } } @@ -477,7 +529,7 @@ public void UsingValueTypeParameterTwiceOnNullableProperty() db.NumericEntities.Where(o => o.NullableShort == value && o.NullableShort != value && o.Short == value), 1, sql => { - Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(0)); + Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(Sfi.Dialect is FirebirdDialect ? 3 : 0)); }); } diff --git a/src/NHibernate/Driver/FirebirdClientDriver.cs b/src/NHibernate/Driver/FirebirdClientDriver.cs index 2cfa6bd448d..058c03bbd1a 100644 --- a/src/NHibernate/Driver/FirebirdClientDriver.cs +++ b/src/NHibernate/Driver/FirebirdClientDriver.cs @@ -24,7 +24,7 @@ public class FirebirdClientDriver : ReflectionBasedDriver // Zero-width negative look-behind: the match must not be preceded by @"(?]\s*" + + @"[=<>]\s" + // or a paging instruction, @"|\bfirst\s+|\bskip\s+" + // or a "between" condition, @@ -36,7 +36,7 @@ public class FirebirdClientDriver : ReflectionBasedDriver // Zero-width negative look-ahead: the match must not be followed by @"(?!" + // a comparison. - @"\s*[=<>])"; + @"\s[=<>])"; private static readonly Regex _statementRegEx = new Regex(SELECT_CLAUSE_EXP, RegexOptions.IgnoreCase); private static readonly Regex _castCandidateRegEx = new Regex(CAST_PARAMS_EXP, RegexOptions.IgnoreCase); private readonly FirebirdDialect _fbDialect = new FirebirdDialect(); From 1aeaf37880bffe5f99c3c2279d741d7c0a47bb70 Mon Sep 17 00:00:00 2001 From: maca88 Date: Sat, 22 Aug 2020 04:08:35 +0200 Subject: [PATCH 06/21] Revert "Try fix firebird parameter regex" This reverts commit 39cf3f553b848bf594e3d45ed8724487d13122a2. --- .../Async/Linq/ParameterTests.cs | 70 +++---------------- src/NHibernate.Test/Linq/ParameterTests.cs | 70 +++---------------- src/NHibernate/Driver/FirebirdClientDriver.cs | 4 +- 3 files changed, 20 insertions(+), 124 deletions(-) diff --git a/src/NHibernate.Test/Async/Linq/ParameterTests.cs b/src/NHibernate.Test/Async/Linq/ParameterTests.cs index 5f3a763749d..008daa1f49e 100644 --- a/src/NHibernate.Test/Async/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Async/Linq/ParameterTests.cs @@ -162,15 +162,7 @@ public async Task CompareIntegralParametersAndColumnsAsync() 3, sql => { - if (Sfi.Dialect is FirebirdDialect) - { - Assert.That(sql, Does.Contain("cast")); - } - else - { - Assert.That(sql, Does.Not.Contain("cast")); - } - + Assert.That(sql, Does.Not.Contain("cast")); Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(3)); })); } @@ -211,15 +203,7 @@ public async Task CompareIntegralParametersWithFloatingPointColumnsAsync() 3, sql => { - if (Sfi.Dialect is FirebirdDialect) - { - Assert.That(sql, Does.Contain("cast")); - } - else - { - Assert.That(sql, Does.Not.Contain("cast")); - } - + Assert.That(sql, Does.Not.Contain("cast")); Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(3)); })); } @@ -261,15 +245,7 @@ public async Task CompareFloatingPointParametersAndColumnsAsync() totalParameters, sql => { - if (Sfi.Dialect is FirebirdDialect) - { - Assert.That(sql, Does.Contain("cast")); - } - else - { - Assert.That(sql, Does.Not.Contain("cast")); - } - + Assert.That(sql, Does.Not.Contain("cast")); Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(totalParameters)); })); } @@ -392,17 +368,8 @@ public async Task CompareFloatingPointParameterWithDifferentFloatingPointColumns var matches = pair.Value == "Double" ? Regex.Matches(sql, @"cast\([\w\d]+\..+\)") : Regex.Matches(sql, @"cast\(((@|\?|:)p\d+|\?)\s+as.*\)"); - if (Sfi.Dialect is FirebirdDialect) - { - // Additional casts are added by FirebirdClientDriver - Assert.That(matches.Count, Is.EqualTo(pair.Value == "Double" ? 1 : 2)); - } - else - { - // SQLiteDialect uses sql cast for transparentcast method - Assert.That(matches.Count, Is.EqualTo(sameType && !(Sfi.Dialect is SQLiteDialect) ? 0 : 1)); - } - + // SQLiteDialect uses sql cast for transparentcast method + Assert.That(matches.Count, Is.EqualTo(sameType && !(Sfi.Dialect is SQLiteDialect) ? 0 : 1)); Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(1)); })); } @@ -441,7 +408,7 @@ public async Task CompareIntegralParameterWithIntegralAndFloatingPointColumnsAsy sql => { var matches = Regex.Matches(sql, @"cast\(((@|\?|:)p\d+|\?)\s+as.*\)"); - Assert.That(matches.Count, Is.EqualTo(Sfi.Dialect is FirebirdDialect ? 2 : 1)); + Assert.That(matches.Count, Is.EqualTo(1)); Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(1)); })); } @@ -480,17 +447,7 @@ public async Task UsingValueTypeParameterOfDifferentTypeAsync() await (AssertTotalParametersAsync( query, 1, - sql => - { - if (Sfi.Dialect is FirebirdDialect) - { - Assert.That(sql, Does.Contain("cast")); - } - else - { - Assert.That(sql, Does.Not.Contain("cast")); - } - })); + sql => Assert.That(sql, Does.Not.Contain("cast")))); } queriables = new List> @@ -519,16 +476,7 @@ public async Task UsingValueTypeParameterOfDifferentTypeAsync() sql => { // SQLiteDialect uses sql cast for transparentcast method Assert.That(sql, !sameType || Sfi.Dialect is SQLiteDialect ? Does.Match("where\\s+cast") : (IResolveConstraint)Does.Not.Contain("cast")); - if (Sfi.Dialect is FirebirdDialect) - { - // Additional casts are added by FirebirdClientDriver - Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(3)); - } - else - { - // SQLiteDialect uses sql cast for transparentcast method - Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(!sameType || Sfi.Dialect is SQLiteDialect ? 1 : 0)); - } + Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(!sameType || Sfi.Dialect is SQLiteDialect ? 1 : 0)); })); } } @@ -541,7 +489,7 @@ public async Task UsingValueTypeParameterTwiceOnNullablePropertyAsync() db.NumericEntities.Where(o => o.NullableShort == value && o.NullableShort != value && o.Short == value), 1, sql => { - Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(Sfi.Dialect is FirebirdDialect ? 3 : 0)); + Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(0)); })); } diff --git a/src/NHibernate.Test/Linq/ParameterTests.cs b/src/NHibernate.Test/Linq/ParameterTests.cs index b993b6bf73e..c0f80819fec 100644 --- a/src/NHibernate.Test/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTests.cs @@ -150,15 +150,7 @@ public void CompareIntegralParametersAndColumns() 3, sql => { - if (Sfi.Dialect is FirebirdDialect) - { - Assert.That(sql, Does.Contain("cast")); - } - else - { - Assert.That(sql, Does.Not.Contain("cast")); - } - + Assert.That(sql, Does.Not.Contain("cast")); Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(3)); }); } @@ -199,15 +191,7 @@ public void CompareIntegralParametersWithFloatingPointColumns() 3, sql => { - if (Sfi.Dialect is FirebirdDialect) - { - Assert.That(sql, Does.Contain("cast")); - } - else - { - Assert.That(sql, Does.Not.Contain("cast")); - } - + Assert.That(sql, Does.Not.Contain("cast")); Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(3)); }); } @@ -249,15 +233,7 @@ public void CompareFloatingPointParametersAndColumns() totalParameters, sql => { - if (Sfi.Dialect is FirebirdDialect) - { - Assert.That(sql, Does.Contain("cast")); - } - else - { - Assert.That(sql, Does.Not.Contain("cast")); - } - + Assert.That(sql, Does.Not.Contain("cast")); Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(totalParameters)); }); } @@ -380,17 +356,8 @@ public void CompareFloatingPointParameterWithDifferentFloatingPointColumns() var matches = pair.Value == "Double" ? Regex.Matches(sql, @"cast\([\w\d]+\..+\)") : Regex.Matches(sql, @"cast\(((@|\?|:)p\d+|\?)\s+as.*\)"); - if (Sfi.Dialect is FirebirdDialect) - { - // Additional casts are added by FirebirdClientDriver - Assert.That(matches.Count, Is.EqualTo(pair.Value == "Double" ? 1 : 2)); - } - else - { - // SQLiteDialect uses sql cast for transparentcast method - Assert.That(matches.Count, Is.EqualTo(sameType && !(Sfi.Dialect is SQLiteDialect) ? 0 : 1)); - } - + // SQLiteDialect uses sql cast for transparentcast method + Assert.That(matches.Count, Is.EqualTo(sameType && !(Sfi.Dialect is SQLiteDialect) ? 0 : 1)); Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(1)); }); } @@ -429,7 +396,7 @@ public void CompareIntegralParameterWithIntegralAndFloatingPointColumns() sql => { var matches = Regex.Matches(sql, @"cast\(((@|\?|:)p\d+|\?)\s+as.*\)"); - Assert.That(matches.Count, Is.EqualTo(Sfi.Dialect is FirebirdDialect ? 2 : 1)); + Assert.That(matches.Count, Is.EqualTo(1)); Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(1)); }); } @@ -468,17 +435,7 @@ public void UsingValueTypeParameterOfDifferentType() AssertTotalParameters( query, 1, - sql => - { - if (Sfi.Dialect is FirebirdDialect) - { - Assert.That(sql, Does.Contain("cast")); - } - else - { - Assert.That(sql, Does.Not.Contain("cast")); - } - }); + sql => Assert.That(sql, Does.Not.Contain("cast"))); } queriables = new List> @@ -507,16 +464,7 @@ public void UsingValueTypeParameterOfDifferentType() sql => { // SQLiteDialect uses sql cast for transparentcast method Assert.That(sql, !sameType || Sfi.Dialect is SQLiteDialect ? Does.Match("where\\s+cast") : (IResolveConstraint)Does.Not.Contain("cast")); - if (Sfi.Dialect is FirebirdDialect) - { - // Additional casts are added by FirebirdClientDriver - Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(3)); - } - else - { - // SQLiteDialect uses sql cast for transparentcast method - Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(!sameType || Sfi.Dialect is SQLiteDialect ? 1 : 0)); - } + Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(!sameType || Sfi.Dialect is SQLiteDialect ? 1 : 0)); }); } } @@ -529,7 +477,7 @@ public void UsingValueTypeParameterTwiceOnNullableProperty() db.NumericEntities.Where(o => o.NullableShort == value && o.NullableShort != value && o.Short == value), 1, sql => { - Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(Sfi.Dialect is FirebirdDialect ? 3 : 0)); + Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(0)); }); } diff --git a/src/NHibernate/Driver/FirebirdClientDriver.cs b/src/NHibernate/Driver/FirebirdClientDriver.cs index 058c03bbd1a..2cfa6bd448d 100644 --- a/src/NHibernate/Driver/FirebirdClientDriver.cs +++ b/src/NHibernate/Driver/FirebirdClientDriver.cs @@ -24,7 +24,7 @@ public class FirebirdClientDriver : ReflectionBasedDriver // Zero-width negative look-behind: the match must not be preceded by @"(?]\s" + + @"[=<>]\s*" + // or a paging instruction, @"|\bfirst\s+|\bskip\s+" + // or a "between" condition, @@ -36,7 +36,7 @@ public class FirebirdClientDriver : ReflectionBasedDriver // Zero-width negative look-ahead: the match must not be followed by @"(?!" + // a comparison. - @"\s[=<>])"; + @"\s*[=<>])"; private static readonly Regex _statementRegEx = new Regex(SELECT_CLAUSE_EXP, RegexOptions.IgnoreCase); private static readonly Regex _castCandidateRegEx = new Regex(CAST_PARAMS_EXP, RegexOptions.IgnoreCase); private readonly FirebirdDialect _fbDialect = new FirebirdDialect(); From c50939fb11d7a39c05d95c4c4ec86120dbabbb5d Mon Sep 17 00:00:00 2001 From: maca88 Date: Sat, 22 Aug 2020 04:22:11 +0200 Subject: [PATCH 07/21] Skip failing tests for Firebird --- .../Async/Linq/NullComparisonTests.cs | 20 +++++++++---------- .../Async/Linq/ParameterTests.cs | 15 ++++++++++++++ .../Linq/NullComparisonTests.cs | 20 +++++++++---------- src/NHibernate.Test/Linq/ParameterTests.cs | 15 ++++++++++++++ 4 files changed, 50 insertions(+), 20 deletions(-) diff --git a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs index 728d76649bf..827a674f594 100644 --- a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs @@ -474,6 +474,11 @@ public async Task NullEqualityAsync() await (ExpectAsync(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id == 5), Does.Not.Contain("is null").IgnoreCase)); await (ExpectAsync(session.Query().Where(o => 5 == o.CreatedBy.ModifiedBy.Id), Does.Not.Contain("is null").IgnoreCase)); + if (Sfi.Dialect is FirebirdDialect) + { + return; + } + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort == o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); await (ExpectAsync(db.NumericEntities.Where(o => o.Short == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); @@ -584,6 +589,11 @@ public async Task NullInequalityAsync() await (ExpectAsync(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id != 5), Does.Contain("is null").IgnoreCase)); await (ExpectAsync(session.Query().Where(o => 5 != o.CreatedBy.ModifiedBy.Id), Does.Contain("is null").IgnoreCase)); + if (Sfi.Dialect is FirebirdDialect) + { + return; + } + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort != o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); await (ExpectAsync(db.NumericEntities.Where(o => o.Short != o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort != o.Short), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); @@ -612,21 +622,11 @@ private IResolveConstraint WithIsNullAndWithoutCast() return Does.Contain("is null").IgnoreCase.And.Not.Contain("cast").IgnoreCase; } - private IResolveConstraint WithIsNullAndWithCast() - { - return Does.Contain("is null").IgnoreCase.And.Contain("cast").IgnoreCase; - } - private IResolveConstraint WithoutIsNullAndWithoutCast() { return Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast").IgnoreCase; } - private IResolveConstraint WithoutIsNullAndWithCast() - { - return Does.Not.Contain("is null").IgnoreCase.And.Contain("cast").IgnoreCase; - } - [Test] public async Task NullEqualityInvertedAsync() { diff --git a/src/NHibernate.Test/Async/Linq/ParameterTests.cs b/src/NHibernate.Test/Async/Linq/ParameterTests.cs index 008daa1f49e..769c3f28943 100644 --- a/src/NHibernate.Test/Async/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Async/Linq/ParameterTests.cs @@ -337,6 +337,11 @@ public async Task CompareFloatingPointParameterWithIntegralAndFloatingPointColum [Test] public async Task CompareFloatingPointParameterWithDifferentFloatingPointColumnsAsync() { + if (Sfi.Dialect is FirebirdDialect) + { + Assert.Ignore("Due to the regex hack in FirebirdClientDriver, the parameters can be casted twice."); + } + var singleParam = 2.2f; var doubleParam = 3.3d; var queriables = new Dictionary, string> @@ -378,6 +383,11 @@ public async Task CompareFloatingPointParameterWithDifferentFloatingPointColumns [Test] public async Task CompareIntegralParameterWithIntegralAndFloatingPointColumnsAsync() { + if (Sfi.Dialect is FirebirdDialect) + { + Assert.Ignore("Due to the regex hack in FirebirdClientDriver, the parameters can be casted twice."); + } + short shortParam = 1; var intParam = 2; var longParam = 3L; @@ -450,6 +460,11 @@ public async Task UsingValueTypeParameterOfDifferentTypeAsync() sql => Assert.That(sql, Does.Not.Contain("cast")))); } + if (Sfi.Dialect is FirebirdDialect) + { + Assert.Ignore("Due to the regex bug in FirebirdClientDriver, the parameters are not casted."); + } + queriables = new List> { db.NumericEntities.Where(o => o.Short + value > value), diff --git a/src/NHibernate.Test/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Linq/NullComparisonTests.cs index 4b54bab5557..dcb0f9b5899 100644 --- a/src/NHibernate.Test/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Linq/NullComparisonTests.cs @@ -462,6 +462,11 @@ public void NullEquality() Expect(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id == 5), Does.Not.Contain("is null").IgnoreCase); Expect(session.Query().Where(o => 5 == o.CreatedBy.ModifiedBy.Id), Does.Not.Contain("is null").IgnoreCase); + if (Sfi.Dialect is FirebirdDialect) + { + return; + } + Expect(db.NumericEntities.Where(o => o.NullableShort == o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); Expect(db.NumericEntities.Where(o => o.Short == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); Expect(db.NumericEntities.Where(o => o.NullableShort == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); @@ -572,6 +577,11 @@ public void NullInequality() Expect(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id != 5), Does.Contain("is null").IgnoreCase); Expect(session.Query().Where(o => 5 != o.CreatedBy.ModifiedBy.Id), Does.Contain("is null").IgnoreCase); + if (Sfi.Dialect is FirebirdDialect) + { + return; + } + Expect(db.NumericEntities.Where(o => o.NullableShort != o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); Expect(db.NumericEntities.Where(o => o.Short != o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); Expect(db.NumericEntities.Where(o => o.NullableShort != o.Short), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); @@ -600,21 +610,11 @@ private IResolveConstraint WithIsNullAndWithoutCast() return Does.Contain("is null").IgnoreCase.And.Not.Contain("cast").IgnoreCase; } - private IResolveConstraint WithIsNullAndWithCast() - { - return Does.Contain("is null").IgnoreCase.And.Contain("cast").IgnoreCase; - } - private IResolveConstraint WithoutIsNullAndWithoutCast() { return Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast").IgnoreCase; } - private IResolveConstraint WithoutIsNullAndWithCast() - { - return Does.Not.Contain("is null").IgnoreCase.And.Contain("cast").IgnoreCase; - } - [Test] public void NullEqualityInverted() { diff --git a/src/NHibernate.Test/Linq/ParameterTests.cs b/src/NHibernate.Test/Linq/ParameterTests.cs index c0f80819fec..b7c361f088a 100644 --- a/src/NHibernate.Test/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTests.cs @@ -325,6 +325,11 @@ public void CompareFloatingPointParameterWithIntegralAndFloatingPointColumns() [Test] public void CompareFloatingPointParameterWithDifferentFloatingPointColumns() { + if (Sfi.Dialect is FirebirdDialect) + { + Assert.Ignore("Due to the regex hack in FirebirdClientDriver, the parameters can be casted twice."); + } + var singleParam = 2.2f; var doubleParam = 3.3d; var queriables = new Dictionary, string> @@ -366,6 +371,11 @@ public void CompareFloatingPointParameterWithDifferentFloatingPointColumns() [Test] public void CompareIntegralParameterWithIntegralAndFloatingPointColumns() { + if (Sfi.Dialect is FirebirdDialect) + { + Assert.Ignore("Due to the regex hack in FirebirdClientDriver, the parameters can be casted twice."); + } + short shortParam = 1; var intParam = 2; var longParam = 3L; @@ -438,6 +448,11 @@ public void UsingValueTypeParameterOfDifferentType() sql => Assert.That(sql, Does.Not.Contain("cast"))); } + if (Sfi.Dialect is FirebirdDialect) + { + Assert.Ignore("Due to the regex bug in FirebirdClientDriver, the parameters are not casted."); + } + queriables = new List> { db.NumericEntities.Where(o => o.Short + value > value), From 010cf850b3ac5d59e2c7b6373b7fe9b781a39a04 Mon Sep 17 00:00:00 2001 From: maca88 Date: Sat, 22 Aug 2020 13:53:52 +0200 Subject: [PATCH 08/21] Fix odbc tests --- src/NHibernate.Test/Async/Linq/ParameterTests.cs | 10 +++++++--- src/NHibernate.Test/Linq/ParameterTests.cs | 10 +++++++--- src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs | 4 ++-- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/NHibernate.Test/Async/Linq/ParameterTests.cs b/src/NHibernate.Test/Async/Linq/ParameterTests.cs index 769c3f28943..4bb8ca0f7f5 100644 --- a/src/NHibernate.Test/Async/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Async/Linq/ParameterTests.cs @@ -16,6 +16,7 @@ using System.Text.RegularExpressions; using NHibernate.Dialect; using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Driver; using NHibernate.Engine.Query; using NHibernate.Linq; using NHibernate.Util; @@ -318,6 +319,7 @@ public async Task CompareFloatingPointParameterWithIntegralAndFloatingPointColum {db.NumericEntities.Where(o => o.NullableDouble == doubleParam || o.Integer != doubleParam), "Double"}, {db.NumericEntities.Where(o => o.NullableDouble == doubleParam || o.NullableInteger == doubleParam), "Double"} }; + var odbcDriver = Sfi.ConnectionProvider.Driver is OdbcDriver; foreach (var pair in queriables) { @@ -329,7 +331,7 @@ public async Task CompareFloatingPointParameterWithIntegralAndFloatingPointColum { var matches = Regex.Matches(sql, @"cast\([\w\d]+\..+\)"); Assert.That(matches.Count, Is.EqualTo(1)); - Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(odbcDriver ? 2 : 1)); })); } } @@ -361,6 +363,7 @@ public async Task CompareFloatingPointParameterWithDifferentFloatingPointColumns var sameType = Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Single.SqlType, out var singleCast) && Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Double.SqlType, out var doubleCast) && singleCast == doubleCast; + var odbcDriver = Sfi.ConnectionProvider.Driver is OdbcDriver; foreach (var pair in queriables) { @@ -375,7 +378,7 @@ public async Task CompareFloatingPointParameterWithDifferentFloatingPointColumns : Regex.Matches(sql, @"cast\(((@|\?|:)p\d+|\?)\s+as.*\)"); // SQLiteDialect uses sql cast for transparentcast method Assert.That(matches.Count, Is.EqualTo(sameType && !(Sfi.Dialect is SQLiteDialect) ? 0 : 1)); - Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(odbcDriver ? 2 : 1)); })); } } @@ -408,6 +411,7 @@ public async Task CompareIntegralParameterWithIntegralAndFloatingPointColumnsAsy {db.NumericEntities.Where(o => o.NullableLong == longParam || o.Decimal != longParam), "Int64"}, {db.NumericEntities.Where(o => o.NullableLong == longParam || o.NullableSingle > longParam), "Int64"} }; + var odbcDriver = Sfi.ConnectionProvider.Driver is OdbcDriver; foreach (var pair in queriables) { @@ -419,7 +423,7 @@ public async Task CompareIntegralParameterWithIntegralAndFloatingPointColumnsAsy { var matches = Regex.Matches(sql, @"cast\(((@|\?|:)p\d+|\?)\s+as.*\)"); Assert.That(matches.Count, Is.EqualTo(1)); - Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(odbcDriver ? 2 : 1)); })); } } diff --git a/src/NHibernate.Test/Linq/ParameterTests.cs b/src/NHibernate.Test/Linq/ParameterTests.cs index b7c361f088a..73028ba8598 100644 --- a/src/NHibernate.Test/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTests.cs @@ -6,6 +6,7 @@ using System.Text.RegularExpressions; using NHibernate.Dialect; using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Driver; using NHibernate.Engine.Query; using NHibernate.Linq; using NHibernate.Util; @@ -306,6 +307,7 @@ public void CompareFloatingPointParameterWithIntegralAndFloatingPointColumns() {db.NumericEntities.Where(o => o.NullableDouble == doubleParam || o.Integer != doubleParam), "Double"}, {db.NumericEntities.Where(o => o.NullableDouble == doubleParam || o.NullableInteger == doubleParam), "Double"} }; + var odbcDriver = Sfi.ConnectionProvider.Driver is OdbcDriver; foreach (var pair in queriables) { @@ -317,7 +319,7 @@ public void CompareFloatingPointParameterWithIntegralAndFloatingPointColumns() { var matches = Regex.Matches(sql, @"cast\([\w\d]+\..+\)"); Assert.That(matches.Count, Is.EqualTo(1)); - Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(odbcDriver ? 2 : 1)); }); } } @@ -349,6 +351,7 @@ public void CompareFloatingPointParameterWithDifferentFloatingPointColumns() var sameType = Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Single.SqlType, out var singleCast) && Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Double.SqlType, out var doubleCast) && singleCast == doubleCast; + var odbcDriver = Sfi.ConnectionProvider.Driver is OdbcDriver; foreach (var pair in queriables) { @@ -363,7 +366,7 @@ public void CompareFloatingPointParameterWithDifferentFloatingPointColumns() : Regex.Matches(sql, @"cast\(((@|\?|:)p\d+|\?)\s+as.*\)"); // SQLiteDialect uses sql cast for transparentcast method Assert.That(matches.Count, Is.EqualTo(sameType && !(Sfi.Dialect is SQLiteDialect) ? 0 : 1)); - Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(odbcDriver ? 2 : 1)); }); } } @@ -396,6 +399,7 @@ public void CompareIntegralParameterWithIntegralAndFloatingPointColumns() {db.NumericEntities.Where(o => o.NullableLong == longParam || o.Decimal != longParam), "Int64"}, {db.NumericEntities.Where(o => o.NullableLong == longParam || o.NullableSingle > longParam), "Int64"} }; + var odbcDriver = Sfi.ConnectionProvider.Driver is OdbcDriver; foreach (var pair in queriables) { @@ -407,7 +411,7 @@ public void CompareIntegralParameterWithIntegralAndFloatingPointColumns() { var matches = Regex.Matches(sql, @"cast\(((@|\?|:)p\d+|\?)\s+as.*\)"); Assert.That(matches.Count, Is.EqualTo(1)); - Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(odbcDriver ? 2 : 1)); }); } } diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 9a985e73047..81675b1ccbf 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -141,8 +141,8 @@ internal static void SetParameterTypes( } } - constantExpressions.Select(o => o.Type.UnwrapIfNullable()).Distinct().Single(); - var constantExpression = constantExpressions.First(); // TODO: check when types are different + // All constant expressions have the same type/value + var constantExpression = constantExpressions.First(); var constantType = constantExpression.Type.UnwrapIfNullable(); IType type = null; if ( From c9183272072e9b6618482a83b5c799fa79febdc1 Mon Sep 17 00:00:00 2001 From: maca88 Date: Tue, 25 Aug 2020 19:29:44 +0200 Subject: [PATCH 09/21] Code review changes --- .../Visitors/HqlGeneratorExpressionVisitor.cs | 11 ++++--- .../NhPartialEvaluatingExpressionVisitor.cs | 32 ++++++++++++------- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index a36c7292182..9a8d897ea3d 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -526,11 +526,14 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression) case ExpressionType.Convert: case ExpressionType.ConvertChecked: case ExpressionType.TypeAs: - var notCastable = _notCastableExpressions.TryGetValue(expression, out var castType); - castType = castType ?? expression.Type; + var castable = !_notCastableExpressions.TryGetValue(expression, out var castType); + if (castable) + { + castType = expression.Type; + } - return IsCastRequired(expression.Operand, castType, out var existType) && !notCastable - ? _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), expression.Type) + return IsCastRequired(expression.Operand, castType, out var existType) && castable + ? _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), castType) // Make a transparent cast when an IType exists, so that it can be used to retrieve the value from the data reader : existType && HqlIdent.SupportsType(castType) ? _hqlTreeBuilder.TransparentCast(VisitExpression(expression.Operand).AsExpression(), castType) diff --git a/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs index b198da9002f..475dbcdad31 100644 --- a/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs @@ -80,16 +80,7 @@ public override Expression Visit(Expression expression) if (expression.NodeType == ExpressionType.Lambda || !_partialEvaluationInfo.IsEvaluatableExpression(expression) || #region NH additions // Variables should be evaluated only when they are part of an evaluatable expression (e.g. o => string.Format("...", variable)) - expression is UnaryExpression unaryExpression && - ( - ExpressionsHelper.IsVariable(unaryExpression.Operand, out _, out _) || - // Check whether the variable is casted due to comparison with a nullable expression - // (e.g. o.NullableShort == shortVariable) - unaryExpression.Operand is UnaryExpression subUnaryExpression && - unaryExpression.Type.UnwrapIfNullable() == subUnaryExpression.Type && - ExpressionsHelper.IsVariable(subUnaryExpression.Operand, out _, out _) - ) - ) + ContainsVariable(expression)) #endregion return base.Visit(expression); @@ -162,8 +153,27 @@ private Expression EvaluateSubtree(Expression subtree) } } + #region NH additions + + private bool ContainsVariable(Expression expression) + { + if (!(expression is UnaryExpression unaryExpression)) + { + return false; + } + + return ExpressionsHelper.IsVariable(unaryExpression.Operand, out _, out _) || + // Check whether the variable is casted due to comparison with a nullable expression + // (e.g. o.NullableShort == shortVariable) + unaryExpression.Operand is UnaryExpression subUnaryExpression && + unaryExpression.Type.UnwrapIfNullable() == subUnaryExpression.Type && + ExpressionsHelper.IsVariable(subUnaryExpression.Operand, out _, out _); + } + #endregion - + + #endregion + protected override Expression VisitConstant(ConstantExpression expression) { if (expression.Value is Expression value) From a4d4df01aa817fbe0bf31f6f5a4755baea07e2ec Mon Sep 17 00:00:00 2001 From: maca88 Date: Tue, 25 Aug 2020 20:16:51 +0200 Subject: [PATCH 10/21] Code review changes --- .../Async/Linq/NullComparisonTests.cs | 40 +++++++++---------- .../Linq/NullComparisonTests.cs | 40 +++++++++---------- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs index 827a674f594..185dff330f2 100644 --- a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs @@ -479,19 +479,19 @@ public async Task NullEqualityAsync() return; } - await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort == o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(db.NumericEntities.Where(o => o.Short == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(db.NumericEntities.Where(o => o.Short == o.NullableShort), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort == o.NullableShort), WithIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => o.Short == o.Short), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort == o.Short), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => o.Short == o.NullableShort), WithoutIsNullAndWithoutCast())); short value = 3; - await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(db.NumericEntities.Where(o => value == o.NullableShort), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort == value), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => value == o.NullableShort), WithoutIsNullAndWithoutCast())); - await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort.Value == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(db.NumericEntities.Where(o => value == o.NullableShort.Value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(db.NumericEntities.Where(o => o.Short == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(db.NumericEntities.Where(o => value == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort.Value == value), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => value == o.NullableShort.Value), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => o.Short == value), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => value == o.Short), WithoutIsNullAndWithoutCast())); await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort == 3L), WithoutIsNullAndWithoutCast())); await (ExpectAsync(db.NumericEntities.Where(o => 3L == o.NullableShort), WithoutIsNullAndWithoutCast())); @@ -594,19 +594,19 @@ public async Task NullInequalityAsync() return; } - await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort != o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(db.NumericEntities.Where(o => o.Short != o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort != o.Short), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(db.NumericEntities.Where(o => o.Short != o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort != o.NullableShort), WithIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => o.Short != o.Short), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort != o.Short), WithIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => o.Short != o.NullableShort), WithIsNullAndWithoutCast())); short value = 3; - await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort != value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(db.NumericEntities.Where(o => value != o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort != value), WithIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => value != o.NullableShort), WithIsNullAndWithoutCast())); - await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort.Value != value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(db.NumericEntities.Where(o => value != o.NullableShort.Value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(db.NumericEntities.Where(o => o.Short != value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); - await (ExpectAsync(db.NumericEntities.Where(o => value != o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"))); + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort.Value != value), WithIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => value != o.NullableShort.Value), WithIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => o.Short != value), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => value != o.Short), WithoutIsNullAndWithoutCast())); await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort != 3L), WithIsNullAndWithoutCast())); await (ExpectAsync(db.NumericEntities.Where(o => 3 != o.NullableShort), WithIsNullAndWithoutCast())); diff --git a/src/NHibernate.Test/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Linq/NullComparisonTests.cs index dcb0f9b5899..7eec9f09e22 100644 --- a/src/NHibernate.Test/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Linq/NullComparisonTests.cs @@ -467,19 +467,19 @@ public void NullEquality() return; } - Expect(db.NumericEntities.Where(o => o.NullableShort == o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(db.NumericEntities.Where(o => o.Short == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(db.NumericEntities.Where(o => o.NullableShort == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(db.NumericEntities.Where(o => o.Short == o.NullableShort), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(db.NumericEntities.Where(o => o.NullableShort == o.NullableShort), WithIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => o.Short == o.Short), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => o.NullableShort == o.Short), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => o.Short == o.NullableShort), WithoutIsNullAndWithoutCast()); short value = 3; - Expect(db.NumericEntities.Where(o => o.NullableShort == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(db.NumericEntities.Where(o => value == o.NullableShort), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(db.NumericEntities.Where(o => o.NullableShort == value), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => value == o.NullableShort), WithoutIsNullAndWithoutCast()); - Expect(db.NumericEntities.Where(o => o.NullableShort.Value == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(db.NumericEntities.Where(o => value == o.NullableShort.Value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(db.NumericEntities.Where(o => o.Short == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(db.NumericEntities.Where(o => value == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(db.NumericEntities.Where(o => o.NullableShort.Value == value), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => value == o.NullableShort.Value), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => o.Short == value), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => value == o.Short), WithoutIsNullAndWithoutCast()); Expect(db.NumericEntities.Where(o => o.NullableShort == 3L), WithoutIsNullAndWithoutCast()); Expect(db.NumericEntities.Where(o => 3L == o.NullableShort), WithoutIsNullAndWithoutCast()); @@ -582,19 +582,19 @@ public void NullInequality() return; } - Expect(db.NumericEntities.Where(o => o.NullableShort != o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(db.NumericEntities.Where(o => o.Short != o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(db.NumericEntities.Where(o => o.NullableShort != o.Short), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(db.NumericEntities.Where(o => o.Short != o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(db.NumericEntities.Where(o => o.NullableShort != o.NullableShort), WithIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => o.Short != o.Short), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => o.NullableShort != o.Short), WithIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => o.Short != o.NullableShort), WithIsNullAndWithoutCast()); short value = 3; - Expect(db.NumericEntities.Where(o => o.NullableShort != value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(db.NumericEntities.Where(o => value != o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(db.NumericEntities.Where(o => o.NullableShort != value), WithIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => value != o.NullableShort), WithIsNullAndWithoutCast()); - Expect(db.NumericEntities.Where(o => o.NullableShort.Value != value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(db.NumericEntities.Where(o => value != o.NullableShort.Value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(db.NumericEntities.Where(o => o.Short != value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); - Expect(db.NumericEntities.Where(o => value != o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")); + Expect(db.NumericEntities.Where(o => o.NullableShort.Value != value), WithIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => value != o.NullableShort.Value), WithIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => o.Short != value), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => value != o.Short), WithoutIsNullAndWithoutCast()); Expect(db.NumericEntities.Where(o => o.NullableShort != 3L), WithIsNullAndWithoutCast()); Expect(db.NumericEntities.Where(o => 3 != o.NullableShort), WithIsNullAndWithoutCast()); From 22c17ffc8847f667d62560f1418e4ed46307af39 Mon Sep 17 00:00:00 2001 From: maca88 Date: Tue, 25 Aug 2020 20:17:39 +0200 Subject: [PATCH 11/21] Update src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Frédéric Delaporte <12201973+fredericDelaporte@users.noreply.github.com> --- src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index 9a8d897ea3d..6dc938bdae2 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -311,7 +311,7 @@ private HqlTreeNode VisitVBStringComparisonExpression(VBStringComparisonExpressi protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression) { // There are some cases where we do not want to add a sql cast: - // - When comparing numeric types that do not have thier own operator (e.g. short == short) + // - When comparing numeric types that do not have their own operator (e.g. short == short) // - When comparing a member expression with a parameter of similar type (e.g. o.Short == intParameter) var leftType = GetExpressionType(expression.Left); var rightType = GetExpressionType(expression.Right); From 0858dce87e908b356059b4779206ddbb3979b52c Mon Sep 17 00:00:00 2001 From: Alex Zaytsev Date: Thu, 27 Aug 2020 21:08:38 +1200 Subject: [PATCH 12/21] Remove unused variable --- src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 81675b1ccbf..5f7e85d7aa5 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -101,7 +101,6 @@ internal static void SetParameterTypes( var visitor = new ConstantTypeLocatorVisitor(removeMappedAsCalls, targetType, parameters, sessionFactory); queryModel.TransformExpressions(visitor.Visit); - var processedConstants = new HashSet(); foreach (var pair in visitor.ParameterConstants) { var namedParameter = pair.Key; From 6d76720a33bdfb3f463f5d2a648563fc2d3b1e73 Mon Sep 17 00:00:00 2001 From: Alex Zaytsev Date: Thu, 27 Aug 2020 21:10:04 +1200 Subject: [PATCH 13/21] Extract GetParameterType and GetCandidateTypes methods --- .../Linq/Visitors/ParameterTypeLocator.cs | 107 +++++++++++------- 1 file changed, 64 insertions(+), 43 deletions(-) diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 5f7e85d7aa5..5040b775b06 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -112,60 +112,81 @@ internal static void SetParameterTypes( continue; } - var parameterRelatedExpressions = new List(); - foreach (var expression in constantExpressions) + namedParameter.Type = GetParameterType(sessionFactory, constantExpressions, visitor, namedParameter); + } + } + + private static HashSet GetCandidateTypes( + ISessionFactoryImplementor sessionFactory, + IEnumerable constantExpressions, + ConstantTypeLocatorVisitor visitor) + { + var parameterRelatedExpressions = new List(); + foreach (var expression in constantExpressions) + { + if (visitor.RelatedExpressions.TryGetValue(expression, out var relatedExpressions)) { - if (visitor.RelatedExpressions.TryGetValue(expression, out var relatedExpressions)) - { - parameterRelatedExpressions.AddRange(relatedExpressions); - } + parameterRelatedExpressions.AddRange(relatedExpressions); } + } - var candidateTypes = new HashSet(); - // In order to get the actual type we have to check first the related member expressions, as - // an enum is translated in a numeric type when used in a BinaryExpression and also it can be mapped as string. - // By getting the type from a related member expression we also get the correct length in case of StringType - // or precision when having a DecimalType. - foreach (var relatedExpression in parameterRelatedExpressions) + var candidateTypes = new HashSet(); + // In order to get the actual type we have to check first the related member expressions, as + // an enum is translated in a numeric type when used in a BinaryExpression and also it can be mapped as string. + // By getting the type from a related member expression we also get the correct length in case of StringType + // or precision when having a DecimalType. + foreach (var relatedExpression in parameterRelatedExpressions) + { + if (TryGetMappedType(sessionFactory, relatedExpression, out var candidateType, out _, out _, out _)) { - if (TryGetMappedType(sessionFactory, relatedExpression, out var candidateType, out _, out _, out _)) + if (candidateType.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(relatedExpression)) { - if (candidateType.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(relatedExpression)) - { - var collection = (IQueryableCollection) ((IAssociationType) candidateType).GetAssociatedJoinable(sessionFactory); - candidateType = collection.ElementType; - } - - candidateTypes.Add(candidateType); + var collection = + (IQueryableCollection) ((IAssociationType) candidateType).GetAssociatedJoinable(sessionFactory); + candidateType = collection.ElementType; } - } - // All constant expressions have the same type/value - var constantExpression = constantExpressions.First(); - var constantType = constantExpression.Type.UnwrapIfNullable(); - IType type = null; - if ( - candidateTypes.Count == 1 && - // When comparing an integral column with a floating-point parameter, the parameter type must remain floating-point type - // and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)). - !(candidateTypes.Any(t => IntegralNumericTypes.Contains(t.ReturnedClass)) && FloatingPointNumericTypes.Contains(constantType)) - ) - { - type = candidateTypes.FirstOrDefault(); + candidateTypes.Add(candidateType); } + } - // No related MemberExpressions was found, guess the type by value or its type when null. - // When a numeric parameter is compared to different columns with different types (e.g. Where(o => o.Single >= singleParam || o.Double <= singleParam)) - // do not change the parameter type, but instead cast the parameter when comparing with different column types. - if (type == null) - { - type = constantExpression.Value != null - ? ParameterHelper.TryGuessType(constantExpression.Value, sessionFactory, namedParameter.IsCollection) - : ParameterHelper.TryGuessType(constantType, sessionFactory, namedParameter.IsCollection); - } + return candidateTypes; + } - namedParameter.Type = type; + private static IType GetParameterType( + ISessionFactoryImplementor sessionFactory, + HashSet constantExpressions, + ConstantTypeLocatorVisitor visitor, + NamedParameter namedParameter) + { + var candidateTypes = GetCandidateTypes(sessionFactory, constantExpressions, visitor); + + // All constant expressions have the same type/value + var constantExpression = constantExpressions.First(); + var constantType = constantExpression.Type.UnwrapIfNullable(); + IType type = null; + if ( + candidateTypes.Count == 1 && + // When comparing an integral column with a floating-point parameter, the parameter type must remain floating-point type + // and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)). + !(candidateTypes.Any(t => IntegralNumericTypes.Contains(t.ReturnedClass)) && + FloatingPointNumericTypes.Contains(constantType)) + ) + { + type = candidateTypes.FirstOrDefault(); + } + + // No related MemberExpressions was found, guess the type by value or its type when null. + // When a numeric parameter is compared to different columns with different types (e.g. Where(o => o.Single >= singleParam || o.Double <= singleParam)) + // do not change the parameter type, but instead cast the parameter when comparing with different column types. + if (type == null) + { + type = constantExpression.Value != null + ? ParameterHelper.TryGuessType(constantExpression.Value, sessionFactory, namedParameter.IsCollection) + : ParameterHelper.TryGuessType(constantType, sessionFactory, namedParameter.IsCollection); } + + return type; } private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor From 3bdb59864dd3a3cd26a2fb560cf00448095392d4 Mon Sep 17 00:00:00 2001 From: Alex Zaytsev Date: Thu, 27 Aug 2020 21:12:06 +1200 Subject: [PATCH 14/21] Reduce complexity --- .../Linq/Visitors/ParameterTypeLocator.cs | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 5040b775b06..60fd9e788d7 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -164,7 +164,6 @@ private static IType GetParameterType( // All constant expressions have the same type/value var constantExpression = constantExpressions.First(); var constantType = constantExpression.Type.UnwrapIfNullable(); - IType type = null; if ( candidateTypes.Count == 1 && // When comparing an integral column with a floating-point parameter, the parameter type must remain floating-point type @@ -173,20 +172,15 @@ private static IType GetParameterType( FloatingPointNumericTypes.Contains(constantType)) ) { - type = candidateTypes.FirstOrDefault(); + return candidateTypes.First(); } // No related MemberExpressions was found, guess the type by value or its type when null. // When a numeric parameter is compared to different columns with different types (e.g. Where(o => o.Single >= singleParam || o.Double <= singleParam)) // do not change the parameter type, but instead cast the parameter when comparing with different column types. - if (type == null) - { - type = constantExpression.Value != null - ? ParameterHelper.TryGuessType(constantExpression.Value, sessionFactory, namedParameter.IsCollection) - : ParameterHelper.TryGuessType(constantType, sessionFactory, namedParameter.IsCollection); - } - - return type; + return constantExpression.Value != null + ? ParameterHelper.TryGuessType(constantExpression.Value, sessionFactory, namedParameter.IsCollection) + : ParameterHelper.TryGuessType(constantType, sessionFactory, namedParameter.IsCollection); } private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor From b4888b77eddac87214a6d435c11e3d4b87ef4332 Mon Sep 17 00:00:00 2001 From: Alex Zaytsev Date: Thu, 27 Aug 2020 21:20:58 +1200 Subject: [PATCH 15/21] Further reduce complexity --- .../Linq/Visitors/ParameterTypeLocator.cs | 36 ++++++++++++++----- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 60fd9e788d7..859c9a2cd05 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -153,26 +153,44 @@ private static HashSet GetCandidateTypes( return candidateTypes; } + private static bool GetCandidateType( + ISessionFactoryImplementor sessionFactory, + IEnumerable constantExpressions, + ConstantTypeLocatorVisitor visitor, + System.Type constantType, + out IType candidateType) + { + var candidateTypes = GetCandidateTypes(sessionFactory, constantExpressions, visitor); + if (candidateTypes.Count == 1) + { + candidateType = candidateTypes.First(); + + // When comparing an integral column with a floating-point parameter, the parameter type must remain floating-point type + // and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)). + if (!IntegralNumericTypes.Contains(candidateType.ReturnedClass) || + !FloatingPointNumericTypes.Contains(constantType)) + { + return true; + } + } + + candidateType = null; + return false; + } + private static IType GetParameterType( ISessionFactoryImplementor sessionFactory, HashSet constantExpressions, ConstantTypeLocatorVisitor visitor, NamedParameter namedParameter) { - var candidateTypes = GetCandidateTypes(sessionFactory, constantExpressions, visitor); // All constant expressions have the same type/value var constantExpression = constantExpressions.First(); var constantType = constantExpression.Type.UnwrapIfNullable(); - if ( - candidateTypes.Count == 1 && - // When comparing an integral column with a floating-point parameter, the parameter type must remain floating-point type - // and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)). - !(candidateTypes.Any(t => IntegralNumericTypes.Contains(t.ReturnedClass)) && - FloatingPointNumericTypes.Contains(constantType)) - ) + if (GetCandidateType(sessionFactory, constantExpressions, visitor, constantType, out var candidateType)) { - return candidateTypes.First(); + return candidateType; } // No related MemberExpressions was found, guess the type by value or its type when null. From e31a9a561ea3746f158d36d943ccefd02e7fb124 Mon Sep 17 00:00:00 2001 From: Alex Zaytsev Date: Thu, 27 Aug 2020 21:29:29 +1200 Subject: [PATCH 16/21] Move UnwrapUnary to ParameterTypeLocator as it is too specific --- .../Linq/Visitors/ParameterTypeLocator.cs | 18 ++++++++++++++++-- src/NHibernate/Util/ExpressionsHelper.cs | 15 --------------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 859c9a2cd05..3e42690cb5c 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -13,7 +13,6 @@ using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Clauses.ResultOperators; using Remotion.Linq.Parsing; -using static NHibernate.Util.ExpressionsHelper; namespace NHibernate.Linq.Visitors { @@ -137,7 +136,7 @@ private static HashSet GetCandidateTypes( // or precision when having a DecimalType. foreach (var relatedExpression in parameterRelatedExpressions) { - if (TryGetMappedType(sessionFactory, relatedExpression, out var candidateType, out _, out _, out _)) + if (ExpressionsHelper.TryGetMappedType(sessionFactory, relatedExpression, out var candidateType, out _, out _, out _)) { if (candidateType.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(relatedExpression)) { @@ -434,5 +433,20 @@ private bool IsDynamicMember(Expression expression) } } } + + /// + /// Unwraps . + /// + /// The expression to unwrap. + /// The unwrapped expression. + private static Expression UnwrapUnary(Expression expression) + { + while (expression is UnaryExpression unaryExpression) + { + expression = unaryExpression.Operand; + } + + return expression; + } } } diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index 1940fde78bf..6f9f290aa72 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -59,21 +59,6 @@ constant.Value is CallSite site && } #endif - /// - /// Unwraps . - /// - /// The expression to unwrap. - /// The unwrapped expression. - internal static Expression UnwrapUnary(Expression expression) - { - if (expression is UnaryExpression unaryExpression) - { - return UnwrapUnary(unaryExpression.Operand); - } - - return expression; - } - /// /// Check whether the given expression represent a variable. /// From d10a057be159645fae8a287ec44d7e253cbb0d8d Mon Sep 17 00:00:00 2001 From: Alex Zaytsev Date: Thu, 27 Aug 2020 21:31:23 +1200 Subject: [PATCH 17/21] Remove list creation --- .../Linq/Visitors/ParameterTypeLocator.cs | 36 +++++++++---------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 3e42690cb5c..76fc8dd5494 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -120,32 +120,28 @@ private static HashSet GetCandidateTypes( IEnumerable constantExpressions, ConstantTypeLocatorVisitor visitor) { - var parameterRelatedExpressions = new List(); + var candidateTypes = new HashSet(); foreach (var expression in constantExpressions) { if (visitor.RelatedExpressions.TryGetValue(expression, out var relatedExpressions)) { - parameterRelatedExpressions.AddRange(relatedExpressions); - } - } - - var candidateTypes = new HashSet(); - // In order to get the actual type we have to check first the related member expressions, as - // an enum is translated in a numeric type when used in a BinaryExpression and also it can be mapped as string. - // By getting the type from a related member expression we also get the correct length in case of StringType - // or precision when having a DecimalType. - foreach (var relatedExpression in parameterRelatedExpressions) - { - if (ExpressionsHelper.TryGetMappedType(sessionFactory, relatedExpression, out var candidateType, out _, out _, out _)) - { - if (candidateType.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(relatedExpression)) + // In order to get the actual type we have to check first the related member expressions, as + // an enum is translated in a numeric type when used in a BinaryExpression and also it can be mapped as string. + // By getting the type from a related member expression we also get the correct length in case of StringType + // or precision when having a DecimalType. + foreach (var relatedExpression in relatedExpressions) { - var collection = - (IQueryableCollection) ((IAssociationType) candidateType).GetAssociatedJoinable(sessionFactory); - candidateType = collection.ElementType; - } + if (ExpressionsHelper.TryGetMappedType(sessionFactory, relatedExpression, out var candidateType, out _, out _, out _)) + { + if (candidateType.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(relatedExpression)) + { + var collection = (IQueryableCollection) ((IAssociationType) candidateType).GetAssociatedJoinable(sessionFactory); + candidateType = collection.ElementType; + } - candidateTypes.Add(candidateType); + candidateTypes.Add(candidateType); + } + } } } From 371b8dfa6b08ba55cd4494b807ab05d8e1b9139f Mon Sep 17 00:00:00 2001 From: Alex Zaytsev Date: Thu, 27 Aug 2020 21:34:33 +1200 Subject: [PATCH 18/21] It seems comment belongs here --- src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 76fc8dd5494..88119c2dce5 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -123,12 +123,12 @@ private static HashSet GetCandidateTypes( var candidateTypes = new HashSet(); foreach (var expression in constantExpressions) { + // In order to get the actual type we have to check first the related member expressions, as + // an enum is translated in a numeric type when used in a BinaryExpression and also it can be mapped as string. + // By getting the type from a related member expression we also get the correct length in case of StringType + // or precision when having a DecimalType. if (visitor.RelatedExpressions.TryGetValue(expression, out var relatedExpressions)) { - // In order to get the actual type we have to check first the related member expressions, as - // an enum is translated in a numeric type when used in a BinaryExpression and also it can be mapped as string. - // By getting the type from a related member expression we also get the correct length in case of StringType - // or precision when having a DecimalType. foreach (var relatedExpression in relatedExpressions) { if (ExpressionsHelper.TryGetMappedType(sessionFactory, relatedExpression, out var candidateType, out _, out _, out _)) From 9d0bbb55763ed3d0a306cfcaff545464fb3bdc0c Mon Sep 17 00:00:00 2001 From: Roman Artiukhin Date: Thu, 27 Aug 2020 12:58:24 +0300 Subject: [PATCH 19/21] Avoid HashSet creation for candidate type calculation --- .../Linq/Visitors/ParameterTypeLocator.cs | 61 +++++++------------ 1 file changed, 21 insertions(+), 40 deletions(-) diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 88119c2dce5..0836fe5043b 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -115,62 +115,39 @@ internal static void SetParameterTypes( } } - private static HashSet GetCandidateTypes( + private static IType GetCandidateType( ISessionFactoryImplementor sessionFactory, IEnumerable constantExpressions, ConstantTypeLocatorVisitor visitor) { - var candidateTypes = new HashSet(); + IType candidateType = null; foreach (var expression in constantExpressions) { // In order to get the actual type we have to check first the related member expressions, as // an enum is translated in a numeric type when used in a BinaryExpression and also it can be mapped as string. // By getting the type from a related member expression we also get the correct length in case of StringType // or precision when having a DecimalType. - if (visitor.RelatedExpressions.TryGetValue(expression, out var relatedExpressions)) + if (!visitor.RelatedExpressions.TryGetValue(expression, out var relatedExpressions)) + continue; + foreach (var relatedExpression in relatedExpressions) { - foreach (var relatedExpression in relatedExpressions) - { - if (ExpressionsHelper.TryGetMappedType(sessionFactory, relatedExpression, out var candidateType, out _, out _, out _)) - { - if (candidateType.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(relatedExpression)) - { - var collection = (IQueryableCollection) ((IAssociationType) candidateType).GetAssociatedJoinable(sessionFactory); - candidateType = collection.ElementType; - } + if (!ExpressionsHelper.TryGetMappedType(sessionFactory, relatedExpression, out var mappedType, out _, out _, out _)) + continue; - candidateTypes.Add(candidateType); - } + if (mappedType.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(relatedExpression)) + { + var collection = (IQueryableCollection) ((IAssociationType) mappedType).GetAssociatedJoinable(sessionFactory); + mappedType = collection.ElementType; } - } - } - - return candidateTypes; - } - - private static bool GetCandidateType( - ISessionFactoryImplementor sessionFactory, - IEnumerable constantExpressions, - ConstantTypeLocatorVisitor visitor, - System.Type constantType, - out IType candidateType) - { - var candidateTypes = GetCandidateTypes(sessionFactory, constantExpressions, visitor); - if (candidateTypes.Count == 1) - { - candidateType = candidateTypes.First(); - // When comparing an integral column with a floating-point parameter, the parameter type must remain floating-point type - // and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)). - if (!IntegralNumericTypes.Contains(candidateType.ReturnedClass) || - !FloatingPointNumericTypes.Contains(constantType)) - { - return true; + if (candidateType == null) + candidateType = mappedType; + else if (!candidateType.Equals(mappedType)) + return null; } } - candidateType = null; - return false; + return candidateType; } private static IType GetParameterType( @@ -183,7 +160,11 @@ private static IType GetParameterType( // All constant expressions have the same type/value var constantExpression = constantExpressions.First(); var constantType = constantExpression.Type.UnwrapIfNullable(); - if (GetCandidateType(sessionFactory, constantExpressions, visitor, constantType, out var candidateType)) + var candidateType = GetCandidateType(sessionFactory, constantExpressions, visitor); + if (candidateType != null && + // When comparing an integral column with a floating-point parameter, the parameter type must remain floating-point type + // and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)). + !(FloatingPointNumericTypes.Contains(constantType) && IntegralNumericTypes.Contains(candidateType.ReturnedClass))) { return candidateType; } From 3c10311e3ed035e80f3e10655a82591d375f3305 Mon Sep 17 00:00:00 2001 From: Alex Zaytsev Date: Fri, 28 Aug 2020 09:13:40 +1200 Subject: [PATCH 20/21] Move number type checks to TypeExtensions and extract additional method to make code more readable --- .../Linq/Visitors/ParameterTypeLocator.cs | 50 ++++++++++--------- src/NHibernate/Util/TypeExtensions.cs | 25 ++++++++++ 2 files changed, 51 insertions(+), 24 deletions(-) diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 0836fe5043b..e3104d2e841 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -13,6 +13,7 @@ using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Clauses.ResultOperators; using Remotion.Linq.Parsing; +using TypeExtensions = NHibernate.Util.TypeExtensions; namespace NHibernate.Linq.Visitors { @@ -50,25 +51,6 @@ public static class ParameterTypeLocator }; - private static readonly HashSet IntegralNumericTypes = new HashSet - { - typeof(sbyte), - typeof(short), - typeof(int), - typeof(long), - typeof(byte), - typeof(ushort), - typeof(uint), - typeof(ulong) - }; - - private static readonly HashSet FloatingPointNumericTypes = new HashSet - { - typeof(decimal), - typeof(float), - typeof(double) - }; - /// /// Set query parameter types based on the given query model. /// @@ -150,6 +132,29 @@ private static IType GetCandidateType( return candidateType; } + private static IType GetCandidateType( + ISessionFactoryImplementor sessionFactory, + HashSet constantExpressions, + ConstantTypeLocatorVisitor visitor, + System.Type constantType) + { + var candidateType = GetCandidateType(sessionFactory, constantExpressions, visitor); + + if (candidateType == null) + { + return null; + } + + // When comparing an integral column with a real parameter, the parameter type must remain real type + // and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)). + if (constantType.IsRealNumberType() && candidateType.ReturnedClass.IsIntegralNumberType()) + { + return null; + } + + return candidateType; + } + private static IType GetParameterType( ISessionFactoryImplementor sessionFactory, HashSet constantExpressions, @@ -160,11 +165,8 @@ private static IType GetParameterType( // All constant expressions have the same type/value var constantExpression = constantExpressions.First(); var constantType = constantExpression.Type.UnwrapIfNullable(); - var candidateType = GetCandidateType(sessionFactory, constantExpressions, visitor); - if (candidateType != null && - // When comparing an integral column with a floating-point parameter, the parameter type must remain floating-point type - // and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)). - !(FloatingPointNumericTypes.Contains(constantType) && IntegralNumericTypes.Contains(candidateType.ReturnedClass))) + var candidateType = GetCandidateType(sessionFactory, constantExpressions, visitor, constantType); + if (candidateType != null) { return candidateType; } diff --git a/src/NHibernate/Util/TypeExtensions.cs b/src/NHibernate/Util/TypeExtensions.cs index 71bf854957f..1637f30d420 100644 --- a/src/NHibernate/Util/TypeExtensions.cs +++ b/src/NHibernate/Util/TypeExtensions.cs @@ -48,5 +48,30 @@ internal static System.Type UnwrapIfNullable(this System.Type type) return type; } + + internal static bool IsIntegralNumberType(this System.Type type) + { + var code = System.Type.GetTypeCode(type); + if (code == TypeCode.SByte || code == TypeCode.Byte || + code == TypeCode.Int16 || code == TypeCode.UInt16 || + code == TypeCode.Int32 || code == TypeCode.UInt32 || + code == TypeCode.Int64 || code == TypeCode.UInt64) + { + return true; + } + + return false; + } + + internal static bool IsRealNumberType(this System.Type type) + { + var code = System.Type.GetTypeCode(type); + if (code == TypeCode.Decimal || code == TypeCode.Single || code == TypeCode.Double) + { + return true; + } + + return false; + } } } From 503aec21cdba03d0b04553a0bffc388ef52a739f Mon Sep 17 00:00:00 2001 From: Roman Artiukhin Date: Fri, 28 Aug 2020 09:36:16 +0300 Subject: [PATCH 21/21] Merge to single GetCandidateType --- .../Linq/Visitors/ParameterTypeLocator.cs | 21 +++---------------- 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index e3104d2e841..37e3da19852 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -100,7 +100,8 @@ internal static void SetParameterTypes( private static IType GetCandidateType( ISessionFactoryImplementor sessionFactory, IEnumerable constantExpressions, - ConstantTypeLocatorVisitor visitor) + ConstantTypeLocatorVisitor visitor, + System.Type constantType) { IType candidateType = null; foreach (var expression in constantExpressions) @@ -129,28 +130,13 @@ private static IType GetCandidateType( } } - return candidateType; - } - - private static IType GetCandidateType( - ISessionFactoryImplementor sessionFactory, - HashSet constantExpressions, - ConstantTypeLocatorVisitor visitor, - System.Type constantType) - { - var candidateType = GetCandidateType(sessionFactory, constantExpressions, visitor); - if (candidateType == null) - { return null; - } - + // When comparing an integral column with a real parameter, the parameter type must remain real type // and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)). if (constantType.IsRealNumberType() && candidateType.ReturnedClass.IsIntegralNumberType()) - { return null; - } return candidateType; } @@ -161,7 +147,6 @@ private static IType GetParameterType( ConstantTypeLocatorVisitor visitor, NamedParameter namedParameter) { - // All constant expressions have the same type/value var constantExpression = constantExpressions.First(); var constantType = constantExpression.Type.UnwrapIfNullable();