Skip to content

Commit 3f7ae00

Browse files
maca88bahusoid
authored andcommitted
Fix parameter detection for custom hql method generators (nhibernate#2793)
1 parent 44871d9 commit 3f7ae00

File tree

3 files changed

+58
-0
lines changed

3 files changed

+58
-0
lines changed

src/NHibernate.Test/Async/Linq/CustomExtensionsExample.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
using System.Reflection;
1515
using System.Text.RegularExpressions;
1616
using NHibernate.Cfg;
17+
using NHibernate.DomainModel.Northwind.Entities;
1718
using NHibernate.Hql.Ast;
1819
using NHibernate.Linq.Functions;
1920
using NHibernate.Linq.Visitors;
@@ -33,6 +34,14 @@ protected override void Configure(NHibernate.Cfg.Configuration configuration)
3334
configuration.LinqToHqlGeneratorsRegistry<MyLinqToHqlGeneratorsRegistry>();
3435
}
3536

37+
[Test]
38+
public async Task CanUseObjectEqualsAsync()
39+
{
40+
var users = await (db.Users.Where(o => ((object) EnumStoredAsString.Medium).Equals(o.NullableEnum1)).ToListAsync());
41+
Assert.That(users.Count, Is.EqualTo(2));
42+
Assert.That(users.All(c => c.NullableEnum1 == EnumStoredAsString.Medium), Is.True);
43+
}
44+
3645
[Test]
3746
public async Task CanUseMyCustomExtensionAsync()
3847
{

src/NHibernate.Test/Linq/CustomExtensionsExample.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Reflection;
55
using System.Text.RegularExpressions;
66
using NHibernate.Cfg;
7+
using NHibernate.DomainModel.Northwind.Entities;
78
using NHibernate.Hql.Ast;
89
using NHibernate.Linq.Functions;
910
using NHibernate.Linq.Visitors;
@@ -30,6 +31,7 @@ public MyLinqToHqlGeneratorsRegistry():base()
3031
{
3132
RegisterGenerator(ReflectHelper.GetMethodDefinition(() => MyLinqExtensions.IsLike(null, null)),
3233
new IsLikeGenerator());
34+
RegisterGenerator(ReflectHelper.GetMethodDefinition(() => new object().Equals(null)), new ObjectEqualsGenerator());
3335
}
3436
}
3537

@@ -48,6 +50,21 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject,
4850
}
4951
}
5052

53+
public class ObjectEqualsGenerator : BaseHqlGeneratorForMethod
54+
{
55+
public ObjectEqualsGenerator()
56+
{
57+
SupportedMethods = new[] { ReflectHelper.GetMethodDefinition(() => new object().Equals(null)) };
58+
}
59+
60+
public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject,
61+
ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
62+
{
63+
return treeBuilder.Equality(visitor.Visit(targetObject).AsExpression(),
64+
visitor.Visit(arguments[0]).AsExpression());
65+
}
66+
}
67+
5168
[TestFixture]
5269
public class CustomExtensionsExample : LinqTestCase
5370
{
@@ -56,6 +73,14 @@ protected override void Configure(NHibernate.Cfg.Configuration configuration)
5673
configuration.LinqToHqlGeneratorsRegistry<MyLinqToHqlGeneratorsRegistry>();
5774
}
5875

76+
[Test]
77+
public void CanUseObjectEquals()
78+
{
79+
var users = db.Users.Where(o => ((object) EnumStoredAsString.Medium).Equals(o.NullableEnum1)).ToList();
80+
Assert.That(users.Count, Is.EqualTo(2));
81+
Assert.That(users.All(c => c.NullableEnum1 == EnumStoredAsString.Medium), Is.True);
82+
}
83+
5984
[Test]
6085
public void CanUseMyCustomExtension()
6186
{

src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,11 @@ private static IType GetParameterType(
156156
return candidateType;
157157
}
158158

159+
if (visitor.NotGuessableConstants.Contains(constantExpression))
160+
{
161+
return null;
162+
}
163+
159164
// No related MemberExpressions was found, guess the type by value or its type when null.
160165
// When a numeric parameter is compared to different columns with different types (e.g. Where(o => o.Single >= singleParam || o.Double <= singleParam))
161166
// do not change the parameter type, but instead cast the parameter when comparing with different column types.
@@ -166,10 +171,13 @@ private static IType GetParameterType(
166171

167172
private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor
168173
{
174+
private bool _hqlGenerator;
169175
private readonly bool _removeMappedAsCalls;
170176
private readonly System.Type _targetType;
171177
private readonly IDictionary<ConstantExpression, NamedParameter> _parameters;
172178
private readonly ISessionFactoryImplementor _sessionFactory;
179+
private readonly ILinqToHqlGeneratorsRegistry _functionRegistry;
180+
public readonly HashSet<ConstantExpression> NotGuessableConstants = new HashSet<ConstantExpression>();
173181
public readonly Dictionary<ConstantExpression, IType> ConstantExpressions =
174182
new Dictionary<ConstantExpression, IType>();
175183
public readonly Dictionary<NamedParameter, HashSet<ConstantExpression>> ParameterConstants =
@@ -187,6 +195,7 @@ public ConstantTypeLocatorVisitor(
187195
_targetType = targetType;
188196
_sessionFactory = sessionFactory;
189197
_parameters = parameters;
198+
_functionRegistry = sessionFactory.Settings.LinqToHqlGeneratorsRegistry;
190199
}
191200

192201
protected override Expression VisitBinary(BinaryExpression node)
@@ -257,6 +266,16 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
257266
return node;
258267
}
259268

269+
// For hql method generators we do not want to guess the parameter type here, let hql logic figure it out.
270+
if (_functionRegistry.TryGetGenerator(node.Method, out _))
271+
{
272+
var origHqlGenerator = _hqlGenerator;
273+
_hqlGenerator = true;
274+
var expression = base.VisitMethodCall(node);
275+
_hqlGenerator = origHqlGenerator;
276+
return expression;
277+
}
278+
260279
return base.VisitMethodCall(node);
261280
}
262281

@@ -267,6 +286,11 @@ protected override Expression VisitConstant(ConstantExpression node)
267286
return node;
268287
}
269288

289+
if (_hqlGenerator)
290+
{
291+
NotGuessableConstants.Add(node);
292+
}
293+
270294
RelatedExpressions.Add(node, new HashSet<Expression>());
271295
ConstantExpressions.Add(node, null);
272296
if (!ParameterConstants.TryGetValue(param, out var set))

0 commit comments

Comments
 (0)