diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH2437/Fixture.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH2437/Fixture.cs index 2dea6eb573e..3aefd7ae543 100644 --- a/src/NHibernate.Test/Async/NHSpecificTest/GH2437/Fixture.cs +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH2437/Fixture.cs @@ -85,6 +85,18 @@ public async Task Get_DateCustomType_NullableDateValueEqualsAsync() } } + [Test] + public async Task Get_DateCustomType_NullableDateValueEqualsMethodAsync() + { + using (var session = OpenSession()) + using (session.BeginTransaction()) + { + var sessions = await (session.Query().Where(x => x.OpenDate.Value.Equals(DateTime.Now)).ToListAsync()); + + Assert.That(sessions, Has.Count.EqualTo(10)); + } + } + [Test] public async Task Get_DateTimeCustomType_NullableDateValueEqualsAsync() { diff --git a/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs b/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs index 39cb2d22d74..ba4505dda8c 100644 --- a/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs @@ -84,6 +84,21 @@ public void EqualStringEnumTest() ); } + [Test] + public void EqualsMethodStringTest() + { + AssertResults( + new Dictionary> + { + {"\"London\"", o => o is StringType stringType && stringType.SqlType.Length == 15} + }, + db.Orders.Where(o => o.ShippingAddress.City.Equals("London")), + db.Orders.Where(o => "London".Equals(o.ShippingAddress.City)), + db.Orders.Where(o => string.Equals("London", o.ShippingAddress.City)), + db.Orders.Where(o => string.Equals(o.ShippingAddress.City, "London")) + ); + } + [Test] public void ContainsStringEnumTest() { @@ -158,6 +173,22 @@ public void EqualStringTest() ); } + [Test] + public void CompareToStringTest() + { + AssertResults( + new Dictionary> + { + {"1", o => o is Int32Type}, + {"\"London\"", o => o is StringType stringType && stringType.SqlType.Length == 15} + }, + db.Orders.Where(o => o.ShippingAddress.City.CompareTo("London") > 1), + db.Orders.Where(o => "London".CompareTo(o.ShippingAddress.City) > 1), + db.Orders.Where(o => string.Compare("London", o.ShippingAddress.City) > 1), + db.Orders.Where(o => string.Compare(o.ShippingAddress.City, "London") > 1) + ); + } + [Test] public void EqualEntityTest() { diff --git a/src/NHibernate.Test/NHSpecificTest/GH2437/Fixture.cs b/src/NHibernate.Test/NHSpecificTest/GH2437/Fixture.cs index 05cc50f83f1..bb8354e2b44 100644 --- a/src/NHibernate.Test/NHSpecificTest/GH2437/Fixture.cs +++ b/src/NHibernate.Test/NHSpecificTest/GH2437/Fixture.cs @@ -73,6 +73,18 @@ public void Get_DateCustomType_NullableDateValueEquals() } } + [Test] + public void Get_DateCustomType_NullableDateValueEqualsMethod() + { + using (var session = OpenSession()) + using (session.BeginTransaction()) + { + var sessions = session.Query().Where(x => x.OpenDate.Value.Equals(DateTime.Now)).ToList(); + + Assert.That(sessions, Has.Count.EqualTo(10)); + } + } + [Test] public void Get_DateTimeCustomType_NullableDateValueEquals() { diff --git a/src/NHibernate/Linq/Functions/EqualsGenerator.cs b/src/NHibernate/Linq/Functions/EqualsGenerator.cs index 1728f0b706f..82c15c87190 100644 --- a/src/NHibernate/Linq/Functions/EqualsGenerator.cs +++ b/src/NHibernate/Linq/Functions/EqualsGenerator.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Collections.ObjectModel; using System.Linq.Expressions; using System.Reflection; @@ -10,57 +11,59 @@ namespace NHibernate.Linq.Functions { public class EqualsGenerator : BaseHqlGeneratorForMethod { - public EqualsGenerator() + internal static HashSet Methods = new HashSet { - SupportedMethods = new[] - { - ReflectHelper.FastGetMethod(string.Equals, default(string), default(string)), - ReflectHelper.GetMethodDefinition(x => x.Equals(x)), - ReflectHelper.GetMethodDefinition(x => x.Equals(x)), + ReflectHelper.FastGetMethod(string.Equals, default(string), default(string)), + ReflectHelper.GetMethodDefinition(x => x.Equals(x)), + ReflectHelper.GetMethodDefinition(x => x.Equals(x)), + + ReflectHelper.GetMethodDefinition(x => x.Equals(x)), + ReflectHelper.GetMethodDefinition(x => x.Equals(x)), + + ReflectHelper.GetMethodDefinition(x => x.Equals(x)), + ReflectHelper.GetMethodDefinition(x => x.Equals(x)), - ReflectHelper.GetMethodDefinition(x => x.Equals(x)), - ReflectHelper.GetMethodDefinition(x => x.Equals(x)), + ReflectHelper.GetMethodDefinition(x => x.Equals(x)), + ReflectHelper.GetMethodDefinition(x => x.Equals(x)), - ReflectHelper.GetMethodDefinition(x => x.Equals(x)), - ReflectHelper.GetMethodDefinition(x => x.Equals(x)), + ReflectHelper.GetMethodDefinition(x => x.Equals(x)), + ReflectHelper.GetMethodDefinition(x => x.Equals(x)), - ReflectHelper.GetMethodDefinition(x => x.Equals(x)), - ReflectHelper.GetMethodDefinition(x => x.Equals(x)), + ReflectHelper.GetMethodDefinition(x => x.Equals(x)), + ReflectHelper.GetMethodDefinition(x => x.Equals(x)), - ReflectHelper.GetMethodDefinition(x => x.Equals(x)), - ReflectHelper.GetMethodDefinition(x => x.Equals(x)), + ReflectHelper.FastGetMethod(decimal.Equals, default(decimal), default(decimal)), + ReflectHelper.GetMethodDefinition(x => x.Equals(x)), - ReflectHelper.GetMethodDefinition(x => x.Equals(x)), - ReflectHelper.GetMethodDefinition(x => x.Equals(x)), - - ReflectHelper.FastGetMethod(decimal.Equals, default(decimal), default(decimal)), - ReflectHelper.GetMethodDefinition(x => x.Equals(x)), + ReflectHelper.GetMethodDefinition(x => x.Equals(x)), + ReflectHelper.GetMethodDefinition(x => x.Equals(x)), + ReflectHelper.GetMethodDefinition(x => x.Equals(x)), + ReflectHelper.GetMethodDefinition(x => x.Equals(x)), + ReflectHelper.GetMethodDefinition(x => x.Equals(default(bool))), - ReflectHelper.GetMethodDefinition(x => x.Equals(x)), - ReflectHelper.GetMethodDefinition(x => x.Equals(x)), - ReflectHelper.GetMethodDefinition(x => x.Equals(x)), - ReflectHelper.GetMethodDefinition(x => x.Equals(x)), - ReflectHelper.GetMethodDefinition(x => x.Equals(default(bool))), + ReflectHelper.GetMethodDefinition>(x => x.Equals(default(string))), + ReflectHelper.GetMethodDefinition>(x => x.Equals(default(char))), + ReflectHelper.GetMethodDefinition>(x => x.Equals(default(sbyte))), + ReflectHelper.GetMethodDefinition>(x => x.Equals(default(byte))), + ReflectHelper.GetMethodDefinition>(x => x.Equals(default(short))), + ReflectHelper.GetMethodDefinition>(x => x.Equals(default(ushort))), + ReflectHelper.GetMethodDefinition>(x => x.Equals(default(int))), + ReflectHelper.GetMethodDefinition>(x => x.Equals(default(uint))), + ReflectHelper.GetMethodDefinition>(x => x.Equals(default(long))), + ReflectHelper.GetMethodDefinition>(x => x.Equals(default(ulong))), + ReflectHelper.GetMethodDefinition>(x => x.Equals(default(float))), + ReflectHelper.GetMethodDefinition>(x => x.Equals(default(double))), + ReflectHelper.GetMethodDefinition>(x => x.Equals(default(decimal))), + ReflectHelper.GetMethodDefinition>(x => x.Equals(default(Guid))), + ReflectHelper.GetMethodDefinition>(x => x.Equals(default(DateTime))), + ReflectHelper.GetMethodDefinition>(x => x.Equals(default(DateTimeOffset))), + ReflectHelper.GetMethodDefinition>(x => x.Equals(default(TimeSpan))), + ReflectHelper.GetMethodDefinition>(x => x.Equals(default(bool))) + }; - ReflectHelper.GetMethodDefinition>(x => x.Equals(default(string))), - ReflectHelper.GetMethodDefinition>(x => x.Equals(default(char))), - ReflectHelper.GetMethodDefinition>(x => x.Equals(default(sbyte))), - ReflectHelper.GetMethodDefinition>(x => x.Equals(default(byte))), - ReflectHelper.GetMethodDefinition>(x => x.Equals(default(short))), - ReflectHelper.GetMethodDefinition>(x => x.Equals(default(ushort))), - ReflectHelper.GetMethodDefinition>(x => x.Equals(default(int))), - ReflectHelper.GetMethodDefinition>(x => x.Equals(default(uint))), - ReflectHelper.GetMethodDefinition>(x => x.Equals(default(long))), - ReflectHelper.GetMethodDefinition>(x => x.Equals(default(ulong))), - ReflectHelper.GetMethodDefinition>(x => x.Equals(default(float))), - ReflectHelper.GetMethodDefinition>(x => x.Equals(default(double))), - ReflectHelper.GetMethodDefinition>(x => x.Equals(default(decimal))), - ReflectHelper.GetMethodDefinition>(x => x.Equals(default(Guid))), - ReflectHelper.GetMethodDefinition>(x => x.Equals(default(DateTime))), - ReflectHelper.GetMethodDefinition>(x => x.Equals(default(DateTimeOffset))), - ReflectHelper.GetMethodDefinition>(x => x.Equals(default(TimeSpan))), - ReflectHelper.GetMethodDefinition>(x => x.Equals(default(bool))) - }; + public EqualsGenerator() + { + SupportedMethods = Methods; } public override bool AllowsNullableReturnType(MethodInfo method) => false; diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 41bc2547f31..457e04fcbdb 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -3,6 +3,7 @@ using System.Linq; using System.Linq.Expressions; using NHibernate.Engine; +using NHibernate.Linq.Functions; using NHibernate.Param; using NHibernate.Persister.Collection; using NHibernate.Type; @@ -215,6 +216,17 @@ protected override Expression VisitMethodCall(MethodCallExpression node) : node; } + if (EqualsGenerator.Methods.Contains(node.Method) || CompareGenerator.IsCompareMethod(node.Method)) + { + node = (MethodCallExpression) base.VisitMethodCall(node); + var left = Unwrap(node.Method.IsStatic ? node.Arguments[0] : node.Object); + var right = Unwrap(node.Method.IsStatic ? node.Arguments[1] : node.Arguments[0]); + AddRelatedExpression(node, left, right); + AddRelatedExpression(node, right, left); + + return node; + } + return base.VisitMethodCall(node); }