Skip to content

Fix parameter detection when using custom hql functions #2964

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 4, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/NHibernate.Test/Async/Linq/CustomExtensionsExample.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
//------------------------------------------------------------------------------


using System;
using System.Collections.ObjectModel;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text.RegularExpressions;
using NHibernate.Cfg;
using NHibernate.DomainModel.Northwind.Entities;
using NHibernate.Hql.Ast;
using NHibernate.Linq.Functions;
Expand Down Expand Up @@ -42,6 +42,14 @@ public async Task CanUseObjectEqualsAsync()
Assert.That(users.All(c => c.NullableEnum1 == EnumStoredAsString.Medium), Is.True);
}

[Test(Description = "GH-2963")]
public async Task CanUseComparisonWithExtensionOnMappedPropertyAsync()
{
var time = DateTime.UtcNow.GetTime();
//using(new SqlLogSpy())
await (db.Users.Where(u => u.RegisteredAt.GetTime() > time).Select(u => u.Id).ToListAsync());
}

[Test]
public async Task CanUseMyCustomExtensionAsync()
{
Expand Down
29 changes: 28 additions & 1 deletion src/NHibernate.Test/Linq/CustomExtensionsExample.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
using System;
using System.Collections.ObjectModel;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text.RegularExpressions;
using NHibernate.Cfg;
using NHibernate.DomainModel.Northwind.Entities;
using NHibernate.Hql.Ast;
using NHibernate.Linq.Functions;
Expand All @@ -23,6 +23,11 @@ public static bool IsLike(this string source, string pattern)

return Regex.IsMatch(source, pattern);
}

public static TimeSpan GetTime(this DateTime dateTime)
{
return dateTime.TimeOfDay;
}
}

public class MyLinqToHqlGeneratorsRegistry: DefaultLinqToHqlGeneratorsRegistry
Expand All @@ -32,6 +37,20 @@ public MyLinqToHqlGeneratorsRegistry():base()
RegisterGenerator(ReflectHelper.GetMethodDefinition(() => MyLinqExtensions.IsLike(null, null)),
new IsLikeGenerator());
RegisterGenerator(ReflectHelper.GetMethodDefinition(() => new object().Equals(null)), new ObjectEqualsGenerator());
RegisterGenerator(ReflectHelper.GetMethodDefinition(() => MyLinqExtensions.GetTime(default(DateTime))), new GetTimeGenerator());
}
}

public class GetTimeGenerator : BaseHqlGeneratorForMethod
{
public GetTimeGenerator()
{
SupportedMethods = new[] { ReflectHelper.GetMethodDefinition(() => MyLinqExtensions.GetTime(default(DateTime))) };
}

public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
{
return treeBuilder.MethodCall("cast", visitor.Visit(arguments[0]).AsExpression(), treeBuilder.Ident(NHibernateUtil.TimeAsTimeSpan.Name));
}
}

Expand Down Expand Up @@ -81,6 +100,14 @@ public void CanUseObjectEquals()
Assert.That(users.All(c => c.NullableEnum1 == EnumStoredAsString.Medium), Is.True);
}

[Test(Description = "GH-2963")]
public void CanUseComparisonWithExtensionOnMappedProperty()
{
var time = DateTime.UtcNow.GetTime();
//using(new SqlLogSpy())
db.Users.Where(u => u.RegisteredAt.GetTime() > time).Select(u => u.Id).ToList();
}

[Test]
public void CanUseMyCustomExtension()
{
Expand Down
4 changes: 3 additions & 1 deletion src/NHibernate.Test/TypesTest/CharClass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,7 @@ public class CharClass
public int Id { get; set; }
public virtual char NormalChar { get; set; }
public virtual char? NullableChar { get; set; }
public virtual string AnsiString { get; set; }
public virtual char AnsiChar { get; set; }
}
}
}
3 changes: 3 additions & 0 deletions src/NHibernate.Test/TypesTest/CharClass.hbm.xml
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,8 @@

<property name="NormalChar"/>
<property name="NullableChar"/>
<property name="AnsiString" type="AnsiString(15)"/>
<property name="AnsiChar" type="AnsiChar"/>

</class>
</hibernate-mapping>
31 changes: 30 additions & 1 deletion src/NHibernate.Test/TypesTest/CharClassFixture.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using NUnit.Framework;
using System.Linq;

namespace NHibernate.Test.TypesTest
{
Expand Down Expand Up @@ -31,5 +32,33 @@ public void ReadWrite()
s.Flush();
}
}

[Test]
public void ParameterTypeForAnsiCharInLinq()
{
using (var logSpy = new SqlLogSpy())
using (var session = OpenSession())
{
var result = (from e in session.Query<CharClass>()
where e.AnsiChar == 'B'
select e).ToList();

Assert.That(logSpy.GetWholeLog(), Does.Contain("Type: AnsiString"));
}
}

[Test]
public void ParameterTypeForCharInAnsiStringInLinq()
{
using (var logSpy = new SqlLogSpy())
using (var session = OpenSession())
{
var result = (from e in session.Query<CharClass>()
where e.AnsiString[0] == 'P'
select e).ToList();

Assert.That(logSpy.GetWholeLog(), Does.Contain("Type: AnsiString"));
}
}
}
}
}
19 changes: 19 additions & 0 deletions src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public partial class HqlSqlWalker

private readonly IDictionary<string, string> _tokenReplacements;
private readonly IDictionary<string, NamedParameter> _namedParameters;
private readonly IDictionary<IParameterSpecification, IType> _guessedParameterTypes = new Dictionary<IParameterSpecification, IType>();

private JoinType _impliedJoinType;

Expand Down Expand Up @@ -98,6 +99,21 @@ public override void ReportError(RecognitionException e)
_parseErrorHandler.ReportError(e);
}

internal IStatement Transform()
{
var tree = (IStatement) statement().Tree;
// Use the guessed type in case we weren't been able to detect the type
foreach (var parameter in _parameters)
{
if (parameter.ExpectedType == null && _guessedParameterTypes.TryGetValue(parameter, out var guessedType))
{
parameter.ExpectedType = guessedType;
}
}

return tree;
}

/*
protected override void Mismatch(IIntStream input, int ttype, BitSet follow)
{
Expand Down Expand Up @@ -1072,7 +1088,10 @@ IASTNode GenerateNamedParameter(IASTNode delimiterNode, IASTNode nameNode)
// Add the parameter type information so that we are able to calculate functions return types
// when the parameter is used as an argument.
if (namedParameter.IsGuessedType)
{
_guessedParameterTypes[paramSpec] = namedParameter.Type;
parameter.GuessedType = namedParameter.Type;
}
else
parameter.ExpectedType = namedParameter.Type;
}
Expand Down
2 changes: 1 addition & 1 deletion src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ public IStatement Translate()
try
{
// Transform the tree.
_resultAst = (IStatement) hqlSqlWalker.statement().Tree;
_resultAst = hqlSqlWalker.Transform();
}
finally
{
Expand Down
27 changes: 3 additions & 24 deletions src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,9 @@ private static IType GetParameterType(
return candidateType;
}

if (visitor.NotGuessableConstants.Contains(constantExpression) && constantExpression.Value != null)
{
tryProcessInHql = true;
}

// Leave hql logic to determine the type except when the value is a char. Hql logic detects a char as a string, which causes an exception
// when trying to set a string db parameter with a char value.
tryProcessInHql = !(constantExpression.Value is char);
// No related MemberExpressions was found, guess the type by value or its type when null.
// When a numeric parameter is compared to different columns with different types (e.g. Where(o => o.Single >= singleParam || o.Double <= singleParam))
// do not change the parameter type, but instead cast the parameter when comparing with different column types.
Expand All @@ -174,13 +172,10 @@ private static IType GetParameterType(

private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor
{
private bool _hqlGenerator;
private readonly bool _removeMappedAsCalls;
private readonly System.Type _targetType;
private readonly IDictionary<ConstantExpression, NamedParameter> _parameters;
private readonly ISessionFactoryImplementor _sessionFactory;
private readonly ILinqToHqlGeneratorsRegistry _functionRegistry;
public readonly HashSet<ConstantExpression> NotGuessableConstants = new HashSet<ConstantExpression>();
public readonly Dictionary<ConstantExpression, IType> ConstantExpressions =
new Dictionary<ConstantExpression, IType>();
public readonly Dictionary<NamedParameter, HashSet<ConstantExpression>> ParameterConstants =
Expand All @@ -198,7 +193,6 @@ public ConstantTypeLocatorVisitor(
_targetType = targetType;
_sessionFactory = sessionFactory;
_parameters = parameters;
_functionRegistry = sessionFactory.Settings.LinqToHqlGeneratorsRegistry;
}

protected override Expression VisitBinary(BinaryExpression node)
Expand Down Expand Up @@ -269,16 +263,6 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
return node;
}

// For hql method generators we do not want to guess the parameter type here, let hql logic figure it out.
if (_functionRegistry.TryGetGenerator(node.Method, out _))
{
var origHqlGenerator = _hqlGenerator;
_hqlGenerator = true;
var expression = base.VisitMethodCall(node);
_hqlGenerator = origHqlGenerator;
return expression;
}

return base.VisitMethodCall(node);
}

Expand All @@ -289,11 +273,6 @@ protected override Expression VisitConstant(ConstantExpression node)
return node;
}

if (_hqlGenerator)
{
NotGuessableConstants.Add(node);
}

RelatedExpressions.Add(node, new HashSet<Expression>());
ConstantExpressions.Add(node, null);
if (!ParameterConstants.TryGetValue(param, out var set))
Expand Down