Skip to content

Commit 46a18cc

Browse files
maca88fredericDelaporte
authored andcommitted
Fix parameter detection for Equals and ComapreTo methods for Linq provider
1 parent b131556 commit 46a18cc

File tree

5 files changed

+113
-43
lines changed

5 files changed

+113
-43
lines changed

src/NHibernate.Test/Async/NHSpecificTest/GH2437/Fixture.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,18 @@ public async Task Get_DateCustomType_NullableDateValueEqualsAsync()
8585
}
8686
}
8787

88+
[Test]
89+
public async Task Get_DateCustomType_NullableDateValueEqualsMethodAsync()
90+
{
91+
using (var session = OpenSession())
92+
using (session.BeginTransaction())
93+
{
94+
var sessions = await (session.Query<UserSession>().Where(x => x.OpenDate.Value.Equals(DateTime.Now)).ToListAsync());
95+
96+
Assert.That(sessions, Has.Count.EqualTo(10));
97+
}
98+
}
99+
88100
[Test]
89101
public async Task Get_DateTimeCustomType_NullableDateValueEqualsAsync()
90102
{

src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,21 @@ public void EqualStringEnumTest()
8484
);
8585
}
8686

87+
[Test]
88+
public void EqualsMethodStringTest()
89+
{
90+
AssertResults(
91+
new Dictionary<string, Predicate<IType>>
92+
{
93+
{"\"London\"", o => o is StringType stringType && stringType.SqlType.Length == 15}
94+
},
95+
db.Orders.Where(o => o.ShippingAddress.City.Equals("London")),
96+
db.Orders.Where(o => "London".Equals(o.ShippingAddress.City)),
97+
db.Orders.Where(o => string.Equals("London", o.ShippingAddress.City)),
98+
db.Orders.Where(o => string.Equals(o.ShippingAddress.City, "London"))
99+
);
100+
}
101+
87102
[Test]
88103
public void ContainsStringEnumTest()
89104
{
@@ -158,6 +173,22 @@ public void EqualStringTest()
158173
);
159174
}
160175

176+
[Test]
177+
public void CompareToStringTest()
178+
{
179+
AssertResults(
180+
new Dictionary<string, Predicate<IType>>
181+
{
182+
{"1", o => o is Int32Type},
183+
{"\"London\"", o => o is StringType stringType && stringType.SqlType.Length == 15}
184+
},
185+
db.Orders.Where(o => o.ShippingAddress.City.CompareTo("London") > 1),
186+
db.Orders.Where(o => "London".CompareTo(o.ShippingAddress.City) > 1),
187+
db.Orders.Where(o => string.Compare("London", o.ShippingAddress.City) > 1),
188+
db.Orders.Where(o => string.Compare(o.ShippingAddress.City, "London") > 1)
189+
);
190+
}
191+
161192
[Test]
162193
public void EqualEntityTest()
163194
{

src/NHibernate.Test/NHSpecificTest/GH2437/Fixture.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,18 @@ public void Get_DateCustomType_NullableDateValueEquals()
7373
}
7474
}
7575

76+
[Test]
77+
public void Get_DateCustomType_NullableDateValueEqualsMethod()
78+
{
79+
using (var session = OpenSession())
80+
using (session.BeginTransaction())
81+
{
82+
var sessions = session.Query<UserSession>().Where(x => x.OpenDate.Value.Equals(DateTime.Now)).ToList();
83+
84+
Assert.That(sessions, Has.Count.EqualTo(10));
85+
}
86+
}
87+
7688
[Test]
7789
public void Get_DateTimeCustomType_NullableDateValueEquals()
7890
{

src/NHibernate/Linq/Functions/EqualsGenerator.cs

Lines changed: 46 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Collections.Generic;
23
using System.Collections.ObjectModel;
34
using System.Linq.Expressions;
45
using System.Reflection;
@@ -10,57 +11,59 @@ namespace NHibernate.Linq.Functions
1011
{
1112
public class EqualsGenerator : BaseHqlGeneratorForMethod
1213
{
13-
public EqualsGenerator()
14+
internal static HashSet<MethodInfo> Methods = new HashSet<MethodInfo>
1415
{
15-
SupportedMethods = new[]
16-
{
17-
ReflectHelper.FastGetMethod(string.Equals, default(string), default(string)),
18-
ReflectHelper.GetMethodDefinition<string>(x => x.Equals(x)),
19-
ReflectHelper.GetMethodDefinition<char>(x => x.Equals(x)),
16+
ReflectHelper.FastGetMethod(string.Equals, default(string), default(string)),
17+
ReflectHelper.GetMethodDefinition<string>(x => x.Equals(x)),
18+
ReflectHelper.GetMethodDefinition<char>(x => x.Equals(x)),
19+
20+
ReflectHelper.GetMethodDefinition<sbyte>(x => x.Equals(x)),
21+
ReflectHelper.GetMethodDefinition<byte>(x => x.Equals(x)),
22+
23+
ReflectHelper.GetMethodDefinition<short>(x => x.Equals(x)),
24+
ReflectHelper.GetMethodDefinition<ushort>(x => x.Equals(x)),
2025

21-
ReflectHelper.GetMethodDefinition<sbyte>(x => x.Equals(x)),
22-
ReflectHelper.GetMethodDefinition<byte>(x => x.Equals(x)),
26+
ReflectHelper.GetMethodDefinition<int>(x => x.Equals(x)),
27+
ReflectHelper.GetMethodDefinition<uint>(x => x.Equals(x)),
2328

24-
ReflectHelper.GetMethodDefinition<short>(x => x.Equals(x)),
25-
ReflectHelper.GetMethodDefinition<ushort>(x => x.Equals(x)),
29+
ReflectHelper.GetMethodDefinition<long>(x => x.Equals(x)),
30+
ReflectHelper.GetMethodDefinition<ulong>(x => x.Equals(x)),
2631

27-
ReflectHelper.GetMethodDefinition<int>(x => x.Equals(x)),
28-
ReflectHelper.GetMethodDefinition<uint>(x => x.Equals(x)),
32+
ReflectHelper.GetMethodDefinition<float>(x => x.Equals(x)),
33+
ReflectHelper.GetMethodDefinition<double>(x => x.Equals(x)),
2934

30-
ReflectHelper.GetMethodDefinition<long>(x => x.Equals(x)),
31-
ReflectHelper.GetMethodDefinition<ulong>(x => x.Equals(x)),
35+
ReflectHelper.FastGetMethod(decimal.Equals, default(decimal), default(decimal)),
36+
ReflectHelper.GetMethodDefinition<decimal>(x => x.Equals(x)),
3237

33-
ReflectHelper.GetMethodDefinition<float>(x => x.Equals(x)),
34-
ReflectHelper.GetMethodDefinition<double>(x => x.Equals(x)),
35-
36-
ReflectHelper.FastGetMethod(decimal.Equals, default(decimal), default(decimal)),
37-
ReflectHelper.GetMethodDefinition<decimal>(x => x.Equals(x)),
38+
ReflectHelper.GetMethodDefinition<Guid>(x => x.Equals(x)),
39+
ReflectHelper.GetMethodDefinition<DateTime>(x => x.Equals(x)),
40+
ReflectHelper.GetMethodDefinition<DateTimeOffset>(x => x.Equals(x)),
41+
ReflectHelper.GetMethodDefinition<TimeSpan>(x => x.Equals(x)),
42+
ReflectHelper.GetMethodDefinition<bool>(x => x.Equals(default(bool))),
3843

39-
ReflectHelper.GetMethodDefinition<Guid>(x => x.Equals(x)),
40-
ReflectHelper.GetMethodDefinition<DateTime>(x => x.Equals(x)),
41-
ReflectHelper.GetMethodDefinition<DateTimeOffset>(x => x.Equals(x)),
42-
ReflectHelper.GetMethodDefinition<TimeSpan>(x => x.Equals(x)),
43-
ReflectHelper.GetMethodDefinition<bool>(x => x.Equals(default(bool))),
44+
ReflectHelper.GetMethodDefinition<IEquatable<string>>(x => x.Equals(default(string))),
45+
ReflectHelper.GetMethodDefinition<IEquatable<char>>(x => x.Equals(default(char))),
46+
ReflectHelper.GetMethodDefinition<IEquatable<sbyte>>(x => x.Equals(default(sbyte))),
47+
ReflectHelper.GetMethodDefinition<IEquatable<byte>>(x => x.Equals(default(byte))),
48+
ReflectHelper.GetMethodDefinition<IEquatable<short>>(x => x.Equals(default(short))),
49+
ReflectHelper.GetMethodDefinition<IEquatable<ushort>>(x => x.Equals(default(ushort))),
50+
ReflectHelper.GetMethodDefinition<IEquatable<int>>(x => x.Equals(default(int))),
51+
ReflectHelper.GetMethodDefinition<IEquatable<uint>>(x => x.Equals(default(uint))),
52+
ReflectHelper.GetMethodDefinition<IEquatable<long>>(x => x.Equals(default(long))),
53+
ReflectHelper.GetMethodDefinition<IEquatable<ulong>>(x => x.Equals(default(ulong))),
54+
ReflectHelper.GetMethodDefinition<IEquatable<float>>(x => x.Equals(default(float))),
55+
ReflectHelper.GetMethodDefinition<IEquatable<double>>(x => x.Equals(default(double))),
56+
ReflectHelper.GetMethodDefinition<IEquatable<decimal>>(x => x.Equals(default(decimal))),
57+
ReflectHelper.GetMethodDefinition<IEquatable<Guid>>(x => x.Equals(default(Guid))),
58+
ReflectHelper.GetMethodDefinition<IEquatable<DateTime>>(x => x.Equals(default(DateTime))),
59+
ReflectHelper.GetMethodDefinition<IEquatable<DateTimeOffset>>(x => x.Equals(default(DateTimeOffset))),
60+
ReflectHelper.GetMethodDefinition<IEquatable<TimeSpan>>(x => x.Equals(default(TimeSpan))),
61+
ReflectHelper.GetMethodDefinition<IEquatable<bool>>(x => x.Equals(default(bool)))
62+
};
4463

45-
ReflectHelper.GetMethodDefinition<IEquatable<string>>(x => x.Equals(default(string))),
46-
ReflectHelper.GetMethodDefinition<IEquatable<char>>(x => x.Equals(default(char))),
47-
ReflectHelper.GetMethodDefinition<IEquatable<sbyte>>(x => x.Equals(default(sbyte))),
48-
ReflectHelper.GetMethodDefinition<IEquatable<byte>>(x => x.Equals(default(byte))),
49-
ReflectHelper.GetMethodDefinition<IEquatable<short>>(x => x.Equals(default(short))),
50-
ReflectHelper.GetMethodDefinition<IEquatable<ushort>>(x => x.Equals(default(ushort))),
51-
ReflectHelper.GetMethodDefinition<IEquatable<int>>(x => x.Equals(default(int))),
52-
ReflectHelper.GetMethodDefinition<IEquatable<uint>>(x => x.Equals(default(uint))),
53-
ReflectHelper.GetMethodDefinition<IEquatable<long>>(x => x.Equals(default(long))),
54-
ReflectHelper.GetMethodDefinition<IEquatable<ulong>>(x => x.Equals(default(ulong))),
55-
ReflectHelper.GetMethodDefinition<IEquatable<float>>(x => x.Equals(default(float))),
56-
ReflectHelper.GetMethodDefinition<IEquatable<double>>(x => x.Equals(default(double))),
57-
ReflectHelper.GetMethodDefinition<IEquatable<decimal>>(x => x.Equals(default(decimal))),
58-
ReflectHelper.GetMethodDefinition<IEquatable<Guid>>(x => x.Equals(default(Guid))),
59-
ReflectHelper.GetMethodDefinition<IEquatable<DateTime>>(x => x.Equals(default(DateTime))),
60-
ReflectHelper.GetMethodDefinition<IEquatable<DateTimeOffset>>(x => x.Equals(default(DateTimeOffset))),
61-
ReflectHelper.GetMethodDefinition<IEquatable<TimeSpan>>(x => x.Equals(default(TimeSpan))),
62-
ReflectHelper.GetMethodDefinition<IEquatable<bool>>(x => x.Equals(default(bool)))
63-
};
64+
public EqualsGenerator()
65+
{
66+
SupportedMethods = Methods;
6467
}
6568

6669
public override bool AllowsNullableReturnType(MethodInfo method) => false;

src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Linq;
44
using System.Linq.Expressions;
55
using NHibernate.Engine;
6+
using NHibernate.Linq.Functions;
67
using NHibernate.Param;
78
using NHibernate.Persister.Collection;
89
using NHibernate.Type;
@@ -215,6 +216,17 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
215216
: node;
216217
}
217218

219+
if (EqualsGenerator.Methods.Contains(node.Method) || CompareGenerator.IsCompareMethod(node.Method))
220+
{
221+
node = (MethodCallExpression) base.VisitMethodCall(node);
222+
var left = Unwrap(node.Method.IsStatic ? node.Arguments[0] : node.Object);
223+
var right = Unwrap(node.Method.IsStatic ? node.Arguments[1] : node.Arguments[0]);
224+
AddRelatedExpression(node, left, right);
225+
AddRelatedExpression(node, right, left);
226+
227+
return node;
228+
}
229+
218230
return base.VisitMethodCall(node);
219231
}
220232

0 commit comments

Comments
 (0)