Skip to content

Commit 8e98169

Browse files
authored
Fix OData string comparisons (lt, gt, ge, le) (#2386)
1 parent da13b9f commit 8e98169

File tree

7 files changed

+158
-13
lines changed

7 files changed

+158
-13
lines changed

src/NHibernate.Test/Async/Linq/ODataTests.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,18 @@ public async Task BasePropertyFilterAsync(string queryString, int expectedRows)
8787
Assert.That(results, Has.Count.EqualTo(expectedRows));
8888
}
8989

90+
//GH-2362
91+
[TestCase("$filter=CustomerId le 'ANATR'", 2)]
92+
[TestCase("$filter=startswith(CustomerId, 'ANATR')", 1)]
93+
[TestCase("$filter=endswith(CustomerId, 'ANATR')", 1)]
94+
[TestCase("$filter=indexof(CustomerId, 'ANATR') eq 0", 1)]
95+
public async Task StringFilterAsync(string queryString, int expectedCount)
96+
{
97+
Assert.That(
98+
await (ApplyFilter(session.Query<Customer>(), queryString).Cast<Customer>().ToListAsync()),
99+
Has.Count.EqualTo(expectedCount));
100+
}
101+
90102
private IQueryable ApplyFilter<T>(IQueryable<T> query, string queryString)
91103
{
92104
var context = new ODataQueryContext(CreatEdmModel(), typeof(T), null) { };

src/NHibernate.Test/Async/Linq/WhereTests.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@
1414
using System.Collections.ObjectModel;
1515
using System.Linq;
1616
using System.Linq.Expressions;
17+
using log4net.Core;
1718
using NHibernate.Engine.Query;
1819
using NHibernate.Linq;
1920
using NHibernate.DomainModel.Northwind.Entities;
21+
using NHibernate.Linq.Functions;
2022
using NUnit.Framework;
2123

2224
namespace NHibernate.Test.Linq
2325
{
2426
using System.Threading.Tasks;
27+
using System.Threading;
2528
[TestFixture]
2629
public class WhereTestsAsync : LinqTestCase
2730
{
@@ -430,6 +433,34 @@ public async Task UsersWithStringContainsAndNotNullNameHQLAsync()
430433
Assert.That(users.Count, Is.EqualTo(1));
431434
}
432435

436+
[Test]
437+
public void StringComparisonParamEmitsWarningAsync()
438+
{
439+
Assert.Multiple(
440+
async () =>
441+
{
442+
await (AssertStringComparisonWarningAsync(x => string.Compare(x.CustomerId, "ANATR", StringComparison.Ordinal) <= 0, 2));
443+
await (AssertStringComparisonWarningAsync(x => x.CustomerId.StartsWith("ANATR", StringComparison.Ordinal), 1));
444+
await (AssertStringComparisonWarningAsync(x => x.CustomerId.EndsWith("ANATR", StringComparison.Ordinal), 1));
445+
await (AssertStringComparisonWarningAsync(x => x.CustomerId.IndexOf("ANATR", StringComparison.Ordinal) == 0, 1));
446+
await (AssertStringComparisonWarningAsync(x => x.CustomerId.IndexOf("ANATR", 0, StringComparison.Ordinal) == 0, 1));
447+
#if NETCOREAPP2_0
448+
await (AssertStringComparisonWarningAsync(x => x.CustomerId.Replace("AN", "XX", StringComparison.Ordinal) == "XXATR", 1));
449+
#endif
450+
});
451+
}
452+
453+
private async Task AssertStringComparisonWarningAsync(Expression<Func<Customer, bool>> whereParam, int expected, CancellationToken cancellationToken = default(CancellationToken))
454+
{
455+
using (var log = new LogSpy(typeof(BaseHqlGeneratorForMethod)))
456+
{
457+
var customers = await (session.Query<Customer>().Where(whereParam).ToListAsync(cancellationToken));
458+
459+
Assert.That(customers, Has.Count.EqualTo(expected), whereParam.ToString);
460+
Assert.That(log.GetWholeLog(), Does.Contain($"parameter of type '{nameof(StringComparison)}' is ignored"), whereParam.ToString);
461+
}
462+
}
463+
433464
[Test]
434465
public async Task UsersWithArrayContainsAsync()
435466
{

src/NHibernate.Test/Linq/ODataTests.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,18 @@ public void BasePropertyFilter(string queryString, int expectedRows)
7575
Assert.That(results, Has.Count.EqualTo(expectedRows));
7676
}
7777

78+
//GH-2362
79+
[TestCase("$filter=CustomerId le 'ANATR'", 2)]
80+
[TestCase("$filter=startswith(CustomerId, 'ANATR')", 1)]
81+
[TestCase("$filter=endswith(CustomerId, 'ANATR')", 1)]
82+
[TestCase("$filter=indexof(CustomerId, 'ANATR') eq 0", 1)]
83+
public void StringFilter(string queryString, int expectedCount)
84+
{
85+
Assert.That(
86+
ApplyFilter(session.Query<Customer>(), queryString).Cast<Customer>().ToList(),
87+
Has.Count.EqualTo(expectedCount));
88+
}
89+
7890
private IQueryable ApplyFilter<T>(IQueryable<T> query, string queryString)
7991
{
8092
var context = new ODataQueryContext(CreatEdmModel(), typeof(T), null) { };

src/NHibernate.Test/Linq/WhereTests.cs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
using System.Collections.ObjectModel;
55
using System.Linq;
66
using System.Linq.Expressions;
7+
using log4net.Core;
78
using NHibernate.Engine.Query;
89
using NHibernate.Linq;
910
using NHibernate.DomainModel.Northwind.Entities;
11+
using NHibernate.Linq.Functions;
1012
using NUnit.Framework;
1113

1214
namespace NHibernate.Test.Linq
@@ -419,6 +421,34 @@ public void UsersWithStringContainsAndNotNullNameHQL()
419421
Assert.That(users.Count, Is.EqualTo(1));
420422
}
421423

424+
[Test]
425+
public void StringComparisonParamEmitsWarning()
426+
{
427+
Assert.Multiple(
428+
() =>
429+
{
430+
AssertStringComparisonWarning(x => string.Compare(x.CustomerId, "ANATR", StringComparison.Ordinal) <= 0, 2);
431+
AssertStringComparisonWarning(x => x.CustomerId.StartsWith("ANATR", StringComparison.Ordinal), 1);
432+
AssertStringComparisonWarning(x => x.CustomerId.EndsWith("ANATR", StringComparison.Ordinal), 1);
433+
AssertStringComparisonWarning(x => x.CustomerId.IndexOf("ANATR", StringComparison.Ordinal) == 0, 1);
434+
AssertStringComparisonWarning(x => x.CustomerId.IndexOf("ANATR", 0, StringComparison.Ordinal) == 0, 1);
435+
#if NETCOREAPP2_0
436+
AssertStringComparisonWarning(x => x.CustomerId.Replace("AN", "XX", StringComparison.Ordinal) == "XXATR", 1);
437+
#endif
438+
});
439+
}
440+
441+
private void AssertStringComparisonWarning(Expression<Func<Customer, bool>> whereParam, int expected)
442+
{
443+
using (var log = new LogSpy(typeof(BaseHqlGeneratorForMethod)))
444+
{
445+
var customers = session.Query<Customer>().Where(whereParam).ToList();
446+
447+
Assert.That(customers, Has.Count.EqualTo(expected), whereParam.ToString);
448+
Assert.That(log.GetWholeLog(), Does.Contain($"parameter of type '{nameof(StringComparison)}' is ignored"), whereParam.ToString);
449+
}
450+
}
451+
422452
[Test]
423453
public void UsersWithArrayContains()
424454
{

src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
using System.Collections.Generic;
1+
using System;
2+
using System.Collections.Generic;
23
using System.Collections.ObjectModel;
4+
using System.Linq;
35
using System.Linq.Expressions;
46
using System.Reflection;
57
using NHibernate.Hql.Ast;
@@ -9,10 +11,33 @@ namespace NHibernate.Linq.Functions
911
{
1012
public abstract class BaseHqlGeneratorForMethod : IHqlGeneratorForMethod, IHqlGeneratorForMethodExtended
1113
{
14+
protected static readonly INHibernateLogger Log = NHibernateLogger.For(typeof(BaseHqlGeneratorForMethod));
15+
1216
public IEnumerable<MethodInfo> SupportedMethods { get; protected set; }
1317

1418
public abstract HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor);
1519

1620
public virtual bool AllowsNullableReturnType(MethodInfo method) => true;
21+
22+
private protected static void LogIgnoredParameter(MethodInfo method, string paramType)
23+
{
24+
if (Log.IsWarnEnabled())
25+
Log.Warn("Method parameter of type '{0}' is ignored when converting to hql the following method: {1}", paramType, method);
26+
}
27+
28+
private protected static void LogIgnoredStringComparisonParameter(MethodInfo actualMethod, MethodInfo methodWithStringComparison)
29+
{
30+
if (actualMethod == methodWithStringComparison)
31+
LogIgnoredParameter(actualMethod, nameof(StringComparison));
32+
}
33+
34+
private protected bool LogIgnoredStringComparisonParameter(MethodInfo actualMethod, params MethodInfo[] methodsWithStringComparison)
35+
{
36+
if (!methodsWithStringComparison.Contains(actualMethod))
37+
return false;
38+
39+
LogIgnoredParameter(actualMethod, nameof(StringComparison));
40+
return true;
41+
}
1742
}
1843
}

src/NHibernate/Linq/Functions/CompareGenerator.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@ namespace NHibernate.Linq.Functions
1212
{
1313
internal class CompareGenerator : BaseHqlGeneratorForMethod, IRuntimeMethodHqlGenerator
1414
{
15+
private static readonly MethodInfo MethodWithComparer = ReflectHelper.FastGetMethod(string.Compare, default(string), default(string), default(StringComparison));
16+
1517
private static readonly HashSet<MethodInfo> ActingMethods = new HashSet<MethodInfo>
1618
{
1719
ReflectHelper.FastGetMethod(string.Compare, default(string), default(string)),
20+
MethodWithComparer,
1821
ReflectHelper.GetMethodDefinition<string>(s => s.CompareTo(s)),
1922
ReflectHelper.GetMethodDefinition<char>(x => x.CompareTo(x)),
2023

@@ -43,7 +46,10 @@ internal class CompareGenerator : BaseHqlGeneratorForMethod, IRuntimeMethodHqlGe
4346
internal static bool IsCompareMethod(MethodInfo methodInfo)
4447
{
4548
if (ActingMethods.Contains(methodInfo))
49+
{
50+
LogIgnoredStringComparisonParameter(methodInfo, MethodWithComparer);
4651
return true;
52+
}
4753

4854
// This is .Net 4 only, and in the System.Data.Services assembly, which we don't depend directly on.
4955
return methodInfo != null && methodInfo.Name == "Compare" &&

src/NHibernate/Linq/Functions/StringGenerator.cs

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,18 @@ public override HqlTreeNode BuildHql(MemberInfo member, Expression expression, H
7676

7777
public class StartsWithGenerator : BaseHqlGeneratorForMethod
7878
{
79+
private static readonly MethodInfo MethodWithComparer = ReflectHelper.GetMethodDefinition<string>(x => x.StartsWith(null, default(StringComparison)));
80+
7981
public StartsWithGenerator()
8082
{
81-
SupportedMethods = new[] { ReflectHelper.GetMethodDefinition<string>(x => x.StartsWith(null)) };
83+
SupportedMethods = new[] {ReflectHelper.GetMethodDefinition<string>(x => x.StartsWith(null)), MethodWithComparer};
8284
}
8385

8486
public override bool AllowsNullableReturnType(MethodInfo method) => false;
8587

8688
public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
8789
{
90+
LogIgnoredStringComparisonParameter(method, MethodWithComparer);
8891
return treeBuilder.Like(
8992
visitor.Visit(targetObject).AsExpression(),
9093
treeBuilder.Concat(
@@ -95,15 +98,18 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject,
9598

9699
public class EndsWithGenerator : BaseHqlGeneratorForMethod
97100
{
101+
private static readonly MethodInfo MethodWithComparer = ReflectHelper.GetMethodDefinition<string>(x => x.EndsWith(null, default(StringComparison)));
102+
98103
public EndsWithGenerator()
99104
{
100-
SupportedMethods = new[] { ReflectHelper.GetMethodDefinition<string>(x => x.EndsWith(null)) };
105+
SupportedMethods = new[] {ReflectHelper.GetMethodDefinition<string>(x => x.EndsWith(null)), MethodWithComparer,};
101106
}
102107

103108
public override bool AllowsNullableReturnType(MethodInfo method) => false;
104109

105110
public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
106111
{
112+
LogIgnoredStringComparisonParameter(method, MethodWithComparer);
107113
return treeBuilder.Like(
108114
visitor.Visit(targetObject).AsExpression(),
109115
treeBuilder.Concat(
@@ -210,20 +216,32 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject,
210216

211217
public class IndexOfGenerator : BaseHqlGeneratorForMethod
212218
{
219+
private static readonly MethodInfo MethodWithComparer1 = ReflectHelper.GetMethodDefinition<string>(x => x.IndexOf(string.Empty, default(StringComparison)));
220+
private static readonly MethodInfo MethodWithComparer2 = ReflectHelper.GetMethodDefinition<string>(x => x.IndexOf(string.Empty, 0, default(StringComparison)));
221+
213222
public IndexOfGenerator()
214223
{
215224
SupportedMethods = new[]
216225
{
217226
ReflectHelper.GetMethodDefinition<string>(s => s.IndexOf(' ')),
218227
ReflectHelper.GetMethodDefinition<string>(s => s.IndexOf(" ")),
219228
ReflectHelper.GetMethodDefinition<string>(s => s.IndexOf(' ', 0)),
220-
ReflectHelper.GetMethodDefinition<string>(s => s.IndexOf(" ", 0))
229+
ReflectHelper.GetMethodDefinition<string>(s => s.IndexOf(" ", 0)),
230+
MethodWithComparer1,
231+
MethodWithComparer2,
221232
};
222233
}
223234
public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
224235
{
236+
var argsCount = arguments.Count;
237+
if (LogIgnoredStringComparisonParameter(method, MethodWithComparer1, MethodWithComparer2))
238+
{
239+
//StringComparison is last argument, just ignore it
240+
argsCount--;
241+
}
242+
225243
HqlMethodCall locate;
226-
if (arguments.Count == 1)
244+
if (argsCount == 1)
227245
{
228246
locate = treeBuilder.MethodCall("locate",
229247
visitor.Visit(arguments[0]).AsExpression(),
@@ -244,21 +262,32 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject,
244262

245263
public class ReplaceGenerator : BaseHqlGeneratorForMethod
246264
{
265+
#if NETCOREAPP2_0
266+
private static readonly MethodInfo MethodWithComparer = ReflectHelper.GetMethodDefinition<string>(x => x.Replace(string.Empty, string.Empty, default(StringComparison)));
267+
#endif
268+
247269
public ReplaceGenerator()
248270
{
249271
SupportedMethods = new[]
250-
{
251-
ReflectHelper.GetMethodDefinition<string>(s => s.Replace(' ', ' ')),
252-
ReflectHelper.GetMethodDefinition<string>(s => s.Replace("", ""))
253-
};
272+
{
273+
ReflectHelper.GetMethodDefinition<string>(s => s.Replace(' ', ' ')),
274+
ReflectHelper.GetMethodDefinition<string>(s => s.Replace("", "")),
275+
#if NETCOREAPP2_0
276+
MethodWithComparer,
277+
#endif
278+
};
254279
}
255280

256281
public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
257282
{
258-
return treeBuilder.MethodCall("replace",
259-
visitor.Visit(targetObject).AsExpression(),
260-
visitor.Visit(arguments[0]).AsExpression(),
261-
visitor.Visit(arguments[1]).AsExpression());
283+
#if NETCOREAPP2_0
284+
LogIgnoredStringComparisonParameter(method, MethodWithComparer);
285+
#endif
286+
return treeBuilder.MethodCall(
287+
"replace",
288+
visitor.Visit(targetObject).AsExpression(),
289+
visitor.Visit(arguments[0]).AsExpression(),
290+
visitor.Visit(arguments[1]).AsExpression());
262291
}
263292
}
264293

0 commit comments

Comments
 (0)