Skip to content

Commit 5c36581

Browse files
committed
Fix IEnumerable parameters outside Contains
1 parent 279151f commit 5c36581

File tree

7 files changed

+112
-54
lines changed

7 files changed

+112
-54
lines changed

src/NHibernate.DomainModel/Northwind/Entities/DynamicUser.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,17 @@
22

33
namespace NHibernate.DomainModel.Northwind.Entities
44
{
5-
public class DynamicUser
5+
public class DynamicUser : IEnumerable
66
{
77
public virtual int Id { get; set; }
88

99
public virtual dynamic Properties { get; set; }
1010

1111
public virtual IDictionary Settings { get; set; }
12+
13+
public virtual IEnumerator GetEnumerator()
14+
{
15+
throw new System.NotImplementedException();
16+
}
1217
}
1318
}

src/NHibernate.Test/Async/Linq/ParameterTests.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,34 @@ public async Task UsingTwoEntityParametersAsync()
8888
2));
8989
}
9090

91+
[Test]
92+
public async Task UsingEntityEnumerableParameterTwiceAsync()
93+
{
94+
if (!Dialect.SupportsSubSelects)
95+
{
96+
Assert.Ignore();
97+
}
98+
99+
var enumerable = await (db.DynamicUsers.FirstAsync());
100+
await (AssertTotalParametersAsync(
101+
db.DynamicUsers.Where(o => o == enumerable && o != enumerable),
102+
1));
103+
}
104+
105+
[Test]
106+
public async Task UsingEntityEnumerableListParameterTwiceAsync()
107+
{
108+
if (!Dialect.SupportsSubSelects)
109+
{
110+
Assert.Ignore();
111+
}
112+
113+
var enumerable = new[] {await (db.DynamicUsers.FirstAsync())};
114+
await (AssertTotalParametersAsync(
115+
db.DynamicUsers.Where(o => enumerable.Contains(o) && enumerable.Contains(o)),
116+
1));
117+
}
118+
91119
[Test]
92120
public async Task UsingValueTypeParameterTwiceAsync()
93121
{

src/NHibernate.Test/Linq/ParameterTests.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,34 @@ public void UsingTwoEntityParameters()
7676
2);
7777
}
7878

79+
[Test]
80+
public void UsingEntityEnumerableParameterTwice()
81+
{
82+
if (!Dialect.SupportsSubSelects)
83+
{
84+
Assert.Ignore();
85+
}
86+
87+
var enumerable = db.DynamicUsers.First();
88+
AssertTotalParameters(
89+
db.DynamicUsers.Where(o => o == enumerable && o != enumerable),
90+
1);
91+
}
92+
93+
[Test]
94+
public void UsingEntityEnumerableListParameterTwice()
95+
{
96+
if (!Dialect.SupportsSubSelects)
97+
{
98+
Assert.Ignore();
99+
}
100+
101+
var enumerable = new[] {db.DynamicUsers.First()};
102+
AssertTotalParameters(
103+
db.DynamicUsers.Where(o => enumerable.Contains(o) && enumerable.Contains(o)),
104+
1);
105+
}
106+
79107
[Test]
80108
public void UsingValueTypeParameterTwice()
81109
{

src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,23 @@ public void EqualStringTest()
9797
);
9898
}
9999

100+
[Test]
101+
public void EqualEntityTest()
102+
{
103+
var order = new Order();
104+
AssertResults(
105+
new Dictionary<string, Predicate<IType>>
106+
{
107+
{
108+
$"value({typeof(Order).FullName})",
109+
o => o is ManyToOneType manyToOne && manyToOne.Name == typeof(Order).FullName
110+
}
111+
},
112+
db.Orders.Where(o => o == order),
113+
db.Orders.Where(o => order == o)
114+
);
115+
}
116+
100117
[Test]
101118
public void DoubleEqualTest()
102119
{

src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Linq.Expressions;
66
using System.Reflection;
77
using NHibernate.Engine;
8+
using NHibernate.Linq.Functions;
89
using NHibernate.Param;
910
using NHibernate.Type;
1011
using NHibernate.Util;
@@ -19,8 +20,10 @@ public class ExpressionParameterVisitor : RelinqExpressionVisitor
1920
{
2021
private readonly Dictionary<ConstantExpression, NamedParameter> _parameters = new Dictionary<ConstantExpression, NamedParameter>();
2122
private readonly Dictionary<QueryVariable, NamedParameter> _variableParameters = new Dictionary<QueryVariable, NamedParameter>();
23+
private readonly HashSet<ConstantExpression> _collectionParameters = new HashSet<ConstantExpression>();
2224
private readonly IDictionary<ConstantExpression, QueryVariable> _queryVariables;
2325
private readonly ISessionFactoryImplementor _sessionFactory;
26+
private readonly ILinqToHqlGeneratorsRegistry _functionRegistry;
2427

2528
private static readonly ISet<MethodBase> PagingMethods = new HashSet<MethodBase>
2629
{
@@ -41,6 +44,7 @@ public ExpressionParameterVisitor(PreTransformationResult preTransformationResul
4144
{
4245
_sessionFactory = preTransformationResult.SessionFactory;
4346
_queryVariables = preTransformationResult.QueryVariables;
47+
_functionRegistry = _sessionFactory.Settings.LinqToHqlGeneratorsRegistry;
4448
}
4549

4650
// Since v5.3
@@ -98,6 +102,17 @@ protected override Expression VisitMethodCall(MethodCallExpression expression)
98102
return Expression.Call(null, expression.Method, query, arg);
99103
}
100104

105+
if (_functionRegistry != null &&
106+
_functionRegistry.TryGetGenerator(method, out var generator) &&
107+
generator is CollectionContainsGenerator)
108+
{
109+
var argument = method.IsStatic ? expression.Arguments[0] : expression.Object;
110+
if (argument is ConstantExpression constantExpression)
111+
{
112+
_collectionParameters.Add(constantExpression);
113+
}
114+
}
115+
101116
if (VisitorUtil.IsDynamicComponentDictionaryGetter(expression, _sessionFactory))
102117
{
103118
return expression;
@@ -172,7 +187,7 @@ protected override Expression VisitConstant(ConstantExpression expression)
172187
private NamedParameter CreateParameter(ConstantExpression expression, object value, IType type)
173188
{
174189
var parameterName = "p" + (_parameters.Count + 1);
175-
return IsCollectionType(expression)
190+
return _collectionParameters.Contains(expression)
176191
? new NamedListParameter(parameterName, value, type)
177192
: new NamedParameter(parameterName, value, type);
178193
}
@@ -181,15 +196,5 @@ private static bool IsNullObject(ConstantExpression expression)
181196
{
182197
return expression.Type == typeof(Object) && expression.Value == null;
183198
}
184-
185-
private static bool IsCollectionType(ConstantExpression expression)
186-
{
187-
if (expression.Value != null)
188-
{
189-
return expression.Value is IEnumerable && !(expression.Value is string);
190-
}
191-
192-
return expression.Type.IsCollectionType();
193-
}
194199
}
195200
}

src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ internal static void SetParameterTypes(
116116
if (type == null)
117117
{
118118
type = constantExpression.Value != null
119-
? ParameterHelper.TryGuessType(constantExpression.Value, sessionFactory, out _)
120-
: ParameterHelper.TryGuessType(constantExpression.Type, sessionFactory, out _);
119+
? ParameterHelper.TryGuessType(constantExpression.Value, sessionFactory, namedParameter.IsCollection)
120+
: ParameterHelper.TryGuessType(constantExpression.Type, sessionFactory, namedParameter.IsCollection);
121121
}
122122

123123
namedParameter.Type = type;
@@ -251,7 +251,9 @@ private void VisitAssign(Expression leftNode, Expression rightNode)
251251

252252
private void AddRelatedExpression(Expression node, Expression left, Expression right)
253253
{
254-
if (left.NodeType == ExpressionType.MemberAccess || IsDynamicMember(left))
254+
if (left.NodeType == ExpressionType.MemberAccess ||
255+
IsDynamicMember(left) ||
256+
left is QuerySourceReferenceExpression)
255257
{
256258
AddRelatedExpression(right, left);
257259
if (NonVoidOperators.Contains(node.NodeType))

src/NHibernate/Util/ParameterHelper.cs

Lines changed: 12 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,27 @@ internal static class ParameterHelper
1414
/// </summary>
1515
/// <param name="param">The object to guess the <see cref="IType"/> of.</param>
1616
/// <param name="sessionFactory">The session factory to search for entity persister.</param>
17-
/// <param name="isCollection">The output parameter that represents whether the <paramref name="param"/> is a collection.</param>
17+
/// <param name="isCollection">Whether <paramref name="param"/> is a collection.</param>
1818
/// <returns>An <see cref="IType"/> for the object.</returns>
1919
/// <exception cref="ArgumentNullException">
2020
/// Thrown when the <c>param</c> is null because the <see cref="IType"/>
2121
/// can't be guess from a null value.
2222
/// </exception>
23-
public static IType TryGuessType(object param, ISessionFactoryImplementor sessionFactory, out bool isCollection)
23+
public static IType TryGuessType(object param, ISessionFactoryImplementor sessionFactory, bool isCollection)
2424
{
2525
if (param == null)
2626
{
27-
throw new ArgumentNullException(nameof(param), "The IType can not be guessed for a null value.");
27+
return null;
2828
}
2929

30-
if (param is IEnumerable enumerable && !(param is string))
30+
if (param is IEnumerable enumerable && isCollection)
3131
{
3232
var firstValue = enumerable.Cast<object>().FirstOrDefault();
33-
isCollection = true;
3433
return firstValue == null
3534
? TryGuessType(enumerable.GetCollectionElementType(), sessionFactory)
36-
: TryGuessType(firstValue, sessionFactory, out _);
35+
: TryGuessType(firstValue, sessionFactory, false);
3736
}
3837

39-
isCollection = false;
4038
var clazz = NHibernateProxyHelper.GetClassWithoutInitializingProxy(param);
4139
return TryGuessType(clazz, sessionFactory);
4240
}
@@ -67,26 +65,24 @@ public static IType GuessType(object param, ISessionFactoryImplementor sessionFa
6765
/// </summary>
6866
/// <param name="clazz">The <see cref="System.Type"/> to guess the <see cref="IType"/> of.</param>
6967
/// <param name="sessionFactory">The session factory to search for entity persister.</param>
70-
/// <param name="isCollection">The output parameter that represents whether the <paramref name="clazz"/> is a collection.</param>
68+
/// <param name="isCollection">Whether <paramref name="clazz"/> is a collection.</param>
7169
/// <returns>An <see cref="IType"/> for the <see cref="System.Type"/>.</returns>
7270
/// <exception cref="ArgumentNullException">
7371
/// Thrown when the <c>clazz</c> is null because the <see cref="IType"/>
7472
/// can't be guess from a null type.
7573
/// </exception>
76-
public static IType TryGuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory, out bool isCollection)
74+
public static IType TryGuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory, bool isCollection)
7775
{
7876
if (clazz == null)
7977
{
80-
throw new ArgumentNullException(nameof(clazz), "The IType can not be guessed for a null value.");
78+
return null;
8179
}
8280

83-
if (clazz.IsCollectionType())
81+
if (isCollection)
8482
{
85-
isCollection = true;
86-
return TryGuessType(ReflectHelper.GetCollectionElementType(clazz), sessionFactory, out _);
83+
return TryGuessType(ReflectHelper.GetCollectionElementType(clazz), sessionFactory, false);
8784
}
8885

89-
isCollection = false;
9086
return TryGuessType(clazz, sessionFactory);
9187
}
9288

@@ -95,41 +91,18 @@ public static IType TryGuessType(System.Type clazz, ISessionFactoryImplementor s
9591
/// </summary>
9692
/// <param name="clazz">The <see cref="System.Type"/> to guess the <see cref="IType"/> of.</param>
9793
/// <param name="sessionFactory">The session factory to search for entity persister.</param>
98-
/// <param name="isCollection">The output parameter that represents whether the <paramref name="clazz"/> is a collection.</param>
9994
/// <returns>An <see cref="IType"/> for the <see cref="System.Type"/>.</returns>
10095
/// <exception cref="ArgumentNullException">
10196
/// Thrown when the <c>clazz</c> is null because the <see cref="IType"/>
10297
/// can't be guess from a null type.
10398
/// </exception>
104-
public static IType GuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory, out bool isCollection)
99+
public static IType GuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory)
105100
{
106101
if (clazz == null)
107102
{
108103
throw new ArgumentNullException(nameof(clazz), "The IType can not be guessed for a null value.");
109104
}
110105

111-
if (typeof(IEnumerable).IsAssignableFrom(clazz) && typeof(string) != clazz)
112-
{
113-
isCollection = true;
114-
return GuessType(ReflectHelper.GetCollectionElementType(clazz), sessionFactory);
115-
}
116-
117-
isCollection = false;
118-
return GuessType(clazz, sessionFactory);
119-
}
120-
121-
/// <summary>
122-
/// Guesses the <see cref="IType"/> from the <see cref="System.Type"/>.
123-
/// </summary>
124-
/// <param name="clazz">The <see cref="System.Type"/> to guess the <see cref="IType"/> of.</param>
125-
/// <param name="sessionFactory">The session factory to search for entity persister.</param>
126-
/// <returns>An <see cref="IType"/> for the <see cref="System.Type"/>.</returns>
127-
/// <exception cref="ArgumentNullException">
128-
/// Thrown when the <c>clazz</c> is null because the <see cref="IType"/>
129-
/// can't be guess from a null type.
130-
/// </exception>
131-
public static IType GuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory)
132-
{
133106
return TryGuessType(clazz, sessionFactory) ??
134107
throw new HibernateException("Could not determine a type for class: " + clazz.AssemblyQualifiedName);
135108
}
@@ -148,7 +121,7 @@ public static IType TryGuessType(System.Type clazz, ISessionFactoryImplementor s
148121
{
149122
if (clazz == null)
150123
{
151-
throw new ArgumentNullException(nameof(clazz), "The IType can not be guessed for a null value.");
124+
return null;
152125
}
153126

154127
var type = TypeFactory.HeuristicType(clazz);

0 commit comments

Comments
 (0)