diff --git a/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs b/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs index 4551ce0e9d8..670b94b5bc5 100755 --- a/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs @@ -69,6 +69,11 @@ public IQueryable Users get { return _session.Query(); } } + public IQueryable NumericEntities + { + get { return _session.Query(); } + } + public IQueryable DynamicUsers { get { return _session.Query(); } diff --git a/src/NHibernate.DomainModel/Northwind/Entities/NumericEntity.cs b/src/NHibernate.DomainModel/Northwind/Entities/NumericEntity.cs new file mode 100644 index 00000000000..f38641303b1 --- /dev/null +++ b/src/NHibernate.DomainModel/Northwind/Entities/NumericEntity.cs @@ -0,0 +1,29 @@ +namespace NHibernate.DomainModel.Northwind.Entities +{ + public class NumericEntity + { + public virtual short Short { get; set; } + + public virtual short? NullableShort { get; set; } + + public virtual int Integer { get; set; } + + public virtual int? NullableInteger { get; set; } + + public virtual long Long { get; set; } + + public virtual long? NullableLong { get; set; } + + public virtual decimal Decimal { get; set; } + + public virtual decimal? NullableDecimal { get; set; } + + public virtual float Single { get; set; } + + public virtual float? NullableSingle { get; set; } + + public virtual double Double { get; set; } + + public virtual double? NullableDouble { get; set; } + } +} diff --git a/src/NHibernate.DomainModel/Northwind/Mappings/NumericEntity.hbm.xml b/src/NHibernate.DomainModel/Northwind/Mappings/NumericEntity.hbm.xml new file mode 100644 index 00000000000..2eb696e1acd --- /dev/null +++ b/src/NHibernate.DomainModel/Northwind/Mappings/NumericEntity.hbm.xml @@ -0,0 +1,20 @@ + + + + + + + + + + + + + + + + + + + + diff --git a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs index 6a5c75091c7..185dff330f2 100644 --- a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs @@ -12,6 +12,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using NHibernate.Dialect; using NHibernate.Linq; using NHibernate.DomainModel.Northwind.Entities; using NUnit.Framework; @@ -472,6 +473,33 @@ public async Task NullEqualityAsync() await (ExpectAsync(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id == 5), Does.Not.Contain("is null").IgnoreCase)); await (ExpectAsync(session.Query().Where(o => 5 == o.CreatedBy.ModifiedBy.Id), Does.Not.Contain("is null").IgnoreCase)); + + if (Sfi.Dialect is FirebirdDialect) + { + return; + } + + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort == o.NullableShort), WithIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => o.Short == o.Short), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort == o.Short), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => o.Short == o.NullableShort), WithoutIsNullAndWithoutCast())); + + short value = 3; + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort == value), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => value == o.NullableShort), WithoutIsNullAndWithoutCast())); + + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort.Value == value), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => value == o.NullableShort.Value), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => o.Short == value), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => value == o.Short), WithoutIsNullAndWithoutCast())); + + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort == 3L), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => 3L == o.NullableShort), WithoutIsNullAndWithoutCast())); + + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort.Value == 3L), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => 3L == o.NullableShort.Value), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => o.Short == 3L), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => 3L == o.Short), WithoutIsNullAndWithoutCast())); } [Test] @@ -560,6 +588,43 @@ public async Task NullInequalityAsync() await (ExpectAsync(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id != 5), Does.Contain("is null").IgnoreCase)); await (ExpectAsync(session.Query().Where(o => 5 != o.CreatedBy.ModifiedBy.Id), Does.Contain("is null").IgnoreCase)); + + if (Sfi.Dialect is FirebirdDialect) + { + return; + } + + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort != o.NullableShort), WithIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => o.Short != o.Short), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort != o.Short), WithIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => o.Short != o.NullableShort), WithIsNullAndWithoutCast())); + + short value = 3; + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort != value), WithIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => value != o.NullableShort), WithIsNullAndWithoutCast())); + + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort.Value != value), WithIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => value != o.NullableShort.Value), WithIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => o.Short != value), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => value != o.Short), WithoutIsNullAndWithoutCast())); + + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort != 3L), WithIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => 3 != o.NullableShort), WithIsNullAndWithoutCast())); + + await (ExpectAsync(db.NumericEntities.Where(o => o.NullableShort.Value != 3L), WithIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => 3L != o.NullableShort.Value), WithIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => o.Short != 3L), WithoutIsNullAndWithoutCast())); + await (ExpectAsync(db.NumericEntities.Where(o => 3L != o.Short), WithoutIsNullAndWithoutCast())); + } + + private IResolveConstraint WithIsNullAndWithoutCast() + { + return Does.Contain("is null").IgnoreCase.And.Not.Contain("cast").IgnoreCase; + } + + private IResolveConstraint WithoutIsNullAndWithoutCast() + { + return Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast").IgnoreCase; } [Test] diff --git a/src/NHibernate.Test/Async/Linq/ParameterTests.cs b/src/NHibernate.Test/Async/Linq/ParameterTests.cs index ad1c885dc4f..4bb8ca0f7f5 100644 --- a/src/NHibernate.Test/Async/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Async/Linq/ParameterTests.cs @@ -14,11 +14,14 @@ using System.Linq.Expressions; using System.Reflection; using System.Text.RegularExpressions; +using NHibernate.Dialect; using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Driver; using NHibernate.Engine.Query; using NHibernate.Linq; using NHibernate.Util; using NUnit.Framework; +using NUnit.Framework.Constraints; namespace NHibernate.Test.Linq { @@ -125,6 +128,403 @@ public async Task UsingValueTypeParameterTwiceAsync() 1)); } + [Test] + public async Task CompareIntegralParametersAndColumnsAsync() + { + short shortParam = 1; + var intParam = 2; + var longParam = 3L; + short? nullShortParam = 1; + int? nullIntParam = 2; + long? nullLongParam = 3L; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Short == shortParam || o.Short < intParam || o.Short > longParam), "Int16"}, + {db.NumericEntities.Where(o => o.NullableShort == shortParam || o.NullableShort <= intParam || o.NullableShort != longParam), "Int16"}, + {db.NumericEntities.Where(o => o.Short == nullShortParam || o.Short < nullIntParam || o.Short > nullLongParam), "Int16"}, + {db.NumericEntities.Where(o => o.NullableShort == nullShortParam || o.NullableShort <= nullIntParam || o.NullableShort != nullLongParam), "Int16"}, + + {db.NumericEntities.Where(o => o.Integer == shortParam || o.Integer < intParam || o.Integer > longParam), "Int32"}, + {db.NumericEntities.Where(o => o.NullableInteger == shortParam || o.NullableInteger <= intParam || o.NullableInteger != longParam), "Int32"}, + {db.NumericEntities.Where(o => o.Integer == nullShortParam || o.Integer < nullIntParam || o.Integer > nullLongParam), "Int32"}, + {db.NumericEntities.Where(o => o.NullableInteger == nullShortParam || o.NullableInteger <= nullIntParam || o.NullableInteger != nullLongParam), "Int32"}, + + {db.NumericEntities.Where(o => o.Long == shortParam || o.Long < intParam || o.Long > longParam), "Int64"}, + {db.NumericEntities.Where(o => o.NullableLong == shortParam || o.NullableLong <= intParam || o.NullableLong != longParam), "Int64"}, + {db.NumericEntities.Where(o => o.Long == nullShortParam || o.Long < nullIntParam || o.Long > nullLongParam), "Int64"}, + {db.NumericEntities.Where(o => o.NullableLong == nullShortParam || o.NullableLong <= nullIntParam || o.NullableLong != nullLongParam), "Int64"} + }; + + foreach (var pair in queriables) + { + // Parameters should be pre-evaluated + await (AssertTotalParametersAsync( + pair.Key, + 3, + sql => + { + Assert.That(sql, Does.Not.Contain("cast")); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(3)); + })); + } + } + + [Test] + public async Task CompareIntegralParametersWithFloatingPointColumnsAsync() + { + short shortParam = 1; + var intParam = 2; + var longParam = 3L; + short? nullShortParam = 1; + int? nullIntParam = 2; + long? nullLongParam = 3L; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Decimal == shortParam || o.Decimal < intParam || o.Decimal > longParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == shortParam || o.NullableDecimal <= intParam || o.NullableDecimal != longParam), "Decimal"}, + {db.NumericEntities.Where(o => o.Decimal == nullShortParam || o.Decimal < nullIntParam || o.Decimal > nullLongParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == nullShortParam || o.NullableDecimal <= nullIntParam || o.NullableDecimal != nullLongParam), "Decimal"}, + + {db.NumericEntities.Where(o => o.Single == shortParam || o.Single < intParam || o.Single > longParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == shortParam || o.NullableSingle <= intParam || o.NullableSingle != longParam), "Single"}, + {db.NumericEntities.Where(o => o.Single == nullShortParam || o.Single < nullIntParam || o.Single > nullLongParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == nullShortParam || o.NullableSingle <= nullIntParam || o.NullableSingle != nullLongParam), "Single"}, + + {db.NumericEntities.Where(o => o.Double == shortParam || o.Double < intParam || o.Double > longParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == shortParam || o.NullableDouble <= intParam || o.NullableDouble != longParam), "Double"}, + {db.NumericEntities.Where(o => o.Double == nullShortParam || o.Double < nullIntParam || o.Double > nullLongParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == nullShortParam || o.NullableDouble <= nullIntParam || o.NullableDouble != nullLongParam), "Double"}, + }; + + foreach (var pair in queriables) + { + // Parameters should be pre-evaluated + await (AssertTotalParametersAsync( + pair.Key, + 3, + sql => + { + Assert.That(sql, Does.Not.Contain("cast")); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(3)); + })); + } + } + + [Test] + public async Task CompareFloatingPointParametersAndColumnsAsync() + { + var decimalParam = 1.1m; + var singleParam = 2.2f; + var doubleParam = 3.3d; + decimal? nullDecimalParam = 1.1m; + float? nullSingleParam = 2.2f; + double? nullDoubleParam = 3.3d; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Decimal == decimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == decimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.Decimal == nullDecimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == nullDecimalParam), "Decimal"}, + + {db.NumericEntities.Where(o => o.Single <= singleParam || o.Single >= doubleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == singleParam || o.NullableSingle != doubleParam), "Single"}, + {db.NumericEntities.Where(o => o.Single <= nullSingleParam || o.Single >= nullDoubleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == nullSingleParam || o.NullableSingle != nullDoubleParam), "Single"}, + + {db.NumericEntities.Where(o => o.Double <= singleParam || o.Double >= doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == singleParam || o.NullableDouble != doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.Double <= nullSingleParam || o.Double >= nullDoubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == nullSingleParam || o.NullableDouble != nullDoubleParam), "Double"}, + }; + + foreach (var pair in queriables) + { + var totalParameters = pair.Value == "Decimal" ? 1 : 2; + // Parameters should be pre-evaluated + await (AssertTotalParametersAsync( + pair.Key, + totalParameters, + sql => + { + Assert.That(sql, Does.Not.Contain("cast")); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(totalParameters)); + })); + } + } + + [Test] + public async Task CompareFloatingPointParametersWithIntegralColumnsAsync() + { + var decimalParam = 1.1m; + var singleParam = 2.2f; + var doubleParam = 3.3d; + decimal? nullDecimalParam = 1.1m; + float? nullSingleParam = 2.2f; + double? nullDoubleParam = 3.3d; + var queriables = new List> + { + db.NumericEntities.Where(o => o.Short == decimalParam || o.Short != singleParam || o.Short <= doubleParam), + db.NumericEntities.Where(o => o.NullableShort <= decimalParam || o.NullableShort == singleParam || o.NullableShort >= doubleParam), + db.NumericEntities.Where(o => o.Short == nullDecimalParam || o.Short != nullSingleParam || o.Short <= nullDoubleParam), + db.NumericEntities.Where(o => o.NullableShort <= nullDecimalParam || o.NullableShort == nullSingleParam || o.NullableShort >= nullDoubleParam), + + db.NumericEntities.Where(o => o.Integer == decimalParam || o.Integer != singleParam || o.Integer <= doubleParam), + db.NumericEntities.Where(o => o.NullableInteger <= decimalParam || o.NullableInteger == singleParam || o.NullableInteger >= doubleParam), + db.NumericEntities.Where(o => o.Integer == nullDecimalParam || o.Integer != nullSingleParam || o.Integer <= nullDoubleParam), + db.NumericEntities.Where(o => o.NullableInteger <= nullDecimalParam || o.NullableInteger == nullSingleParam || o.NullableInteger >= nullDoubleParam), + + db.NumericEntities.Where(o => o.Long == decimalParam || o.Long != singleParam || o.Long <= doubleParam), + db.NumericEntities.Where(o => o.NullableLong <= decimalParam || o.NullableLong == singleParam || o.NullableLong >= doubleParam), + db.NumericEntities.Where(o => o.Long == nullDecimalParam || o.Long != nullSingleParam || o.Long <= nullDoubleParam), + db.NumericEntities.Where(o => o.NullableLong <= nullDecimalParam || o.NullableLong == nullSingleParam || o.NullableLong >= nullDoubleParam), + }; + + foreach (var query in queriables) + { + // Columns should be casted + await (AssertTotalParametersAsync( + query, + 3, + sql => + { + var matches = Regex.Matches(sql, @"cast\([\w\d]+\..+\)"); + Assert.That(matches.Count, Is.EqualTo(3)); + Assert.That(GetTotalOccurrences(sql, $"Type: Decimal"), Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, $"Type: Single"), Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, $"Type: Double"), Is.EqualTo(1)); + })); + } + } + + [Test] + public async Task CompareFloatingPointParameterWithIntegralAndFloatingPointColumnsAsync() + { + var decimalParam = 1.1m; + var singleParam = 2.2f; + var doubleParam = 3.3d; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Decimal == decimalParam || o.NullableShort >= decimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.Decimal == decimalParam || o.Long >= decimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == decimalParam || o.Integer != decimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == decimalParam || o.NullableInteger == decimalParam), "Decimal"}, + + {db.NumericEntities.Where(o => o.Single == singleParam || o.NullableShort >= singleParam), "Single"}, + {db.NumericEntities.Where(o => o.Single == singleParam || o.Long >= singleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == singleParam || o.Integer != singleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == singleParam || o.NullableInteger == singleParam), "Single"}, + + {db.NumericEntities.Where(o => o.Double == doubleParam || o.NullableShort >= doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.Double == doubleParam || o.Long >= doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == doubleParam || o.Integer != doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == doubleParam || o.NullableInteger == doubleParam), "Double"} + }; + var odbcDriver = Sfi.ConnectionProvider.Driver is OdbcDriver; + + foreach (var pair in queriables) + { + // Integral columns should be casted + await (AssertTotalParametersAsync( + pair.Key, + 1, + sql => + { + var matches = Regex.Matches(sql, @"cast\([\w\d]+\..+\)"); + Assert.That(matches.Count, Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(odbcDriver ? 2 : 1)); + })); + } + } + + [Test] + public async Task CompareFloatingPointParameterWithDifferentFloatingPointColumnsAsync() + { + if (Sfi.Dialect is FirebirdDialect) + { + Assert.Ignore("Due to the regex hack in FirebirdClientDriver, the parameters can be casted twice."); + } + + var singleParam = 2.2f; + var doubleParam = 3.3d; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Single == singleParam || o.Double >= singleParam), "Single"}, + {db.NumericEntities.Where(o => o.Double >= singleParam || o.Single == singleParam), "Single"}, + {db.NumericEntities.Where(o => o.Single == singleParam || o.NullableDouble <= singleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == singleParam || o.Double != singleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == singleParam || o.NullableDouble == singleParam), "Single"}, + + {db.NumericEntities.Where(o => o.Double == doubleParam || o.Single >= doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.Single >= doubleParam || o.Double == doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.Double == doubleParam || o.NullableSingle >= doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == doubleParam || o.Single != doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == doubleParam || o.NullableSingle == doubleParam), "Double"} + }; + var sameType = Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Single.SqlType, out var singleCast) && + Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Double.SqlType, out var doubleCast) && + singleCast == doubleCast; + var odbcDriver = Sfi.ConnectionProvider.Driver is OdbcDriver; + + foreach (var pair in queriables) + { + // Columns should be casted for Double parameter and parameters for Single parameter + await (AssertTotalParametersAsync( + pair.Key, + 1, + sql => + { + var matches = pair.Value == "Double" + ? Regex.Matches(sql, @"cast\([\w\d]+\..+\)") + : Regex.Matches(sql, @"cast\(((@|\?|:)p\d+|\?)\s+as.*\)"); + // SQLiteDialect uses sql cast for transparentcast method + Assert.That(matches.Count, Is.EqualTo(sameType && !(Sfi.Dialect is SQLiteDialect) ? 0 : 1)); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(odbcDriver ? 2 : 1)); + })); + } + } + + [Test] + public async Task CompareIntegralParameterWithIntegralAndFloatingPointColumnsAsync() + { + if (Sfi.Dialect is FirebirdDialect) + { + Assert.Ignore("Due to the regex hack in FirebirdClientDriver, the parameters can be casted twice."); + } + + short shortParam = 1; + var intParam = 2; + var longParam = 3L; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Short == shortParam || o.Double >= shortParam), "Int16"}, + {db.NumericEntities.Where(o => o.Short == shortParam || o.NullableDouble >= shortParam), "Int16"}, + {db.NumericEntities.Where(o => o.NullableShort == shortParam || o.Decimal != shortParam), "Int16"}, + {db.NumericEntities.Where(o => o.NullableShort == shortParam || o.NullableSingle > shortParam), "Int16"}, + + {db.NumericEntities.Where(o => o.Integer == intParam || o.Double >= intParam), "Int32"}, + {db.NumericEntities.Where(o => o.Integer == intParam || o.NullableDouble >= intParam), "Int32"}, + {db.NumericEntities.Where(o => o.NullableInteger == intParam || o.Decimal != intParam), "Int32"}, + {db.NumericEntities.Where(o => o.NullableInteger == intParam || o.NullableSingle > intParam), "Int32"}, + + {db.NumericEntities.Where(o => o.Long == longParam || o.Double >= longParam), "Int64"}, + {db.NumericEntities.Where(o => o.Long == longParam || o.NullableDouble >= longParam), "Int64"}, + {db.NumericEntities.Where(o => o.NullableLong == longParam || o.Decimal != longParam), "Int64"}, + {db.NumericEntities.Where(o => o.NullableLong == longParam || o.NullableSingle > longParam), "Int64"} + }; + var odbcDriver = Sfi.ConnectionProvider.Driver is OdbcDriver; + + foreach (var pair in queriables) + { + // Parameters should be casted + await (AssertTotalParametersAsync( + pair.Key, + 1, + sql => + { + var matches = Regex.Matches(sql, @"cast\(((@|\?|:)p\d+|\?)\s+as.*\)"); + Assert.That(matches.Count, Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(odbcDriver ? 2 : 1)); + })); + } + } + + [Test] + public async Task UsingValueTypeParameterOfDifferentTypeAsync() + { + var value = 1; + var queriables = new List> + { + db.NumericEntities.Where(o => o.Short == value), + db.NumericEntities.Where(o => o.Short != value), + db.NumericEntities.Where(o => o.Short >= value), + db.NumericEntities.Where(o => o.Short <= value), + db.NumericEntities.Where(o => o.Short > value), + db.NumericEntities.Where(o => o.Short < value), + + db.NumericEntities.Where(o => o.NullableShort == value), + db.NumericEntities.Where(o => o.NullableShort != value), + db.NumericEntities.Where(o => o.NullableShort >= value), + db.NumericEntities.Where(o => o.NullableShort <= value), + db.NumericEntities.Where(o => o.NullableShort > value), + db.NumericEntities.Where(o => o.NullableShort < value), + + db.NumericEntities.Where(o => o.NullableShort.Value == value), + db.NumericEntities.Where(o => o.NullableShort.Value != value), + db.NumericEntities.Where(o => o.NullableShort.Value >= value), + db.NumericEntities.Where(o => o.NullableShort.Value <= value), + db.NumericEntities.Where(o => o.NullableShort.Value > value), + db.NumericEntities.Where(o => o.NullableShort.Value < value) + }; + + foreach (var query in queriables) + { + await (AssertTotalParametersAsync( + query, + 1, + sql => Assert.That(sql, Does.Not.Contain("cast")))); + } + + if (Sfi.Dialect is FirebirdDialect) + { + Assert.Ignore("Due to the regex bug in FirebirdClientDriver, the parameters are not casted."); + } + + queriables = new List> + { + db.NumericEntities.Where(o => o.Short + value > value), + db.NumericEntities.Where(o => o.Short - value > value), + db.NumericEntities.Where(o => o.Short * value > value), + + db.NumericEntities.Where(o => o.NullableShort + value > value), + db.NumericEntities.Where(o => o.NullableShort - value > value), + db.NumericEntities.Where(o => o.NullableShort * value > value), + + db.NumericEntities.Where(o => o.NullableShort.Value + value > value), + db.NumericEntities.Where(o => o.NullableShort.Value - value > value), + db.NumericEntities.Where(o => o.NullableShort.Value * value > value), + }; + + var sameType = Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int16.SqlType, out var shortCast) && + Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int32.SqlType, out var intCast) && + shortCast == intCast; + foreach (var query in queriables) + { + await (AssertTotalParametersAsync( + query, + 1, + sql => { + // SQLiteDialect uses sql cast for transparentcast method + Assert.That(sql, !sameType || Sfi.Dialect is SQLiteDialect ? Does.Match("where\\s+cast") : (IResolveConstraint)Does.Not.Contain("cast")); + Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(!sameType || Sfi.Dialect is SQLiteDialect ? 1 : 0)); + })); + } + } + + [Test] + public async Task UsingValueTypeParameterTwiceOnNullablePropertyAsync() + { + short value = 1; + await (AssertTotalParametersAsync( + db.NumericEntities.Where(o => o.NullableShort == value && o.NullableShort != value && o.Short == value), + 1, sql => { + + Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(0)); + })); + } + + [Test] + public async Task UsingValueTypeParameterOnDifferentPropertiesAsync() + { + int value = 1; + await (AssertTotalParametersAsync( + db.NumericEntities.Where(o => o.NullableShort == value && o.NullableShort != value && o.Integer == value), + 1)); + + await (AssertTotalParametersAsync( + db.NumericEntities.Where(o => o.Integer == value && o.NullableShort == value && o.NullableShort != value), + 1)); + } + [Test] public async Task UsingParameterInEvaluatableExpressionAsync() { @@ -375,7 +775,12 @@ public async Task UsingTwoParametersInDMLDeleteAsync() 2)); } - private async Task AssertTotalParametersAsync(IQueryable query, int parameterNumber, int? linqParameterNumber = null, CancellationToken cancellationToken = default(CancellationToken)) + private Task AssertTotalParametersAsync(IQueryable query, int parameterNumber, Action sqlAction, CancellationToken cancellationToken = default(CancellationToken)) + { + return AssertTotalParametersAsync(query, parameterNumber, null, sqlAction, cancellationToken); + } + + private async Task AssertTotalParametersAsync(IQueryable query, int parameterNumber, int? linqParameterNumber = null, Action sqlAction = null, CancellationToken cancellationToken = default(CancellationToken)) { using (var sqlSpy = new SqlLogSpy()) { @@ -394,6 +799,8 @@ public async Task UsingTwoParametersInDMLDeleteAsync() await (query.ToListAsync(cancellationToken)); + sqlAction?.Invoke(sqlSpy.GetWholeLog()); + // In case of arrays two query plans will be stored, one with an one without expended parameters Assert.That(cache, Has.Count.EqualTo(linqParameterNumber.HasValue ? 2 : 1), "Query should be cacheable"); diff --git a/src/NHibernate.Test/DbScripts/MsSql2008DialectLinqReadonlyCreateScript.sql b/src/NHibernate.Test/DbScripts/MsSql2008DialectLinqReadonlyCreateScript.sql index 12235095e0d..cb42443637d 100644 Binary files a/src/NHibernate.Test/DbScripts/MsSql2008DialectLinqReadonlyCreateScript.sql and b/src/NHibernate.Test/DbScripts/MsSql2008DialectLinqReadonlyCreateScript.sql differ diff --git a/src/NHibernate.Test/DbScripts/MsSql2008DialectLinqReadonlyDropScript.sql b/src/NHibernate.Test/DbScripts/MsSql2008DialectLinqReadonlyDropScript.sql index 54136ab8381..766d9f34375 100644 Binary files a/src/NHibernate.Test/DbScripts/MsSql2008DialectLinqReadonlyDropScript.sql and b/src/NHibernate.Test/DbScripts/MsSql2008DialectLinqReadonlyDropScript.sql differ diff --git a/src/NHibernate.Test/DbScripts/MsSql2012DialectLinqReadonlyCreateScript.sql b/src/NHibernate.Test/DbScripts/MsSql2012DialectLinqReadonlyCreateScript.sql index 4d15d18a6e1..ec1b580e231 100644 --- a/src/NHibernate.Test/DbScripts/MsSql2012DialectLinqReadonlyCreateScript.sql +++ b/src/NHibernate.Test/DbScripts/MsSql2012DialectLinqReadonlyCreateScript.sql @@ -54,6 +54,25 @@ PRIMARY KEY CLUSTERED )WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] ) ON [PRIMARY] GO +CREATE TABLE [dbo].[NumericEntity]( + [Short] [smallint] IDENTITY(1,1) NOT NULL, + [NullableShort] [smallint] NULL, + [Integer] [int] NOT NULL, + [NullableInteger] [int] NULL, + [Long] [bigint] NOT NULL, + [NullableLong] [bigint] NULL, + [Decimal] [decimal](19, 5) NOT NULL, + [NullableDecimal] [decimal](19, 5) NULL, + [Single] [real] NOT NULL, + [NullableSingle] [real] NULL, + [Double] [float] NOT NULL, + [NullableDouble] [float] NULL, +PRIMARY KEY CLUSTERED +( + [Short] ASC +)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] +) ON [PRIMARY] +GO INSERT [dbo].[Suppliers] ([SupplierId], [CompanyName], [ContactName], [ContactTitle], [HomePage], [Address], [City], [Region], [PostalCode], [Country], [Phone], [Fax]) VALUES (1, N'Exotic Liquids', N'Charlotte Cooper', N'Purchasing Manager', N'', N'49 Gilbert St.', N'London', N'', N'EC1 4SD', N'UK', N'(171) 555-2222', N'') INSERT [dbo].[Suppliers] ([SupplierId], [CompanyName], [ContactName], [ContactTitle], [HomePage], [Address], [City], [Region], [PostalCode], [Country], [Phone], [Fax]) VALUES (2, N'New Orleans Cajun Delights', N'Shelley Burke', N'Order Administrator', N'#CAJUN.HTM#', N'P.O. Box 78934', N'New Orleans', N'LA', N'70117', N'USA', N'(100) 555-4822', N'') INSERT [dbo].[Suppliers] ([SupplierId], [CompanyName], [ContactName], [ContactTitle], [HomePage], [Address], [City], [Region], [PostalCode], [Country], [Phone], [Fax]) VALUES (3, N'Grandma Kelly''s Homestead', N'Regina Murphy', N'Sales Representative', N'', N'707 Oxford Rd.', N'Ann Arbor', N'MI', N'48104', N'USA', N'(313) 555-5735', N'(313) 555-3349') diff --git a/src/NHibernate.Test/DbScripts/MsSql2012DialectLinqReadonlyDropScript.sql b/src/NHibernate.Test/DbScripts/MsSql2012DialectLinqReadonlyDropScript.sql index facae91e28c..bab84a61127 100644 --- a/src/NHibernate.Test/DbScripts/MsSql2012DialectLinqReadonlyDropScript.sql +++ b/src/NHibernate.Test/DbScripts/MsSql2012DialectLinqReadonlyDropScript.sql @@ -53,3 +53,4 @@ DROP TABLE [dbo].[States] DROP TABLE [dbo].[Suppliers] DROP TABLE [dbo].[Region] DROP TABLE [dbo].[Physicians] +DROP TABLE [dbo].[NumericEntity] diff --git a/src/NHibernate.Test/DbScripts/PostgreSQL83DialectLinqReadonlyCreateScript.sql b/src/NHibernate.Test/DbScripts/PostgreSQL83DialectLinqReadonlyCreateScript.sql index dd9d8716a71..0ac95140fc2 100644 Binary files a/src/NHibernate.Test/DbScripts/PostgreSQL83DialectLinqReadonlyCreateScript.sql and b/src/NHibernate.Test/DbScripts/PostgreSQL83DialectLinqReadonlyCreateScript.sql differ diff --git a/src/NHibernate.Test/DbScripts/PostgreSQL83DialectLinqReadonlyDropScript.sql b/src/NHibernate.Test/DbScripts/PostgreSQL83DialectLinqReadonlyDropScript.sql index 27ff0b0e394..a9b1f2642ee 100644 Binary files a/src/NHibernate.Test/DbScripts/PostgreSQL83DialectLinqReadonlyDropScript.sql and b/src/NHibernate.Test/DbScripts/PostgreSQL83DialectLinqReadonlyDropScript.sql differ diff --git a/src/NHibernate.Test/Linq/LinqReadonlyTestsContext.cs b/src/NHibernate.Test/Linq/LinqReadonlyTestsContext.cs index 399eec6df39..d5d0a6bd71f 100644 --- a/src/NHibernate.Test/Linq/LinqReadonlyTestsContext.cs +++ b/src/NHibernate.Test/Linq/LinqReadonlyTestsContext.cs @@ -43,7 +43,8 @@ private IEnumerable Mappings "Northwind.Mappings.User.hbm.xml", "Northwind.Mappings.TimeSheet.hbm.xml", "Northwind.Mappings.Animal.hbm.xml", - "Northwind.Mappings.Patient.hbm.xml" + "Northwind.Mappings.Patient.hbm.xml", + "Northwind.Mappings.NumericEntity.hbm.xml" }; } } diff --git a/src/NHibernate.Test/Linq/LinqTestCase.cs b/src/NHibernate.Test/Linq/LinqTestCase.cs index daf14b9cd18..ecd9c7690e3 100755 --- a/src/NHibernate.Test/Linq/LinqTestCase.cs +++ b/src/NHibernate.Test/Linq/LinqTestCase.cs @@ -35,7 +35,8 @@ protected override string[] Mappings "Northwind.Mappings.TimeSheet.hbm.xml", "Northwind.Mappings.Animal.hbm.xml", "Northwind.Mappings.Patient.hbm.xml", - "Northwind.Mappings.DynamicUser.hbm.xml" + "Northwind.Mappings.DynamicUser.hbm.xml", + "Northwind.Mappings.NumericEntity.hbm.xml" }; } } diff --git a/src/NHibernate.Test/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Linq/NullComparisonTests.cs index 0ed569813f3..7eec9f09e22 100644 --- a/src/NHibernate.Test/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Linq/NullComparisonTests.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using NHibernate.Dialect; using NHibernate.Linq; using NHibernate.DomainModel.Northwind.Entities; using NUnit.Framework; @@ -460,6 +461,33 @@ public void NullEquality() Expect(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id == 5), Does.Not.Contain("is null").IgnoreCase); Expect(session.Query().Where(o => 5 == o.CreatedBy.ModifiedBy.Id), Does.Not.Contain("is null").IgnoreCase); + + if (Sfi.Dialect is FirebirdDialect) + { + return; + } + + Expect(db.NumericEntities.Where(o => o.NullableShort == o.NullableShort), WithIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => o.Short == o.Short), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => o.NullableShort == o.Short), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => o.Short == o.NullableShort), WithoutIsNullAndWithoutCast()); + + short value = 3; + Expect(db.NumericEntities.Where(o => o.NullableShort == value), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => value == o.NullableShort), WithoutIsNullAndWithoutCast()); + + Expect(db.NumericEntities.Where(o => o.NullableShort.Value == value), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => value == o.NullableShort.Value), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => o.Short == value), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => value == o.Short), WithoutIsNullAndWithoutCast()); + + Expect(db.NumericEntities.Where(o => o.NullableShort == 3L), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => 3L == o.NullableShort), WithoutIsNullAndWithoutCast()); + + Expect(db.NumericEntities.Where(o => o.NullableShort.Value == 3L), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => 3L == o.NullableShort.Value), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => o.Short == 3L), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => 3L == o.Short), WithoutIsNullAndWithoutCast()); } [Test] @@ -548,6 +576,43 @@ public void NullInequality() Expect(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id != 5), Does.Contain("is null").IgnoreCase); Expect(session.Query().Where(o => 5 != o.CreatedBy.ModifiedBy.Id), Does.Contain("is null").IgnoreCase); + + if (Sfi.Dialect is FirebirdDialect) + { + return; + } + + Expect(db.NumericEntities.Where(o => o.NullableShort != o.NullableShort), WithIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => o.Short != o.Short), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => o.NullableShort != o.Short), WithIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => o.Short != o.NullableShort), WithIsNullAndWithoutCast()); + + short value = 3; + Expect(db.NumericEntities.Where(o => o.NullableShort != value), WithIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => value != o.NullableShort), WithIsNullAndWithoutCast()); + + Expect(db.NumericEntities.Where(o => o.NullableShort.Value != value), WithIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => value != o.NullableShort.Value), WithIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => o.Short != value), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => value != o.Short), WithoutIsNullAndWithoutCast()); + + Expect(db.NumericEntities.Where(o => o.NullableShort != 3L), WithIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => 3 != o.NullableShort), WithIsNullAndWithoutCast()); + + Expect(db.NumericEntities.Where(o => o.NullableShort.Value != 3L), WithIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => 3L != o.NullableShort.Value), WithIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => o.Short != 3L), WithoutIsNullAndWithoutCast()); + Expect(db.NumericEntities.Where(o => 3L != o.Short), WithoutIsNullAndWithoutCast()); + } + + private IResolveConstraint WithIsNullAndWithoutCast() + { + return Does.Contain("is null").IgnoreCase.And.Not.Contain("cast").IgnoreCase; + } + + private IResolveConstraint WithoutIsNullAndWithoutCast() + { + return Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast").IgnoreCase; } [Test] diff --git a/src/NHibernate.Test/Linq/ParameterTests.cs b/src/NHibernate.Test/Linq/ParameterTests.cs index 97da1e0a079..73028ba8598 100644 --- a/src/NHibernate.Test/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTests.cs @@ -4,11 +4,14 @@ using System.Linq.Expressions; using System.Reflection; using System.Text.RegularExpressions; +using NHibernate.Dialect; using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Driver; using NHibernate.Engine.Query; using NHibernate.Linq; using NHibernate.Util; using NUnit.Framework; +using NUnit.Framework.Constraints; namespace NHibernate.Test.Linq { @@ -113,6 +116,403 @@ public void UsingValueTypeParameterTwice() 1); } + [Test] + public void CompareIntegralParametersAndColumns() + { + short shortParam = 1; + var intParam = 2; + var longParam = 3L; + short? nullShortParam = 1; + int? nullIntParam = 2; + long? nullLongParam = 3L; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Short == shortParam || o.Short < intParam || o.Short > longParam), "Int16"}, + {db.NumericEntities.Where(o => o.NullableShort == shortParam || o.NullableShort <= intParam || o.NullableShort != longParam), "Int16"}, + {db.NumericEntities.Where(o => o.Short == nullShortParam || o.Short < nullIntParam || o.Short > nullLongParam), "Int16"}, + {db.NumericEntities.Where(o => o.NullableShort == nullShortParam || o.NullableShort <= nullIntParam || o.NullableShort != nullLongParam), "Int16"}, + + {db.NumericEntities.Where(o => o.Integer == shortParam || o.Integer < intParam || o.Integer > longParam), "Int32"}, + {db.NumericEntities.Where(o => o.NullableInteger == shortParam || o.NullableInteger <= intParam || o.NullableInteger != longParam), "Int32"}, + {db.NumericEntities.Where(o => o.Integer == nullShortParam || o.Integer < nullIntParam || o.Integer > nullLongParam), "Int32"}, + {db.NumericEntities.Where(o => o.NullableInteger == nullShortParam || o.NullableInteger <= nullIntParam || o.NullableInteger != nullLongParam), "Int32"}, + + {db.NumericEntities.Where(o => o.Long == shortParam || o.Long < intParam || o.Long > longParam), "Int64"}, + {db.NumericEntities.Where(o => o.NullableLong == shortParam || o.NullableLong <= intParam || o.NullableLong != longParam), "Int64"}, + {db.NumericEntities.Where(o => o.Long == nullShortParam || o.Long < nullIntParam || o.Long > nullLongParam), "Int64"}, + {db.NumericEntities.Where(o => o.NullableLong == nullShortParam || o.NullableLong <= nullIntParam || o.NullableLong != nullLongParam), "Int64"} + }; + + foreach (var pair in queriables) + { + // Parameters should be pre-evaluated + AssertTotalParameters( + pair.Key, + 3, + sql => + { + Assert.That(sql, Does.Not.Contain("cast")); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(3)); + }); + } + } + + [Test] + public void CompareIntegralParametersWithFloatingPointColumns() + { + short shortParam = 1; + var intParam = 2; + var longParam = 3L; + short? nullShortParam = 1; + int? nullIntParam = 2; + long? nullLongParam = 3L; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Decimal == shortParam || o.Decimal < intParam || o.Decimal > longParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == shortParam || o.NullableDecimal <= intParam || o.NullableDecimal != longParam), "Decimal"}, + {db.NumericEntities.Where(o => o.Decimal == nullShortParam || o.Decimal < nullIntParam || o.Decimal > nullLongParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == nullShortParam || o.NullableDecimal <= nullIntParam || o.NullableDecimal != nullLongParam), "Decimal"}, + + {db.NumericEntities.Where(o => o.Single == shortParam || o.Single < intParam || o.Single > longParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == shortParam || o.NullableSingle <= intParam || o.NullableSingle != longParam), "Single"}, + {db.NumericEntities.Where(o => o.Single == nullShortParam || o.Single < nullIntParam || o.Single > nullLongParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == nullShortParam || o.NullableSingle <= nullIntParam || o.NullableSingle != nullLongParam), "Single"}, + + {db.NumericEntities.Where(o => o.Double == shortParam || o.Double < intParam || o.Double > longParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == shortParam || o.NullableDouble <= intParam || o.NullableDouble != longParam), "Double"}, + {db.NumericEntities.Where(o => o.Double == nullShortParam || o.Double < nullIntParam || o.Double > nullLongParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == nullShortParam || o.NullableDouble <= nullIntParam || o.NullableDouble != nullLongParam), "Double"}, + }; + + foreach (var pair in queriables) + { + // Parameters should be pre-evaluated + AssertTotalParameters( + pair.Key, + 3, + sql => + { + Assert.That(sql, Does.Not.Contain("cast")); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(3)); + }); + } + } + + [Test] + public void CompareFloatingPointParametersAndColumns() + { + var decimalParam = 1.1m; + var singleParam = 2.2f; + var doubleParam = 3.3d; + decimal? nullDecimalParam = 1.1m; + float? nullSingleParam = 2.2f; + double? nullDoubleParam = 3.3d; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Decimal == decimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == decimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.Decimal == nullDecimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == nullDecimalParam), "Decimal"}, + + {db.NumericEntities.Where(o => o.Single <= singleParam || o.Single >= doubleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == singleParam || o.NullableSingle != doubleParam), "Single"}, + {db.NumericEntities.Where(o => o.Single <= nullSingleParam || o.Single >= nullDoubleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == nullSingleParam || o.NullableSingle != nullDoubleParam), "Single"}, + + {db.NumericEntities.Where(o => o.Double <= singleParam || o.Double >= doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == singleParam || o.NullableDouble != doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.Double <= nullSingleParam || o.Double >= nullDoubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == nullSingleParam || o.NullableDouble != nullDoubleParam), "Double"}, + }; + + foreach (var pair in queriables) + { + var totalParameters = pair.Value == "Decimal" ? 1 : 2; + // Parameters should be pre-evaluated + AssertTotalParameters( + pair.Key, + totalParameters, + sql => + { + Assert.That(sql, Does.Not.Contain("cast")); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(totalParameters)); + }); + } + } + + [Test] + public void CompareFloatingPointParametersWithIntegralColumns() + { + var decimalParam = 1.1m; + var singleParam = 2.2f; + var doubleParam = 3.3d; + decimal? nullDecimalParam = 1.1m; + float? nullSingleParam = 2.2f; + double? nullDoubleParam = 3.3d; + var queriables = new List> + { + db.NumericEntities.Where(o => o.Short == decimalParam || o.Short != singleParam || o.Short <= doubleParam), + db.NumericEntities.Where(o => o.NullableShort <= decimalParam || o.NullableShort == singleParam || o.NullableShort >= doubleParam), + db.NumericEntities.Where(o => o.Short == nullDecimalParam || o.Short != nullSingleParam || o.Short <= nullDoubleParam), + db.NumericEntities.Where(o => o.NullableShort <= nullDecimalParam || o.NullableShort == nullSingleParam || o.NullableShort >= nullDoubleParam), + + db.NumericEntities.Where(o => o.Integer == decimalParam || o.Integer != singleParam || o.Integer <= doubleParam), + db.NumericEntities.Where(o => o.NullableInteger <= decimalParam || o.NullableInteger == singleParam || o.NullableInteger >= doubleParam), + db.NumericEntities.Where(o => o.Integer == nullDecimalParam || o.Integer != nullSingleParam || o.Integer <= nullDoubleParam), + db.NumericEntities.Where(o => o.NullableInteger <= nullDecimalParam || o.NullableInteger == nullSingleParam || o.NullableInteger >= nullDoubleParam), + + db.NumericEntities.Where(o => o.Long == decimalParam || o.Long != singleParam || o.Long <= doubleParam), + db.NumericEntities.Where(o => o.NullableLong <= decimalParam || o.NullableLong == singleParam || o.NullableLong >= doubleParam), + db.NumericEntities.Where(o => o.Long == nullDecimalParam || o.Long != nullSingleParam || o.Long <= nullDoubleParam), + db.NumericEntities.Where(o => o.NullableLong <= nullDecimalParam || o.NullableLong == nullSingleParam || o.NullableLong >= nullDoubleParam), + }; + + foreach (var query in queriables) + { + // Columns should be casted + AssertTotalParameters( + query, + 3, + sql => + { + var matches = Regex.Matches(sql, @"cast\([\w\d]+\..+\)"); + Assert.That(matches.Count, Is.EqualTo(3)); + Assert.That(GetTotalOccurrences(sql, $"Type: Decimal"), Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, $"Type: Single"), Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, $"Type: Double"), Is.EqualTo(1)); + }); + } + } + + [Test] + public void CompareFloatingPointParameterWithIntegralAndFloatingPointColumns() + { + var decimalParam = 1.1m; + var singleParam = 2.2f; + var doubleParam = 3.3d; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Decimal == decimalParam || o.NullableShort >= decimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.Decimal == decimalParam || o.Long >= decimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == decimalParam || o.Integer != decimalParam), "Decimal"}, + {db.NumericEntities.Where(o => o.NullableDecimal == decimalParam || o.NullableInteger == decimalParam), "Decimal"}, + + {db.NumericEntities.Where(o => o.Single == singleParam || o.NullableShort >= singleParam), "Single"}, + {db.NumericEntities.Where(o => o.Single == singleParam || o.Long >= singleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == singleParam || o.Integer != singleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == singleParam || o.NullableInteger == singleParam), "Single"}, + + {db.NumericEntities.Where(o => o.Double == doubleParam || o.NullableShort >= doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.Double == doubleParam || o.Long >= doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == doubleParam || o.Integer != doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == doubleParam || o.NullableInteger == doubleParam), "Double"} + }; + var odbcDriver = Sfi.ConnectionProvider.Driver is OdbcDriver; + + foreach (var pair in queriables) + { + // Integral columns should be casted + AssertTotalParameters( + pair.Key, + 1, + sql => + { + var matches = Regex.Matches(sql, @"cast\([\w\d]+\..+\)"); + Assert.That(matches.Count, Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(odbcDriver ? 2 : 1)); + }); + } + } + + [Test] + public void CompareFloatingPointParameterWithDifferentFloatingPointColumns() + { + if (Sfi.Dialect is FirebirdDialect) + { + Assert.Ignore("Due to the regex hack in FirebirdClientDriver, the parameters can be casted twice."); + } + + var singleParam = 2.2f; + var doubleParam = 3.3d; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Single == singleParam || o.Double >= singleParam), "Single"}, + {db.NumericEntities.Where(o => o.Double >= singleParam || o.Single == singleParam), "Single"}, + {db.NumericEntities.Where(o => o.Single == singleParam || o.NullableDouble <= singleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == singleParam || o.Double != singleParam), "Single"}, + {db.NumericEntities.Where(o => o.NullableSingle == singleParam || o.NullableDouble == singleParam), "Single"}, + + {db.NumericEntities.Where(o => o.Double == doubleParam || o.Single >= doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.Single >= doubleParam || o.Double == doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.Double == doubleParam || o.NullableSingle >= doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == doubleParam || o.Single != doubleParam), "Double"}, + {db.NumericEntities.Where(o => o.NullableDouble == doubleParam || o.NullableSingle == doubleParam), "Double"} + }; + var sameType = Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Single.SqlType, out var singleCast) && + Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Double.SqlType, out var doubleCast) && + singleCast == doubleCast; + var odbcDriver = Sfi.ConnectionProvider.Driver is OdbcDriver; + + foreach (var pair in queriables) + { + // Columns should be casted for Double parameter and parameters for Single parameter + AssertTotalParameters( + pair.Key, + 1, + sql => + { + var matches = pair.Value == "Double" + ? Regex.Matches(sql, @"cast\([\w\d]+\..+\)") + : Regex.Matches(sql, @"cast\(((@|\?|:)p\d+|\?)\s+as.*\)"); + // SQLiteDialect uses sql cast for transparentcast method + Assert.That(matches.Count, Is.EqualTo(sameType && !(Sfi.Dialect is SQLiteDialect) ? 0 : 1)); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(odbcDriver ? 2 : 1)); + }); + } + } + + [Test] + public void CompareIntegralParameterWithIntegralAndFloatingPointColumns() + { + if (Sfi.Dialect is FirebirdDialect) + { + Assert.Ignore("Due to the regex hack in FirebirdClientDriver, the parameters can be casted twice."); + } + + short shortParam = 1; + var intParam = 2; + var longParam = 3L; + var queriables = new Dictionary, string> + { + {db.NumericEntities.Where(o => o.Short == shortParam || o.Double >= shortParam), "Int16"}, + {db.NumericEntities.Where(o => o.Short == shortParam || o.NullableDouble >= shortParam), "Int16"}, + {db.NumericEntities.Where(o => o.NullableShort == shortParam || o.Decimal != shortParam), "Int16"}, + {db.NumericEntities.Where(o => o.NullableShort == shortParam || o.NullableSingle > shortParam), "Int16"}, + + {db.NumericEntities.Where(o => o.Integer == intParam || o.Double >= intParam), "Int32"}, + {db.NumericEntities.Where(o => o.Integer == intParam || o.NullableDouble >= intParam), "Int32"}, + {db.NumericEntities.Where(o => o.NullableInteger == intParam || o.Decimal != intParam), "Int32"}, + {db.NumericEntities.Where(o => o.NullableInteger == intParam || o.NullableSingle > intParam), "Int32"}, + + {db.NumericEntities.Where(o => o.Long == longParam || o.Double >= longParam), "Int64"}, + {db.NumericEntities.Where(o => o.Long == longParam || o.NullableDouble >= longParam), "Int64"}, + {db.NumericEntities.Where(o => o.NullableLong == longParam || o.Decimal != longParam), "Int64"}, + {db.NumericEntities.Where(o => o.NullableLong == longParam || o.NullableSingle > longParam), "Int64"} + }; + var odbcDriver = Sfi.ConnectionProvider.Driver is OdbcDriver; + + foreach (var pair in queriables) + { + // Parameters should be casted + AssertTotalParameters( + pair.Key, + 1, + sql => + { + var matches = Regex.Matches(sql, @"cast\(((@|\?|:)p\d+|\?)\s+as.*\)"); + Assert.That(matches.Count, Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(odbcDriver ? 2 : 1)); + }); + } + } + + [Test] + public void UsingValueTypeParameterOfDifferentType() + { + var value = 1; + var queriables = new List> + { + db.NumericEntities.Where(o => o.Short == value), + db.NumericEntities.Where(o => o.Short != value), + db.NumericEntities.Where(o => o.Short >= value), + db.NumericEntities.Where(o => o.Short <= value), + db.NumericEntities.Where(o => o.Short > value), + db.NumericEntities.Where(o => o.Short < value), + + db.NumericEntities.Where(o => o.NullableShort == value), + db.NumericEntities.Where(o => o.NullableShort != value), + db.NumericEntities.Where(o => o.NullableShort >= value), + db.NumericEntities.Where(o => o.NullableShort <= value), + db.NumericEntities.Where(o => o.NullableShort > value), + db.NumericEntities.Where(o => o.NullableShort < value), + + db.NumericEntities.Where(o => o.NullableShort.Value == value), + db.NumericEntities.Where(o => o.NullableShort.Value != value), + db.NumericEntities.Where(o => o.NullableShort.Value >= value), + db.NumericEntities.Where(o => o.NullableShort.Value <= value), + db.NumericEntities.Where(o => o.NullableShort.Value > value), + db.NumericEntities.Where(o => o.NullableShort.Value < value) + }; + + foreach (var query in queriables) + { + AssertTotalParameters( + query, + 1, + sql => Assert.That(sql, Does.Not.Contain("cast"))); + } + + if (Sfi.Dialect is FirebirdDialect) + { + Assert.Ignore("Due to the regex bug in FirebirdClientDriver, the parameters are not casted."); + } + + queriables = new List> + { + db.NumericEntities.Where(o => o.Short + value > value), + db.NumericEntities.Where(o => o.Short - value > value), + db.NumericEntities.Where(o => o.Short * value > value), + + db.NumericEntities.Where(o => o.NullableShort + value > value), + db.NumericEntities.Where(o => o.NullableShort - value > value), + db.NumericEntities.Where(o => o.NullableShort * value > value), + + db.NumericEntities.Where(o => o.NullableShort.Value + value > value), + db.NumericEntities.Where(o => o.NullableShort.Value - value > value), + db.NumericEntities.Where(o => o.NullableShort.Value * value > value), + }; + + var sameType = Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int16.SqlType, out var shortCast) && + Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int32.SqlType, out var intCast) && + shortCast == intCast; + foreach (var query in queriables) + { + AssertTotalParameters( + query, + 1, + sql => { + // SQLiteDialect uses sql cast for transparentcast method + Assert.That(sql, !sameType || Sfi.Dialect is SQLiteDialect ? Does.Match("where\\s+cast") : (IResolveConstraint)Does.Not.Contain("cast")); + Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(!sameType || Sfi.Dialect is SQLiteDialect ? 1 : 0)); + }); + } + } + + [Test] + public void UsingValueTypeParameterTwiceOnNullableProperty() + { + short value = 1; + AssertTotalParameters( + db.NumericEntities.Where(o => o.NullableShort == value && o.NullableShort != value && o.Short == value), + 1, sql => { + + Assert.That(GetTotalOccurrences(sql, "cast"), Is.EqualTo(0)); + }); + } + + [Test] + public void UsingValueTypeParameterOnDifferentProperties() + { + int value = 1; + AssertTotalParameters( + db.NumericEntities.Where(o => o.NullableShort == value && o.NullableShort != value && o.Integer == value), + 1); + + AssertTotalParameters( + db.NumericEntities.Where(o => o.Integer == value && o.NullableShort == value && o.NullableShort != value), + 1); + } + [Test] public void UsingParameterInEvaluatableExpression() { @@ -436,7 +836,12 @@ public void DMLDeleteShouldHaveSameCacheKeys() Assert.That(expression1.Key, Is.EqualTo(expression2.Key)); } - private void AssertTotalParameters(IQueryable query, int parameterNumber, int? linqParameterNumber = null) + private void AssertTotalParameters(IQueryable query, int parameterNumber, Action sqlAction) + { + AssertTotalParameters(query, parameterNumber, null, sqlAction); + } + + private void AssertTotalParameters(IQueryable query, int parameterNumber, int? linqParameterNumber = null, Action sqlAction = null) { using (var sqlSpy = new SqlLogSpy()) { @@ -455,6 +860,8 @@ private void AssertTotalParameters(IQueryable query, int parameterNumber, query.ToList(); + sqlAction?.Invoke(sqlSpy.GetWholeLog()); + // In case of arrays two query plans will be stored, one with an one without expended parameters Assert.That(cache, Has.Count.EqualTo(linqParameterNumber.HasValue ? 2 : 1), "Query should be cacheable"); diff --git a/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs b/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs index ba4505dda8c..48cc2e14048 100644 --- a/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs @@ -64,7 +64,7 @@ public void GreaterThanTest() AssertResults( new Dictionary> { - {"2.1", o => o is Int32Type} + {"2.1", o => o is DoubleType} }, db.Users.Where(o => o.Id > 2.1), db.Users.Where(o => 2.1 > o.Id) @@ -99,6 +99,82 @@ public void EqualsMethodStringTest() ); } + [Test] + public void BinaryIntShortTest() + { + AssertResults( + new Dictionary> + { + {"3", o => o is Int16Type} + }, + db.NumericEntities.Where(o => o.Short == 3), + db.NumericEntities.Where(o => 3 == o.Short), + db.NumericEntities.Where(o => o.Short < 3), + db.NumericEntities.Where(o => 3 < o.Short), + db.NumericEntities.Where(o => o.Short > 3), + db.NumericEntities.Where(o => 3 > o.Short), + db.NumericEntities.Where(o => o.Short >= 3), + db.NumericEntities.Where(o => 3 >= o.Short), + db.NumericEntities.Where(o => o.Short <= 3), + db.NumericEntities.Where(o => 3 <= o.Short), + db.NumericEntities.Where(o => o.Short != 3), + db.NumericEntities.Where(o => 3 != o.Short), + + db.NumericEntities.Where(o => o.NullableShort == 3), + db.NumericEntities.Where(o => 3 == o.NullableShort), + db.NumericEntities.Where(o => o.NullableShort < 3), + db.NumericEntities.Where(o => 3 < o.NullableShort), + db.NumericEntities.Where(o => o.NullableShort > 3), + db.NumericEntities.Where(o => 3 > o.NullableShort), + db.NumericEntities.Where(o => o.NullableShort >= 3), + db.NumericEntities.Where(o => 3 >= o.NullableShort), + db.NumericEntities.Where(o => o.NullableShort <= 3), + db.NumericEntities.Where(o => 3 <= o.NullableShort), + db.NumericEntities.Where(o => o.NullableShort != 3), + db.NumericEntities.Where(o => 3 != o.NullableShort), + + db.NumericEntities.Where(o => o.NullableShort.Value == 3), + db.NumericEntities.Where(o => 3 == o.NullableShort.Value), + db.NumericEntities.Where(o => o.NullableShort.Value < 3), + db.NumericEntities.Where(o => 3 < o.NullableShort.Value), + db.NumericEntities.Where(o => o.NullableShort.Value > 3), + db.NumericEntities.Where(o => 3 > o.NullableShort.Value), + db.NumericEntities.Where(o => o.NullableShort.Value >= 3), + db.NumericEntities.Where(o => 3 >= o.NullableShort.Value), + db.NumericEntities.Where(o => o.NullableShort.Value <= 3), + db.NumericEntities.Where(o => 3 <= o.NullableShort.Value), + db.NumericEntities.Where(o => o.NullableShort.Value != 3), + db.NumericEntities.Where(o => 3 != o.NullableShort.Value) + ); + } + + [Test] + public void BinaryNullableIntShortTest() + { + int? value = 3; + AssertResults( + new Dictionary> + { + {"3", o => o is Int16Type} + }, + db.NumericEntities.Where(o => o.Short == value), + db.NumericEntities.Where(o => value == o.Short), + db.NumericEntities.Where(o => o.Short < value), + db.NumericEntities.Where(o => value < o.Short), + + db.NumericEntities.Where(o => o.NullableShort == value), + db.NumericEntities.Where(o => value == o.NullableShort), + db.NumericEntities.Where(o => o.NullableShort < value), + db.NumericEntities.Where(o => value < o.NullableShort), + + db.NumericEntities.Where(o => o.NullableShort.Value == value), + db.NumericEntities.Where(o => value == o.NullableShort.Value), + db.NumericEntities.Where(o => o.NullableShort.Value < value), + db.NumericEntities.Where(o => value < o.NullableShort.Value), + db.NumericEntities.Where(o => o.NullableShort.Value > value) + ); + } + [Test] public void ContainsStringEnumTest() { diff --git a/src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs b/src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs index 538e46cb828..8feac321a1c 100644 --- a/src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs +++ b/src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs @@ -1,4 +1,5 @@ using System.Linq.Expressions; +using NHibernate.Util; using Remotion.Linq.Parsing.ExpressionVisitors.Transformation; namespace NHibernate.Linq.ExpressionTransformers @@ -26,6 +27,13 @@ public Expression Transform(UnaryExpression expression) return expression.Operand; } + // Reduce double casting (e.g. (long?)(long)3 => (long?)3) + if (expression.Operand.NodeType == ExpressionType.Convert && + expression.Type.UnwrapIfNullable() == expression.Operand.Type) + { + return Expression.Convert(((UnaryExpression) expression.Operand).Operand, expression.Type); + } + return expression; } @@ -34,4 +42,4 @@ public ExpressionType[] SupportedExpressionTypes get { return _supportedExpressionTypes; } } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs index f6a9e5de43f..a70109bd59c 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs @@ -141,7 +141,7 @@ protected override Expression VisitConstant(ConstantExpression expression) // We have a bit more information about the null parameter value. // Figure out a type so that HQL doesn't break on the null. (Related to NH-2430) - // In v5.3 types are calculated by ConstantTypeLocator, this logic is only for back compatibility. + // In v5.3 types are calculated by ParameterTypeLocator, this logic is only for back compatibility. // TODO 6.0: Remove if (expression.Value == null) type = NHibernateUtil.GuessType(expression.Type); diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index 34315c240d2..6dc938bdae2 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Data; using System.Dynamic; using System.Linq; @@ -23,6 +24,7 @@ public class HqlGeneratorExpressionVisitor : IHqlExpressionVisitor private readonly VisitorParameters _parameters; private readonly ILinqToHqlGeneratorsRegistry _functionRegistry; private readonly NullableExpressionDetector _nullableExpressionDetector; + private readonly Dictionary _notCastableExpressions = new Dictionary(); public static HqlTreeNode Visit(Expression expression, VisitorParameters parameters) { @@ -308,6 +310,17 @@ private HqlTreeNode VisitVBStringComparisonExpression(VBStringComparisonExpressi protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression) { + // There are some cases where we do not want to add a sql cast: + // - When comparing numeric types that do not have their own operator (e.g. short == short) + // - When comparing a member expression with a parameter of similar type (e.g. o.Short == intParameter) + var leftType = GetExpressionType(expression.Left); + var rightType = GetExpressionType(expression.Right); + if (leftType != null && leftType == rightType) + { + _notCastableExpressions.Add(expression.Left, leftType); + _notCastableExpressions.Add(expression.Right, rightType); + } + if (expression.NodeType == ExpressionType.Equal) { return TranslateEqualityComparison(expression); @@ -374,6 +387,23 @@ protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression) throw new InvalidOperationException(); } + private System.Type GetExpressionType(Expression expression) + { + switch (expression.NodeType) + { + case ExpressionType.Constant: + return _parameters.ConstantToParameterMap.TryGetValue((ConstantExpression) expression, out var param) + ? param.Type?.ReturnedClass + : null; + case ExpressionType.Convert: + case ExpressionType.ConvertChecked: + var operand = ((UnaryExpression) expression).Operand; + return GetExpressionType(operand) ?? operand.Type.UnwrapIfNullable(); + } + + return null; + } + private HqlTreeNode TranslateInequalityComparison(BinaryExpression expression) { var lhs = VisitExpression(expression.Left).ToArithmeticExpression(); @@ -496,11 +526,17 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression) case ExpressionType.Convert: case ExpressionType.ConvertChecked: case ExpressionType.TypeAs: - return IsCastRequired(expression.Operand, expression.Type, out var existType) - ? _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), expression.Type) + var castable = !_notCastableExpressions.TryGetValue(expression, out var castType); + if (castable) + { + castType = expression.Type; + } + + return IsCastRequired(expression.Operand, castType, out var existType) && castable + ? _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), castType) // Make a transparent cast when an IType exists, so that it can be used to retrieve the value from the data reader - : existType && HqlIdent.SupportsType(expression.Type) - ? _hqlTreeBuilder.TransparentCast(VisitExpression(expression.Operand).AsExpression(), expression.Type) + : existType && HqlIdent.SupportsType(castType) + ? _hqlTreeBuilder.TransparentCast(VisitExpression(expression.Operand).AsExpression(), castType) : VisitExpression(expression.Operand); } diff --git a/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs index 9ee40092e6b..475dbcdad31 100644 --- a/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs @@ -80,8 +80,7 @@ public override Expression Visit(Expression expression) if (expression.NodeType == ExpressionType.Lambda || !_partialEvaluationInfo.IsEvaluatableExpression(expression) || #region NH additions // Variables should be evaluated only when they are part of an evaluatable expression (e.g. o => string.Format("...", variable)) - expression is UnaryExpression unaryExpression && - ExpressionsHelper.IsVariable(unaryExpression.Operand, out _, out _)) + ContainsVariable(expression)) #endregion return base.Visit(expression); @@ -154,8 +153,27 @@ private Expression EvaluateSubtree(Expression subtree) } } + #region NH additions + + private bool ContainsVariable(Expression expression) + { + if (!(expression is UnaryExpression unaryExpression)) + { + return false; + } + + return ExpressionsHelper.IsVariable(unaryExpression.Operand, out _, out _) || + // Check whether the variable is casted due to comparison with a nullable expression + // (e.g. o.NullableShort == shortVariable) + unaryExpression.Operand is UnaryExpression subUnaryExpression && + unaryExpression.Type.UnwrapIfNullable() == subUnaryExpression.Type && + ExpressionsHelper.IsVariable(subUnaryExpression.Operand, out _, out _); + } + #endregion - + + #endregion + protected override Expression VisitConstant(ConstantExpression expression) { if (expression.Value is Expression value) diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 457e04fcbdb..37e3da19852 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -13,6 +13,7 @@ using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Clauses.ResultOperators; using Remotion.Linq.Parsing; +using TypeExtensions = NHibernate.Util.TypeExtensions; namespace NHibernate.Linq.Visitors { @@ -49,6 +50,7 @@ public static class ParameterTypeLocator ExpressionType.Conditional }; + /// /// Set query parameter types based on the given query model. /// @@ -80,59 +82,86 @@ internal static void SetParameterTypes( var visitor = new ConstantTypeLocatorVisitor(removeMappedAsCalls, targetType, parameters, sessionFactory); queryModel.TransformExpressions(visitor.Visit); - foreach (var pair in visitor.ConstantExpressions) + foreach (var pair in visitor.ParameterConstants) { - var type = pair.Value; - var constantExpression = pair.Key; - if (!parameters.TryGetValue(constantExpression, out var namedParameter)) + var namedParameter = pair.Key; + var constantExpressions = pair.Value; + // In case any of the constants has the type set, use it (e.g. MappedAs) + namedParameter.Type = constantExpressions.Select(o => visitor.ConstantExpressions[o]).FirstOrDefault(o => o != null); + if (namedParameter.Type != null) { continue; } - if (type != null) - { - // MappedAs was used - namedParameter.Type = type; - continue; - } + namedParameter.Type = GetParameterType(sessionFactory, constantExpressions, visitor, namedParameter); + } + } + private static IType GetCandidateType( + ISessionFactoryImplementor sessionFactory, + IEnumerable constantExpressions, + ConstantTypeLocatorVisitor visitor, + System.Type constantType) + { + IType candidateType = null; + foreach (var expression in constantExpressions) + { // In order to get the actual type we have to check first the related member expressions, as // an enum is translated in a numeric type when used in a BinaryExpression and also it can be mapped as string. // By getting the type from a related member expression we also get the correct length in case of StringType // or precision when having a DecimalType. - if (visitor.RelatedExpressions.TryGetValue(constantExpression, out var memberExpressions)) + if (!visitor.RelatedExpressions.TryGetValue(expression, out var relatedExpressions)) + continue; + foreach (var relatedExpression in relatedExpressions) { - foreach (var memberExpression in memberExpressions) - { - if (ExpressionsHelper.TryGetMappedType( - sessionFactory, - memberExpression, - out type, - out _, - out _, - out _)) - { - if (type.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(memberExpression)) - { - var collection = (IQueryableCollection) ((IAssociationType) type).GetAssociatedJoinable(sessionFactory); - type = collection.ElementType; - } + if (!ExpressionsHelper.TryGetMappedType(sessionFactory, relatedExpression, out var mappedType, out _, out _, out _)) + continue; - break; - } + if (mappedType.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(relatedExpression)) + { + var collection = (IQueryableCollection) ((IAssociationType) mappedType).GetAssociatedJoinable(sessionFactory); + mappedType = collection.ElementType; } - } - // No related MemberExpressions was found, guess the type by value or its type when null. - if (type == null) - { - type = constantExpression.Value != null - ? ParameterHelper.TryGuessType(constantExpression.Value, sessionFactory, namedParameter.IsCollection) - : ParameterHelper.TryGuessType(constantExpression.Type, sessionFactory, namedParameter.IsCollection); + if (candidateType == null) + candidateType = mappedType; + else if (!candidateType.Equals(mappedType)) + return null; } + } + + if (candidateType == null) + return null; + + // When comparing an integral column with a real parameter, the parameter type must remain real type + // and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)). + if (constantType.IsRealNumberType() && candidateType.ReturnedClass.IsIntegralNumberType()) + return null; - namedParameter.Type = type; + return candidateType; + } + + private static IType GetParameterType( + ISessionFactoryImplementor sessionFactory, + HashSet constantExpressions, + ConstantTypeLocatorVisitor visitor, + NamedParameter namedParameter) + { + // All constant expressions have the same type/value + var constantExpression = constantExpressions.First(); + var constantType = constantExpression.Type.UnwrapIfNullable(); + var candidateType = GetCandidateType(sessionFactory, constantExpressions, visitor, constantType); + if (candidateType != null) + { + return candidateType; } + + // No related MemberExpressions was found, guess the type by value or its type when null. + // When a numeric parameter is compared to different columns with different types (e.g. Where(o => o.Single >= singleParam || o.Double <= singleParam)) + // do not change the parameter type, but instead cast the parameter when comparing with different column types. + return constantExpression.Value != null + ? ParameterHelper.TryGuessType(constantExpression.Value, sessionFactory, namedParameter.IsCollection) + : ParameterHelper.TryGuessType(constantType, sessionFactory, namedParameter.IsCollection); } private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor @@ -143,6 +172,8 @@ private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor private readonly ISessionFactoryImplementor _sessionFactory; public readonly Dictionary ConstantExpressions = new Dictionary(); + public readonly Dictionary> ParameterConstants = + new Dictionary>(); public readonly Dictionary> RelatedExpressions = new Dictionary>(); public readonly HashSet SequenceSelectorExpressions = new HashSet(); @@ -167,8 +198,8 @@ protected override Expression VisitBinary(BinaryExpression node) return node; } - var left = Unwrap(node.Left); - var right = Unwrap(node.Right); + var left = UnwrapUnary(node.Left); + var right = UnwrapUnary(node.Right); if (node.NodeType == ExpressionType.Assign) { VisitAssign(left, right); @@ -185,8 +216,8 @@ protected override Expression VisitBinary(BinaryExpression node) protected override Expression VisitConditional(ConditionalExpression node) { node = (ConditionalExpression) base.VisitConditional(node); - var ifTrue = Unwrap(node.IfTrue); - var ifFalse = Unwrap(node.IfFalse); + var ifTrue = UnwrapUnary(node.IfTrue); + var ifFalse = UnwrapUnary(node.IfFalse); AddRelatedExpression(node, ifTrue, ifFalse); AddRelatedExpression(node, ifFalse, ifTrue); @@ -219,8 +250,8 @@ protected override Expression VisitMethodCall(MethodCallExpression node) if (EqualsGenerator.Methods.Contains(node.Method) || CompareGenerator.IsCompareMethod(node.Method)) { node = (MethodCallExpression) base.VisitMethodCall(node); - var left = Unwrap(node.Method.IsStatic ? node.Arguments[0] : node.Object); - var right = Unwrap(node.Method.IsStatic ? node.Arguments[1] : node.Arguments[0]); + var left = UnwrapUnary(node.Method.IsStatic ? node.Arguments[0] : node.Object); + var right = UnwrapUnary(node.Method.IsStatic ? node.Arguments[1] : node.Arguments[0]); AddRelatedExpression(node, left, right); AddRelatedExpression(node, right, left); @@ -232,13 +263,21 @@ protected override Expression VisitMethodCall(MethodCallExpression node) protected override Expression VisitConstant(ConstantExpression node) { - if (node.Value is IEntityNameProvider || RelatedExpressions.ContainsKey(node) || !_parameters.ContainsKey(node)) + if (node.Value is IEntityNameProvider || RelatedExpressions.ContainsKey(node) || !_parameters.TryGetValue(node, out var param)) { return node; } RelatedExpressions.Add(node, new HashSet()); ConstantExpressions.Add(node, null); + if (!ParameterConstants.TryGetValue(param, out var set)) + { + set = new HashSet(); + ParameterConstants.Add(param, set); + } + + set.Add(node); + return node; } @@ -254,7 +293,7 @@ querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause && mainFromClause.FromExpression is ConstantExpression constantExpression) { VisitConstant(constantExpression); - AddRelatedExpression(constantExpression, Unwrap(Visit(containsOperator.Item))); + AddRelatedExpression(constantExpression, UnwrapUnary(Visit(containsOperator.Item))); // Copy all found MemberExpressions to the constant expression // (e.g. values.Contains(o.Name != o.Name2 ? o.Enum1 : o.Enum2) -> copy o.Enum1 and o.Enum2) if (RelatedExpressions.TryGetValue(containsOperator.Item, out var set)) @@ -357,16 +396,21 @@ private bool IsDynamicMember(Expression expression) return false; } } + } - private static Expression Unwrap(Expression expression) + /// + /// Unwraps . + /// + /// The expression to unwrap. + /// The unwrapped expression. + private static Expression UnwrapUnary(Expression expression) + { + while (expression is UnaryExpression unaryExpression) { - if (expression is UnaryExpression unaryExpression) - { - return unaryExpression.Operand; - } - - return expression; + expression = unaryExpression.Operand; } + + return expression; } } } diff --git a/src/NHibernate/Type/SingleType.cs b/src/NHibernate/Type/SingleType.cs index bf719e3832b..70ca434e04d 100644 --- a/src/NHibernate/Type/SingleType.cs +++ b/src/NHibernate/Type/SingleType.cs @@ -62,7 +62,7 @@ public override System.Type ReturnedClass public override void Set(DbCommand rs, object value, int index, ISessionImplementor session) { - rs.Parameters[index].Value = value; + rs.Parameters[index].Value = Convert.ToSingle(value); } // Since 5.2 diff --git a/src/NHibernate/Util/TypeExtensions.cs b/src/NHibernate/Util/TypeExtensions.cs index 71bf854957f..1637f30d420 100644 --- a/src/NHibernate/Util/TypeExtensions.cs +++ b/src/NHibernate/Util/TypeExtensions.cs @@ -48,5 +48,30 @@ internal static System.Type UnwrapIfNullable(this System.Type type) return type; } + + internal static bool IsIntegralNumberType(this System.Type type) + { + var code = System.Type.GetTypeCode(type); + if (code == TypeCode.SByte || code == TypeCode.Byte || + code == TypeCode.Int16 || code == TypeCode.UInt16 || + code == TypeCode.Int32 || code == TypeCode.UInt32 || + code == TypeCode.Int64 || code == TypeCode.UInt64) + { + return true; + } + + return false; + } + + internal static bool IsRealNumberType(this System.Type type) + { + var code = System.Type.GetTypeCode(type); + if (code == TypeCode.Decimal || code == TypeCode.Single || code == TypeCode.Double) + { + return true; + } + + return false; + } } }