Skip to content

Commit 5639b0b

Browse files
committed
Support with warn StringComparison in StartsWith, EndsWith, IndexOf, Replace
1 parent 6a15c46 commit 5639b0b

File tree

7 files changed

+121
-10
lines changed

7 files changed

+121
-10
lines changed

src/NHibernate.Test/Async/Linq/ODataTests.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ public async Task BasePropertyFilterAsync(string queryString, int expectedRows)
8787
Assert.That(results, Has.Count.EqualTo(expectedRows));
8888
}
8989

90+
//GH-2362
9091
[TestCase("$filter=CustomerId le 'ANATR'",2 )]
9192
[TestCase("$filter=startswith(CustomerId, 'ANATR')", 1)]
9293
[TestCase("$filter=endswith(CustomerId, 'ANATR')", 1)]

src/NHibernate.Test/Async/Linq/WhereTests.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@
1414
using System.Collections.ObjectModel;
1515
using System.Linq;
1616
using System.Linq.Expressions;
17+
using log4net.Core;
1718
using NHibernate.Engine.Query;
1819
using NHibernate.Linq;
1920
using NHibernate.DomainModel.Northwind.Entities;
21+
using NHibernate.Linq.Functions;
2022
using NUnit.Framework;
2123

2224
namespace NHibernate.Test.Linq
2325
{
2426
using System.Threading.Tasks;
27+
using System.Threading;
2528
[TestFixture]
2629
public class WhereTestsAsync : LinqTestCase
2730
{
@@ -430,6 +433,32 @@ public async Task UsersWithStringContainsAndNotNullNameHQLAsync()
430433
Assert.That(users.Count, Is.EqualTo(1));
431434
}
432435

436+
[Test()]
437+
public void StringComparisonParamEmitsWarningAsync()
438+
{
439+
Assert.Multiple(
440+
async () =>
441+
{
442+
await (AssertStringComparisonWarningAsync(x => string.Compare(x.CustomerId, "ANATR", StringComparison.Ordinal) <= 0, 2));
443+
await (AssertStringComparisonWarningAsync(x => x.CustomerId.StartsWith("ANATR", StringComparison.Ordinal), 1));
444+
await (AssertStringComparisonWarningAsync(x => x.CustomerId.EndsWith("ANATR", StringComparison.Ordinal), 1));
445+
await (AssertStringComparisonWarningAsync(x => x.CustomerId.IndexOf("ANATR", StringComparison.Ordinal) == 0, 1));
446+
await (AssertStringComparisonWarningAsync(x => x.CustomerId.IndexOf("ANATR", 0, StringComparison.Ordinal) == 0, 1));
447+
await (AssertStringComparisonWarningAsync(x => x.CustomerId.Replace("AN", "XX", StringComparison.Ordinal) == "XXATR", 1));
448+
});
449+
}
450+
451+
private async Task AssertStringComparisonWarningAsync(Expression<Func<Customer,bool>> whereParam, int expected, CancellationToken cancellationToken = default(CancellationToken))
452+
{
453+
using (var log = new LogSpy(typeof(BaseHqlGeneratorForMethod)))
454+
{
455+
var customers = await (session.Query<Customer>().Where(whereParam).ToListAsync(cancellationToken));
456+
457+
Assert.That(customers, Has.Count.EqualTo(expected), whereParam.ToString);
458+
Assert.That(log.GetWholeLog(), Does.Contain($"parameter of type '{nameof(StringComparison)}' is ignored"), whereParam.ToString);
459+
}
460+
}
461+
433462
[Test]
434463
public async Task UsersWithArrayContainsAsync()
435464
{

src/NHibernate.Test/Linq/ODataTests.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ public void BasePropertyFilter(string queryString, int expectedRows)
7575
Assert.That(results, Has.Count.EqualTo(expectedRows));
7676
}
7777

78+
//GH-2362
7879
[TestCase("$filter=CustomerId le 'ANATR'",2 )]
7980
[TestCase("$filter=startswith(CustomerId, 'ANATR')", 1)]
8081
[TestCase("$filter=endswith(CustomerId, 'ANATR')", 1)]

src/NHibernate.Test/Linq/WhereTests.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
using System.Collections.ObjectModel;
55
using System.Linq;
66
using System.Linq.Expressions;
7+
using log4net.Core;
78
using NHibernate.Engine.Query;
89
using NHibernate.Linq;
910
using NHibernate.DomainModel.Northwind.Entities;
11+
using NHibernate.Linq.Functions;
1012
using NUnit.Framework;
1113

1214
namespace NHibernate.Test.Linq
@@ -419,6 +421,32 @@ public void UsersWithStringContainsAndNotNullNameHQL()
419421
Assert.That(users.Count, Is.EqualTo(1));
420422
}
421423

424+
[Test()]
425+
public void StringComparisonParamEmitsWarning()
426+
{
427+
Assert.Multiple(
428+
() =>
429+
{
430+
AssertStringComparisonWarning(x => string.Compare(x.CustomerId, "ANATR", StringComparison.Ordinal) <= 0, 2);
431+
AssertStringComparisonWarning(x => x.CustomerId.StartsWith("ANATR", StringComparison.Ordinal), 1);
432+
AssertStringComparisonWarning(x => x.CustomerId.EndsWith("ANATR", StringComparison.Ordinal), 1);
433+
AssertStringComparisonWarning(x => x.CustomerId.IndexOf("ANATR", StringComparison.Ordinal) == 0, 1);
434+
AssertStringComparisonWarning(x => x.CustomerId.IndexOf("ANATR", 0, StringComparison.Ordinal) == 0, 1);
435+
AssertStringComparisonWarning(x => x.CustomerId.Replace("AN", "XX", StringComparison.Ordinal) == "XXATR", 1);
436+
});
437+
}
438+
439+
private void AssertStringComparisonWarning(Expression<Func<Customer,bool>> whereParam, int expected)
440+
{
441+
using (var log = new LogSpy(typeof(BaseHqlGeneratorForMethod)))
442+
{
443+
var customers = session.Query<Customer>().Where(whereParam).ToList();
444+
445+
Assert.That(customers, Has.Count.EqualTo(expected), whereParam.ToString);
446+
Assert.That(log.GetWholeLog(), Does.Contain($"parameter of type '{nameof(StringComparison)}' is ignored"), whereParam.ToString);
447+
}
448+
}
449+
422450
[Test]
423451
public void UsersWithArrayContains()
424452
{

src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
using System.Collections.Generic;
1+
using System;
2+
using System.Collections.Generic;
23
using System.Collections.ObjectModel;
4+
using System.Linq;
35
using System.Linq.Expressions;
46
using System.Reflection;
57
using NHibernate.Hql.Ast;
@@ -9,10 +11,33 @@ namespace NHibernate.Linq.Functions
911
{
1012
public abstract class BaseHqlGeneratorForMethod : IHqlGeneratorForMethod, IHqlGeneratorForMethodExtended
1113
{
14+
protected static readonly INHibernateLogger Log = NHibernateLogger.For(typeof(BaseHqlGeneratorForMethod));
15+
1216
public IEnumerable<MethodInfo> SupportedMethods { get; protected set; }
1317

1418
public abstract HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor);
1519

1620
public virtual bool AllowsNullableReturnType(MethodInfo method) => true;
21+
22+
private protected static void LogIgnoredParameter(MethodInfo method, string paramType)
23+
{
24+
if(Log.IsWarnEnabled())
25+
Log.Warn("Method parameter of type '{0}' is ignored when converting to hql the following method: {1}", paramType, method);
26+
}
27+
28+
private protected static void LogIgnoredStringComparisonParameter(MethodInfo actualMethod, MethodInfo methodWithStringComparison)
29+
{
30+
if(actualMethod == methodWithStringComparison)
31+
LogIgnoredParameter(actualMethod, nameof(StringComparison));
32+
}
33+
34+
private protected bool LogIgnoredStringComparisonParameter(MethodInfo actualMethod, params MethodInfo[] methodsWithStringComparison)
35+
{
36+
if (!methodsWithStringComparison.Contains(actualMethod))
37+
return false;
38+
39+
LogIgnoredParameter(actualMethod, nameof(StringComparison));
40+
return true;
41+
}
1742
}
1843
}

src/NHibernate/Linq/Functions/CompareGenerator.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@ namespace NHibernate.Linq.Functions
1212
{
1313
internal class CompareGenerator : BaseHqlGeneratorForMethod, IRuntimeMethodHqlGenerator
1414
{
15+
private static readonly MethodInfo MethodWithComparer = ReflectHelper.FastGetMethod(string.Compare, default(string), default(string), default(StringComparison));
16+
1517
private static readonly HashSet<MethodInfo> ActingMethods = new HashSet<MethodInfo>
1618
{
1719
ReflectHelper.FastGetMethod(string.Compare, default(string), default(string)),
18-
ReflectHelper.FastGetMethod(string.Compare, default(string), default(string), default(StringComparison)),
20+
MethodWithComparer,
1921
ReflectHelper.GetMethodDefinition<string>(s => s.CompareTo(s)),
2022
ReflectHelper.GetMethodDefinition<char>(x => x.CompareTo(x)),
2123

@@ -44,7 +46,10 @@ internal class CompareGenerator : BaseHqlGeneratorForMethod, IRuntimeMethodHqlGe
4446
internal static bool IsCompareMethod(MethodInfo methodInfo)
4547
{
4648
if (ActingMethods.Contains(methodInfo))
49+
{
50+
LogIgnoredStringComparisonParameter(methodInfo, MethodWithComparer);
4751
return true;
52+
}
4853

4954
// This is .Net 4 only, and in the System.Data.Services assembly, which we don't depend directly on.
5055
return methodInfo != null && methodInfo.Name == "Compare" &&

src/NHibernate/Linq/Functions/StringGenerator.cs

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Collections.Immutable;
34
using System.Collections.ObjectModel;
45
using System.Linq;
56
using System.Linq.Expressions;
@@ -76,15 +77,18 @@ public override HqlTreeNode BuildHql(MemberInfo member, Expression expression, H
7677

7778
public class StartsWithGenerator : BaseHqlGeneratorForMethod
7879
{
80+
private static readonly MethodInfo MethodWithComparer = ReflectHelper.GetMethodDefinition<string>(x => x.StartsWith(null, default(StringComparison)));
81+
7982
public StartsWithGenerator()
8083
{
81-
SupportedMethods = new[] { ReflectHelper.GetMethodDefinition<string>(x => x.StartsWith(null)) };
84+
SupportedMethods = new[] {ReflectHelper.GetMethodDefinition<string>(x => x.StartsWith(null)), MethodWithComparer};
8285
}
8386

8487
public override bool AllowsNullableReturnType(MethodInfo method) => false;
8588

8689
public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
8790
{
91+
LogIgnoredStringComparisonParameter(method, MethodWithComparer);
8892
return treeBuilder.Like(
8993
visitor.Visit(targetObject).AsExpression(),
9094
treeBuilder.Concat(
@@ -95,15 +99,18 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject,
9599

96100
public class EndsWithGenerator : BaseHqlGeneratorForMethod
97101
{
102+
private static readonly MethodInfo MethodWithComparer = ReflectHelper.GetMethodDefinition<string>(x => x.EndsWith(null, default(StringComparison)));
103+
98104
public EndsWithGenerator()
99105
{
100-
SupportedMethods = new[] { ReflectHelper.GetMethodDefinition<string>(x => x.EndsWith(null)) };
106+
SupportedMethods = new[] {ReflectHelper.GetMethodDefinition<string>(x => x.EndsWith(null)), MethodWithComparer,};
101107
}
102108

103109
public override bool AllowsNullableReturnType(MethodInfo method) => false;
104110

105111
public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
106112
{
113+
LogIgnoredStringComparisonParameter(method, MethodWithComparer);
107114
return treeBuilder.Like(
108115
visitor.Visit(targetObject).AsExpression(),
109116
treeBuilder.Concat(
@@ -210,18 +217,28 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject,
210217

211218
public class IndexOfGenerator : BaseHqlGeneratorForMethod
212219
{
220+
private static readonly MethodInfo MethodWithComparer1 = ReflectHelper.GetMethodDefinition<string>(x => x.IndexOf(string.Empty, default(StringComparison)));
221+
private static readonly MethodInfo MethodWithComparer2 = ReflectHelper.GetMethodDefinition<string>(x => x.IndexOf(string.Empty, 0, default(StringComparison)));
222+
213223
public IndexOfGenerator()
214224
{
215225
SupportedMethods = new[]
216226
{
217227
ReflectHelper.GetMethodDefinition<string>(s => s.IndexOf(' ')),
218228
ReflectHelper.GetMethodDefinition<string>(s => s.IndexOf(" ")),
219229
ReflectHelper.GetMethodDefinition<string>(s => s.IndexOf(' ', 0)),
220-
ReflectHelper.GetMethodDefinition<string>(s => s.IndexOf(" ", 0))
230+
ReflectHelper.GetMethodDefinition<string>(s => s.IndexOf(" ", 0)),
231+
MethodWithComparer1,
232+
MethodWithComparer2,
221233
};
222234
}
223235
public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
224236
{
237+
if (LogIgnoredStringComparisonParameter(method, MethodWithComparer1, MethodWithComparer2))
238+
{
239+
arguments = arguments.Where(a => a.Type != typeof(StringComparison)).ToList().AsReadOnly();
240+
}
241+
225242
HqlMethodCall locate;
226243
if (arguments.Count == 1)
227244
{
@@ -244,21 +261,26 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject,
244261

245262
public class ReplaceGenerator : BaseHqlGeneratorForMethod
246263
{
264+
private static readonly MethodInfo MethodWithComparer = ReflectHelper.GetMethodDefinition<string>(x => x.Replace(string.Empty, string.Empty, default(StringComparison)));
265+
247266
public ReplaceGenerator()
248267
{
249268
SupportedMethods = new[]
250269
{
251270
ReflectHelper.GetMethodDefinition<string>(s => s.Replace(' ', ' ')),
252-
ReflectHelper.GetMethodDefinition<string>(s => s.Replace("", ""))
271+
ReflectHelper.GetMethodDefinition<string>(s => s.Replace("", "")),
272+
MethodWithComparer,
253273
};
254274
}
255275

256276
public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
257277
{
258-
return treeBuilder.MethodCall("replace",
259-
visitor.Visit(targetObject).AsExpression(),
260-
visitor.Visit(arguments[0]).AsExpression(),
261-
visitor.Visit(arguments[1]).AsExpression());
278+
LogIgnoredStringComparisonParameter(method, MethodWithComparer);
279+
return treeBuilder.MethodCall(
280+
"replace",
281+
visitor.Visit(targetObject).AsExpression(),
282+
visitor.Visit(arguments[0]).AsExpression(),
283+
visitor.Visit(arguments[1]).AsExpression());
262284
}
263285
}
264286

0 commit comments

Comments
 (0)