Skip to content

Fix OData string comparisons (lt, gt, ge, le) #2386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/NHibernate.Test/Async/Linq/ODataTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ public async Task BasePropertyFilterAsync(string queryString, int expectedRows)
Assert.That(results, Has.Count.EqualTo(expectedRows));
}

//GH-2362
[TestCase("$filter=CustomerId le 'ANATR'", 2)]
[TestCase("$filter=startswith(CustomerId, 'ANATR')", 1)]
[TestCase("$filter=endswith(CustomerId, 'ANATR')", 1)]
[TestCase("$filter=indexof(CustomerId, 'ANATR') eq 0", 1)]
public async Task StringFilterAsync(string queryString, int expectedCount)
{
Assert.That(
await (ApplyFilter(session.Query<Customer>(), queryString).Cast<Customer>().ToListAsync()),
Has.Count.EqualTo(expectedCount));
}

private IQueryable ApplyFilter<T>(IQueryable<T> query, string queryString)
{
var context = new ODataQueryContext(CreatEdmModel(), typeof(T), null) { };
Expand Down
31 changes: 31 additions & 0 deletions src/NHibernate.Test/Async/Linq/WhereTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
using System.Collections.ObjectModel;
using System.Linq;
using System.Linq.Expressions;
using log4net.Core;
using NHibernate.Engine.Query;
using NHibernate.Linq;
using NHibernate.DomainModel.Northwind.Entities;
using NHibernate.Linq.Functions;
using NUnit.Framework;

namespace NHibernate.Test.Linq
{
using System.Threading.Tasks;
using System.Threading;
[TestFixture]
public class WhereTestsAsync : LinqTestCase
{
Expand Down Expand Up @@ -430,6 +433,34 @@ public async Task UsersWithStringContainsAndNotNullNameHQLAsync()
Assert.That(users.Count, Is.EqualTo(1));
}

[Test]
public void StringComparisonParamEmitsWarningAsync()
{
Assert.Multiple(
async () =>
{
await (AssertStringComparisonWarningAsync(x => string.Compare(x.CustomerId, "ANATR", StringComparison.Ordinal) <= 0, 2));
await (AssertStringComparisonWarningAsync(x => x.CustomerId.StartsWith("ANATR", StringComparison.Ordinal), 1));
await (AssertStringComparisonWarningAsync(x => x.CustomerId.EndsWith("ANATR", StringComparison.Ordinal), 1));
await (AssertStringComparisonWarningAsync(x => x.CustomerId.IndexOf("ANATR", StringComparison.Ordinal) == 0, 1));
await (AssertStringComparisonWarningAsync(x => x.CustomerId.IndexOf("ANATR", 0, StringComparison.Ordinal) == 0, 1));
#if NETCOREAPP2_0
await (AssertStringComparisonWarningAsync(x => x.CustomerId.Replace("AN", "XX", StringComparison.Ordinal) == "XXATR", 1));
#endif
});
}

private async Task AssertStringComparisonWarningAsync(Expression<Func<Customer, bool>> whereParam, int expected, CancellationToken cancellationToken = default(CancellationToken))
{
using (var log = new LogSpy(typeof(BaseHqlGeneratorForMethod)))
{
var customers = await (session.Query<Customer>().Where(whereParam).ToListAsync(cancellationToken));

Assert.That(customers, Has.Count.EqualTo(expected), whereParam.ToString);
Assert.That(log.GetWholeLog(), Does.Contain($"parameter of type '{nameof(StringComparison)}' is ignored"), whereParam.ToString);
}
}

[Test]
public async Task UsersWithArrayContainsAsync()
{
Expand Down
12 changes: 12 additions & 0 deletions src/NHibernate.Test/Linq/ODataTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ public void BasePropertyFilter(string queryString, int expectedRows)
Assert.That(results, Has.Count.EqualTo(expectedRows));
}

//GH-2362
[TestCase("$filter=CustomerId le 'ANATR'", 2)]
[TestCase("$filter=startswith(CustomerId, 'ANATR')", 1)]
[TestCase("$filter=endswith(CustomerId, 'ANATR')", 1)]
[TestCase("$filter=indexof(CustomerId, 'ANATR') eq 0", 1)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we do not register EndsWith, StartsWith or IndexOf overloads with StringComparison, does this means there is an inconsistency in OData implementation here? It would emit Compare calls with that argument, but it would omit it for these three other cases?

So I see three options:

  • only register what is currently required by OData, as does this PR currently.
  • also register other cases overloads in case OData changes on this, or in case another Linq generator library requires them.
  • reject this as an external issue, to be fixed on OData side.

@maca88, @hazzik

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this means there is an inconsistency in OData implementation here?

Yeap. Seems to be.

also register other cases overloads in case OData changes on this, or in case another Linq generator library requires them.

I would choose this one as I see no difference with call to default string.Compare. It's not corelated in any way to with how DB is configured to make string comparisons. In fact StringComparison might be used to make LINQ to objects queries behave the same as DB (as I already mentioned in case of default SQL Server case insensitive collation)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By just adding the missing overload we introduce another issue where we generate the same query when using StringComparison.XXX and StringComparison.XXXIgnoreCase, which would cause an unexpected result on a case sensitive collation. IMO we should at least try to mimic .NET behavior when using XXXIgnoreCase by using lower function on each side.

does this means there is an inconsistency in OData implementation here?

Yes, from the OData source code only StringComparison.Ordinal option is used for string.Compare, where for EndsWith, StartsWith and other methods the StringComparison overload is not used.

Copy link
Member

@fredericDelaporte fredericDelaporte May 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By just adding the missing overload we introduce another issue where we generate the same query when using StringComparison.XXX and StringComparison.XXXIgnoreCase, which would cause an unexpected result on a case sensitive collation.

The thing is, we do not try to support different StringComparison values, we just ignore them. The trouble is the same conversely, with case insensitive collations, when using a non ignoring case StringComparison. And for this converse case, there is no ".Net" side workaround.

The current state is, NHibernate bluntly states these overloads are not supported. That is fine for user generated expression. But when the user cannot control the generated expression, it is a blocker.

So I think accepting these overloads, while not actually supporting the StringComparison argument, is an acceptable trade-of. Maybe we should emit a NHibernate warning or information about the argument being ignored, though.

Copy link
Contributor

@maca88 maca88 May 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And for this converse case, there is no ".Net" side workaround

Some databases support the COLLATE clause (SqlLite, SqlServer, Oracle, PostgreSql, MySql) within the query, which could be used to do case sensitive/insensitive comparison. But that's a matter for a separate PR.

So I think accepting these overloads, while not actually supporting the StringComparison argument, is an acceptable trade-of

It is fine by me.

public void StringFilter(string queryString, int expectedCount)
{
Assert.That(
ApplyFilter(session.Query<Customer>(), queryString).Cast<Customer>().ToList(),
Has.Count.EqualTo(expectedCount));
}

private IQueryable ApplyFilter<T>(IQueryable<T> query, string queryString)
{
var context = new ODataQueryContext(CreatEdmModel(), typeof(T), null) { };
Expand Down
30 changes: 30 additions & 0 deletions src/NHibernate.Test/Linq/WhereTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
using System.Collections.ObjectModel;
using System.Linq;
using System.Linq.Expressions;
using log4net.Core;
using NHibernate.Engine.Query;
using NHibernate.Linq;
using NHibernate.DomainModel.Northwind.Entities;
using NHibernate.Linq.Functions;
using NUnit.Framework;

namespace NHibernate.Test.Linq
Expand Down Expand Up @@ -419,6 +421,34 @@ public void UsersWithStringContainsAndNotNullNameHQL()
Assert.That(users.Count, Is.EqualTo(1));
}

[Test]
public void StringComparisonParamEmitsWarning()
{
Assert.Multiple(
() =>
{
AssertStringComparisonWarning(x => string.Compare(x.CustomerId, "ANATR", StringComparison.Ordinal) <= 0, 2);
AssertStringComparisonWarning(x => x.CustomerId.StartsWith("ANATR", StringComparison.Ordinal), 1);
AssertStringComparisonWarning(x => x.CustomerId.EndsWith("ANATR", StringComparison.Ordinal), 1);
AssertStringComparisonWarning(x => x.CustomerId.IndexOf("ANATR", StringComparison.Ordinal) == 0, 1);
AssertStringComparisonWarning(x => x.CustomerId.IndexOf("ANATR", 0, StringComparison.Ordinal) == 0, 1);
#if NETCOREAPP2_0
AssertStringComparisonWarning(x => x.CustomerId.Replace("AN", "XX", StringComparison.Ordinal) == "XXATR", 1);
#endif
});
}

private void AssertStringComparisonWarning(Expression<Func<Customer, bool>> whereParam, int expected)
{
using (var log = new LogSpy(typeof(BaseHqlGeneratorForMethod)))
{
var customers = session.Query<Customer>().Where(whereParam).ToList();

Assert.That(customers, Has.Count.EqualTo(expected), whereParam.ToString);
Assert.That(log.GetWholeLog(), Does.Contain($"parameter of type '{nameof(StringComparison)}' is ignored"), whereParam.ToString);
}
}

[Test]
public void UsersWithArrayContains()
{
Expand Down
27 changes: 26 additions & 1 deletion src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using System.Collections.Generic;
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using NHibernate.Hql.Ast;
Expand All @@ -9,10 +11,33 @@ namespace NHibernate.Linq.Functions
{
public abstract class BaseHqlGeneratorForMethod : IHqlGeneratorForMethod, IHqlGeneratorForMethodExtended
{
protected static readonly INHibernateLogger Log = NHibernateLogger.For(typeof(BaseHqlGeneratorForMethod));

public IEnumerable<MethodInfo> SupportedMethods { get; protected set; }

public abstract HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor);

public virtual bool AllowsNullableReturnType(MethodInfo method) => true;

private protected static void LogIgnoredParameter(MethodInfo method, string paramType)
{
if (Log.IsWarnEnabled())
Log.Warn("Method parameter of type '{0}' is ignored when converting to hql the following method: {1}", paramType, method);
}

private protected static void LogIgnoredStringComparisonParameter(MethodInfo actualMethod, MethodInfo methodWithStringComparison)
{
if (actualMethod == methodWithStringComparison)
LogIgnoredParameter(actualMethod, nameof(StringComparison));
}

private protected bool LogIgnoredStringComparisonParameter(MethodInfo actualMethod, params MethodInfo[] methodsWithStringComparison)
{
if (!methodsWithStringComparison.Contains(actualMethod))
return false;

LogIgnoredParameter(actualMethod, nameof(StringComparison));
return true;
}
}
}
6 changes: 6 additions & 0 deletions src/NHibernate/Linq/Functions/CompareGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ namespace NHibernate.Linq.Functions
{
internal class CompareGenerator : BaseHqlGeneratorForMethod, IRuntimeMethodHqlGenerator
{
private static readonly MethodInfo MethodWithComparer = ReflectHelper.FastGetMethod(string.Compare, default(string), default(string), default(StringComparison));

private static readonly HashSet<MethodInfo> ActingMethods = new HashSet<MethodInfo>
{
ReflectHelper.FastGetMethod(string.Compare, default(string), default(string)),
MethodWithComparer,
ReflectHelper.GetMethodDefinition<string>(s => s.CompareTo(s)),
ReflectHelper.GetMethodDefinition<char>(x => x.CompareTo(x)),

Expand Down Expand Up @@ -43,7 +46,10 @@ internal class CompareGenerator : BaseHqlGeneratorForMethod, IRuntimeMethodHqlGe
internal static bool IsCompareMethod(MethodInfo methodInfo)
{
if (ActingMethods.Contains(methodInfo))
{
LogIgnoredStringComparisonParameter(methodInfo, MethodWithComparer);
return true;
}

// This is .Net 4 only, and in the System.Data.Services assembly, which we don't depend directly on.
return methodInfo != null && methodInfo.Name == "Compare" &&
Expand Down
53 changes: 41 additions & 12 deletions src/NHibernate/Linq/Functions/StringGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,18 @@ public override HqlTreeNode BuildHql(MemberInfo member, Expression expression, H

public class StartsWithGenerator : BaseHqlGeneratorForMethod
{
private static readonly MethodInfo MethodWithComparer = ReflectHelper.GetMethodDefinition<string>(x => x.StartsWith(null, default(StringComparison)));

public StartsWithGenerator()
{
SupportedMethods = new[] { ReflectHelper.GetMethodDefinition<string>(x => x.StartsWith(null)) };
SupportedMethods = new[] {ReflectHelper.GetMethodDefinition<string>(x => x.StartsWith(null)), MethodWithComparer};
}

public override bool AllowsNullableReturnType(MethodInfo method) => false;

public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
{
LogIgnoredStringComparisonParameter(method, MethodWithComparer);
return treeBuilder.Like(
visitor.Visit(targetObject).AsExpression(),
treeBuilder.Concat(
Expand All @@ -95,15 +98,18 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject,

public class EndsWithGenerator : BaseHqlGeneratorForMethod
{
private static readonly MethodInfo MethodWithComparer = ReflectHelper.GetMethodDefinition<string>(x => x.EndsWith(null, default(StringComparison)));

public EndsWithGenerator()
{
SupportedMethods = new[] { ReflectHelper.GetMethodDefinition<string>(x => x.EndsWith(null)) };
SupportedMethods = new[] {ReflectHelper.GetMethodDefinition<string>(x => x.EndsWith(null)), MethodWithComparer,};
}

public override bool AllowsNullableReturnType(MethodInfo method) => false;

public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
{
LogIgnoredStringComparisonParameter(method, MethodWithComparer);
return treeBuilder.Like(
visitor.Visit(targetObject).AsExpression(),
treeBuilder.Concat(
Expand Down Expand Up @@ -210,20 +216,32 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject,

public class IndexOfGenerator : BaseHqlGeneratorForMethod
{
private static readonly MethodInfo MethodWithComparer1 = ReflectHelper.GetMethodDefinition<string>(x => x.IndexOf(string.Empty, default(StringComparison)));
private static readonly MethodInfo MethodWithComparer2 = ReflectHelper.GetMethodDefinition<string>(x => x.IndexOf(string.Empty, 0, default(StringComparison)));

public IndexOfGenerator()
{
SupportedMethods = new[]
{
ReflectHelper.GetMethodDefinition<string>(s => s.IndexOf(' ')),
ReflectHelper.GetMethodDefinition<string>(s => s.IndexOf(" ")),
ReflectHelper.GetMethodDefinition<string>(s => s.IndexOf(' ', 0)),
ReflectHelper.GetMethodDefinition<string>(s => s.IndexOf(" ", 0))
ReflectHelper.GetMethodDefinition<string>(s => s.IndexOf(" ", 0)),
MethodWithComparer1,
MethodWithComparer2,
};
}
public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
{
var argsCount = arguments.Count;
if (LogIgnoredStringComparisonParameter(method, MethodWithComparer1, MethodWithComparer2))
{
//StringComparison is last argument, just ignore it
argsCount--;
}

HqlMethodCall locate;
if (arguments.Count == 1)
if (argsCount == 1)
{
locate = treeBuilder.MethodCall("locate",
visitor.Visit(arguments[0]).AsExpression(),
Expand All @@ -244,21 +262,32 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject,

public class ReplaceGenerator : BaseHqlGeneratorForMethod
{
#if NETCOREAPP2_0
private static readonly MethodInfo MethodWithComparer = ReflectHelper.GetMethodDefinition<string>(x => x.Replace(string.Empty, string.Empty, default(StringComparison)));
#endif

public ReplaceGenerator()
{
SupportedMethods = new[]
{
ReflectHelper.GetMethodDefinition<string>(s => s.Replace(' ', ' ')),
ReflectHelper.GetMethodDefinition<string>(s => s.Replace("", ""))
};
{
ReflectHelper.GetMethodDefinition<string>(s => s.Replace(' ', ' ')),
ReflectHelper.GetMethodDefinition<string>(s => s.Replace("", "")),
#if NETCOREAPP2_0
MethodWithComparer,
#endif
};
}

public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
{
return treeBuilder.MethodCall("replace",
visitor.Visit(targetObject).AsExpression(),
visitor.Visit(arguments[0]).AsExpression(),
visitor.Visit(arguments[1]).AsExpression());
#if NETCOREAPP2_0
LogIgnoredStringComparisonParameter(method, MethodWithComparer);
#endif
return treeBuilder.MethodCall(
"replace",
visitor.Visit(targetObject).AsExpression(),
visitor.Visit(arguments[0]).AsExpression(),
visitor.Visit(arguments[1]).AsExpression());
}
}

Expand Down