Skip to content

Commit 4f7313e

Browse files
committed
Fix edge case scenarios
1 parent a199f0a commit 4f7313e

File tree

7 files changed

+198
-38
lines changed

7 files changed

+198
-38
lines changed

src/NHibernate.Test/Async/Linq/NullComparisonTests.cs

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,18 @@ public async Task NullInequalityWithNotNullAsync()
143143

144144
q = session.Query<AnotherEntityRequired>().Where(o => o.Address.Street != null && o.Address.Street != o.NullableOutput);
145145
await (ExpectAsync(q, Does.Contain("Output is null").IgnoreCase, InputSet, BothDifferent));
146+
147+
await (ExpectAsync(session.Query<Customer>().Where(o => o.CustomerId != null), Does.Not.Contain("is null").IgnoreCase));
148+
await (ExpectAsync(session.Query<Customer>().Where(o => null != o.CustomerId), Does.Not.Contain("is null").IgnoreCase));
149+
150+
await (ExpectAsync(session.Query<Customer>().Where(o => o.CustomerId != "test"), Does.Not.Contain("is null").IgnoreCase));
151+
await (ExpectAsync(session.Query<Customer>().Where(o => "test" != o.CustomerId), Does.Not.Contain("is null").IgnoreCase));
152+
153+
await (ExpectAsync(session.Query<OrderLine>().Where(o => o.Order.Customer.CustomerId != "test"), Does.Not.Contain("is null").IgnoreCase));
154+
await (ExpectAsync(session.Query<OrderLine>().Where(o => "test" != o.Order.Customer.CustomerId), Does.Not.Contain("is null").IgnoreCase));
155+
156+
await (ExpectAsync(session.Query<OrderLine>().Where(o => o.Order.Customer.CompanyName != "test"), Does.Not.Contain("is null").IgnoreCase));
157+
await (ExpectAsync(session.Query<OrderLine>().Where(o => "test" != o.Order.Customer.CompanyName), Does.Not.Contain("is null").IgnoreCase));
146158
}
147159

148160
[Test]
@@ -173,10 +185,10 @@ public async Task NullInequalityWithNotNullSubSelectAsync()
173185
public async Task NullEqualityWithNotNullAsync()
174186
{
175187
var q = session.Query<AnotherEntityRequired>().Where(o => o.Input == null);
176-
await (ExpectAsync(q, Does.Not.Contain("or is null").IgnoreCase, OutputSet, BothNull));
188+
await (ExpectAsync(q, Does.Contain("is null").IgnoreCase, OutputSet, BothNull));
177189

178190
q = session.Query<AnotherEntityRequired>().Where(o => null == o.Input);
179-
await (ExpectAsync(q, Does.Not.Contain("or is null").IgnoreCase, OutputSet, BothNull));
191+
await (ExpectAsync(q, Does.Contain("is null").IgnoreCase, OutputSet, BothNull));
180192

181193
q = session.Query<AnotherEntityRequired>().Where(o => o.InputNullability == AnotherEntityNullability.True);
182194
await (ExpectAsync(q, Does.Not.Contain("end is null").IgnoreCase, BothNull, OutputSet));
@@ -291,6 +303,15 @@ public async Task NullEqualityWithNotNullAsync()
291303

292304
q = session.Query<AnotherEntityRequired>().Where(o => o.Address.Street != null && o.Address.Street == o.NullableOutput);
293305
await (ExpectAsync(q, Does.Not.Contain("Output is null").IgnoreCase, BothSame));
306+
307+
await (ExpectAsync(session.Query<Customer>().Where(o => o.CustomerId == null), Does.Contain("is null").IgnoreCase));
308+
await (ExpectAsync(session.Query<Customer>().Where(o => null == o.CustomerId), Does.Contain("is null").IgnoreCase));
309+
await (ExpectAsync(session.Query<Customer>().Where(o => o.CustomerId == "test"), Does.Not.Contain("is null").IgnoreCase));
310+
await (ExpectAsync(session.Query<Customer>().Where(o => "test" == o.CustomerId), Does.Not.Contain("is null").IgnoreCase));
311+
await (ExpectAsync(session.Query<OrderLine>().Where(o => o.Order.Customer.CustomerId == "test"), Does.Not.Contain("is null").IgnoreCase));
312+
await (ExpectAsync(session.Query<OrderLine>().Where(o => "test" == o.Order.Customer.CustomerId), Does.Not.Contain("is null").IgnoreCase));
313+
await (ExpectAsync(session.Query<OrderLine>().Where(o => o.Order.Customer.CompanyName == "test"), Does.Not.Contain("is null").IgnoreCase));
314+
await (ExpectAsync(session.Query<OrderLine>().Where(o => "test" == o.Order.Customer.CompanyName), Does.Not.Contain("is null").IgnoreCase));
294315
}
295316

296317
[Test]
@@ -377,6 +398,24 @@ public async Task NullEqualityAsync()
377398
// Columns against columns
378399
q = from x in session.Query<AnotherEntity>() where x.Input == x.Output select x;
379400
await (ExpectAsync(q, BothSame, BothNull));
401+
402+
await (ExpectAsync(session.Query<OrderLine>().Where(o => o.Order.Customer.ContactName == null), Does.Contain("is null").IgnoreCase));
403+
await (ExpectAsync(session.Query<OrderLine>().Where(o => null == o.Order.Customer.ContactName), Does.Contain("is null").IgnoreCase));
404+
405+
await (ExpectAsync(session.Query<OrderLine>().Where(o => o.Order.Customer.ContactName == "test"), Does.Not.Contain("is null").IgnoreCase));
406+
await (ExpectAsync(session.Query<OrderLine>().Where(o => "test" == o.Order.Customer.ContactName), Does.Not.Contain("is null").IgnoreCase));
407+
408+
await (ExpectAsync(session.Query<User>().Where(o => null == o.Component.Property1), Does.Contain("is null").IgnoreCase));
409+
await (ExpectAsync(session.Query<User>().Where(o => o.Component.Property1 == null), Does.Contain("is null").IgnoreCase));
410+
411+
await (ExpectAsync(session.Query<User>().Where(o => "test" == o.Component.Property1), Does.Not.Contain("is null").IgnoreCase));
412+
await (ExpectAsync(session.Query<User>().Where(o => o.Component.Property1 == "test"), Does.Not.Contain("is null").IgnoreCase));
413+
414+
await (ExpectAsync(session.Query<User>().Where(o => null == o.Component.OtherComponent.OtherProperty1), Does.Contain("is null").IgnoreCase));
415+
await (ExpectAsync(session.Query<User>().Where(o => o.Component.OtherComponent.OtherProperty1 == null), Does.Contain("is null").IgnoreCase));
416+
417+
await (ExpectAsync(session.Query<User>().Where(o => "test" == o.Component.OtherComponent.OtherProperty1), Does.Not.Contain("is null").IgnoreCase));
418+
await (ExpectAsync(session.Query<User>().Where(o => o.Component.OtherComponent.OtherProperty1 == "test"), Does.Not.Contain("is null").IgnoreCase));
380419
}
381420

382421
[Test]
@@ -435,6 +474,24 @@ public async Task NullInequalityAsync()
435474
// Columns against columns
436475
q = from x in session.Query<AnotherEntity>() where x.Input != x.Output select x;
437476
await (ExpectAsync(q, BothDifferent, InputSet, OutputSet));
477+
478+
await (ExpectAsync(session.Query<OrderLine>().Where(o => o.Order.Customer.ContactName != null), Does.Not.Contain("is null").IgnoreCase));
479+
await (ExpectAsync(session.Query<OrderLine>().Where(o => null != o.Order.Customer.ContactName), Does.Not.Contain("is null").IgnoreCase));
480+
481+
await (ExpectAsync(session.Query<OrderLine>().Where(o => o.Order.Customer.ContactName != "test"), Does.Contain("is null").IgnoreCase));
482+
await (ExpectAsync(session.Query<OrderLine>().Where(o => "test" != o.Order.Customer.ContactName), Does.Contain("is null").IgnoreCase));
483+
484+
await (ExpectAsync(session.Query<User>().Where(o => null != o.Component.Property1), Does.Not.Contain("is null").IgnoreCase));
485+
await (ExpectAsync(session.Query<User>().Where(o => o.Component.Property1 != null), Does.Not.Contain("is null").IgnoreCase));
486+
487+
await (ExpectAsync(session.Query<User>().Where(o => "test" != o.Component.Property1), Does.Contain("is null").IgnoreCase));
488+
await (ExpectAsync(session.Query<User>().Where(o => o.Component.Property1 != "test"), Does.Contain("is null").IgnoreCase));
489+
490+
await (ExpectAsync(session.Query<User>().Where(o => null != o.Component.OtherComponent.OtherProperty1), Does.Not.Contain("is null").IgnoreCase));
491+
await (ExpectAsync(session.Query<User>().Where(o => o.Component.OtherComponent.OtherProperty1 != null), Does.Not.Contain("is null").IgnoreCase));
492+
493+
await (ExpectAsync(session.Query<User>().Where(o => "test" != o.Component.OtherComponent.OtherProperty1), Does.Contain("is null").IgnoreCase));
494+
await (ExpectAsync(session.Query<User>().Where(o => o.Component.OtherComponent.OtherProperty1 != "test"), Does.Contain("is null").IgnoreCase));
438495
}
439496

440497
[Test]
@@ -633,6 +690,15 @@ private async Task ExpectAsync(IQueryable<AnotherEntityRequired> q, IResolveCons
633690
return (await (q.ToListAsync(cancellationToken))).OrderBy(Key).ToList();
634691
}
635692

693+
private static async Task ExpectAsync<T>(IQueryable<T> query, IResolveConstraint sqlConstraint, CancellationToken cancellationToken = default(CancellationToken))
694+
{
695+
using (var sqlLog = new SqlLogSpy())
696+
{
697+
var list = await (query.ToListAsync(cancellationToken));
698+
Assert.That(sqlLog.GetWholeLog(), sqlConstraint);
699+
}
700+
}
701+
636702
private static string Key(AnotherEntityRequired e)
637703
{
638704
return "Input=" + (e.Input ?? "NULL") + ", Output=" + (e.Output ?? "NULL");

src/NHibernate.Test/Linq/NullComparisonTests.cs

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,18 @@ public void NullInequalityWithNotNull()
131131

132132
q = session.Query<AnotherEntityRequired>().Where(o => o.Address.Street != null && o.Address.Street != o.NullableOutput);
133133
Expect(q, Does.Contain("Output is null").IgnoreCase, InputSet, BothDifferent);
134+
135+
Expect(session.Query<Customer>().Where(o => o.CustomerId != null), Does.Not.Contain("is null").IgnoreCase);
136+
Expect(session.Query<Customer>().Where(o => null != o.CustomerId), Does.Not.Contain("is null").IgnoreCase);
137+
138+
Expect(session.Query<Customer>().Where(o => o.CustomerId != "test"), Does.Not.Contain("is null").IgnoreCase);
139+
Expect(session.Query<Customer>().Where(o => "test" != o.CustomerId), Does.Not.Contain("is null").IgnoreCase);
140+
141+
Expect(session.Query<OrderLine>().Where(o => o.Order.Customer.CustomerId != "test"), Does.Not.Contain("is null").IgnoreCase);
142+
Expect(session.Query<OrderLine>().Where(o => "test" != o.Order.Customer.CustomerId), Does.Not.Contain("is null").IgnoreCase);
143+
144+
Expect(session.Query<OrderLine>().Where(o => o.Order.Customer.CompanyName != "test"), Does.Not.Contain("is null").IgnoreCase);
145+
Expect(session.Query<OrderLine>().Where(o => "test" != o.Order.Customer.CompanyName), Does.Not.Contain("is null").IgnoreCase);
134146
}
135147

136148
[Test]
@@ -161,10 +173,10 @@ public void NullInequalityWithNotNullSubSelect()
161173
public void NullEqualityWithNotNull()
162174
{
163175
var q = session.Query<AnotherEntityRequired>().Where(o => o.Input == null);
164-
Expect(q, Does.Not.Contain("or is null").IgnoreCase, OutputSet, BothNull);
176+
Expect(q, Does.Contain("is null").IgnoreCase, OutputSet, BothNull);
165177

166178
q = session.Query<AnotherEntityRequired>().Where(o => null == o.Input);
167-
Expect(q, Does.Not.Contain("or is null").IgnoreCase, OutputSet, BothNull);
179+
Expect(q, Does.Contain("is null").IgnoreCase, OutputSet, BothNull);
168180

169181
q = session.Query<AnotherEntityRequired>().Where(o => o.InputNullability == AnotherEntityNullability.True);
170182
Expect(q, Does.Not.Contain("end is null").IgnoreCase, BothNull, OutputSet);
@@ -279,6 +291,15 @@ public void NullEqualityWithNotNull()
279291

280292
q = session.Query<AnotherEntityRequired>().Where(o => o.Address.Street != null && o.Address.Street == o.NullableOutput);
281293
Expect(q, Does.Not.Contain("Output is null").IgnoreCase, BothSame);
294+
295+
Expect(session.Query<Customer>().Where(o => o.CustomerId == null), Does.Contain("is null").IgnoreCase);
296+
Expect(session.Query<Customer>().Where(o => null == o.CustomerId), Does.Contain("is null").IgnoreCase);
297+
Expect(session.Query<Customer>().Where(o => o.CustomerId == "test"), Does.Not.Contain("is null").IgnoreCase);
298+
Expect(session.Query<Customer>().Where(o => "test" == o.CustomerId), Does.Not.Contain("is null").IgnoreCase);
299+
Expect(session.Query<OrderLine>().Where(o => o.Order.Customer.CustomerId == "test"), Does.Not.Contain("is null").IgnoreCase);
300+
Expect(session.Query<OrderLine>().Where(o => "test" == o.Order.Customer.CustomerId), Does.Not.Contain("is null").IgnoreCase);
301+
Expect(session.Query<OrderLine>().Where(o => o.Order.Customer.CompanyName == "test"), Does.Not.Contain("is null").IgnoreCase);
302+
Expect(session.Query<OrderLine>().Where(o => "test" == o.Order.Customer.CompanyName), Does.Not.Contain("is null").IgnoreCase);
282303
}
283304

284305
[Test]
@@ -365,6 +386,24 @@ public void NullEquality()
365386
// Columns against columns
366387
q = from x in session.Query<AnotherEntity>() where x.Input == x.Output select x;
367388
Expect(q, BothSame, BothNull);
389+
390+
Expect(session.Query<OrderLine>().Where(o => o.Order.Customer.ContactName == null), Does.Contain("is null").IgnoreCase);
391+
Expect(session.Query<OrderLine>().Where(o => null == o.Order.Customer.ContactName), Does.Contain("is null").IgnoreCase);
392+
393+
Expect(session.Query<OrderLine>().Where(o => o.Order.Customer.ContactName == "test"), Does.Not.Contain("is null").IgnoreCase);
394+
Expect(session.Query<OrderLine>().Where(o => "test" == o.Order.Customer.ContactName), Does.Not.Contain("is null").IgnoreCase);
395+
396+
Expect(session.Query<User>().Where(o => null == o.Component.Property1), Does.Contain("is null").IgnoreCase);
397+
Expect(session.Query<User>().Where(o => o.Component.Property1 == null), Does.Contain("is null").IgnoreCase);
398+
399+
Expect(session.Query<User>().Where(o => "test" == o.Component.Property1), Does.Not.Contain("is null").IgnoreCase);
400+
Expect(session.Query<User>().Where(o => o.Component.Property1 == "test"), Does.Not.Contain("is null").IgnoreCase);
401+
402+
Expect(session.Query<User>().Where(o => null == o.Component.OtherComponent.OtherProperty1), Does.Contain("is null").IgnoreCase);
403+
Expect(session.Query<User>().Where(o => o.Component.OtherComponent.OtherProperty1 == null), Does.Contain("is null").IgnoreCase);
404+
405+
Expect(session.Query<User>().Where(o => "test" == o.Component.OtherComponent.OtherProperty1), Does.Not.Contain("is null").IgnoreCase);
406+
Expect(session.Query<User>().Where(o => o.Component.OtherComponent.OtherProperty1 == "test"), Does.Not.Contain("is null").IgnoreCase);
368407
}
369408

370409
[Test]
@@ -423,6 +462,24 @@ public void NullInequality()
423462
// Columns against columns
424463
q = from x in session.Query<AnotherEntity>() where x.Input != x.Output select x;
425464
Expect(q, BothDifferent, InputSet, OutputSet);
465+
466+
Expect(session.Query<OrderLine>().Where(o => o.Order.Customer.ContactName != null), Does.Not.Contain("is null").IgnoreCase);
467+
Expect(session.Query<OrderLine>().Where(o => null != o.Order.Customer.ContactName), Does.Not.Contain("is null").IgnoreCase);
468+
469+
Expect(session.Query<OrderLine>().Where(o => o.Order.Customer.ContactName != "test"), Does.Contain("is null").IgnoreCase);
470+
Expect(session.Query<OrderLine>().Where(o => "test" != o.Order.Customer.ContactName), Does.Contain("is null").IgnoreCase);
471+
472+
Expect(session.Query<User>().Where(o => null != o.Component.Property1), Does.Not.Contain("is null").IgnoreCase);
473+
Expect(session.Query<User>().Where(o => o.Component.Property1 != null), Does.Not.Contain("is null").IgnoreCase);
474+
475+
Expect(session.Query<User>().Where(o => "test" != o.Component.Property1), Does.Contain("is null").IgnoreCase);
476+
Expect(session.Query<User>().Where(o => o.Component.Property1 != "test"), Does.Contain("is null").IgnoreCase);
477+
478+
Expect(session.Query<User>().Where(o => null != o.Component.OtherComponent.OtherProperty1), Does.Not.Contain("is null").IgnoreCase);
479+
Expect(session.Query<User>().Where(o => o.Component.OtherComponent.OtherProperty1 != null), Does.Not.Contain("is null").IgnoreCase);
480+
481+
Expect(session.Query<User>().Where(o => "test" != o.Component.OtherComponent.OtherProperty1), Does.Contain("is null").IgnoreCase);
482+
Expect(session.Query<User>().Where(o => o.Component.OtherComponent.OtherProperty1 != "test"), Does.Contain("is null").IgnoreCase);
426483
}
427484

428485
[Test]
@@ -621,6 +678,15 @@ private IList<AnotherEntityRequired> GetResults(IQueryable<AnotherEntityRequired
621678
return q.ToList().OrderBy(Key).ToList();
622679
}
623680

681+
private static void Expect<T>(IQueryable<T> query, IResolveConstraint sqlConstraint)
682+
{
683+
using (var sqlLog = new SqlLogSpy())
684+
{
685+
var list = query.ToList();
686+
Assert.That(sqlLog.GetWholeLog(), sqlConstraint);
687+
}
688+
}
689+
624690
private static string Key(AnotherEntityRequired e)
625691
{
626692
return "Input=" + (e.Input ?? "NULL") + ", Output=" + (e.Output ?? "NULL");

src/NHibernate/Async/Persister/Entity/IEntityPersister.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
using System.Collections.Generic;
2121
using NHibernate.Intercept;
2222
using NHibernate.Util;
23+
using System.Linq;
2324

2425
namespace NHibernate.Persister.Entity
2526
{

src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,5 @@ public abstract class BaseHqlGeneratorForMethod : IHqlGeneratorForMethod, IHqlGe
1414
public abstract HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor);
1515

1616
public virtual bool AllowsNullableReturnType(MethodInfo method) => true;
17-
18-
bool IHqlGeneratorForMethodExtended.AllowsNullableReturnType(MethodInfo method)
19-
{
20-
return AllowsNullableReturnType(method);
21-
}
2217
}
2318
}

src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using NHibernate.Linq.Clauses;
66
using NHibernate.Linq.Expressions;
77
using NHibernate.Linq.Functions;
8+
using NHibernate.Persister.Entity;
89
using NHibernate.Util;
910
using Remotion.Linq.Clauses;
1011
using Remotion.Linq.Clauses.Expressions;
@@ -162,17 +163,22 @@ private bool IsNullable(MemberExpression memberExpression, BinaryExpression equa
162163
}
163164

164165
// We have to check the member mapping to determine if is nullable
165-
var entityName = TryGetEntityName(memberExpression);
166+
var entityName = ExpressionsHelper.TryGetEntityName(_sessionFactory, memberExpression, out var memberPath);
166167
if (entityName == null)
167168
{
168-
return true; // not mapped
169+
return true; // Not mapped
169170
}
170171

171172
var persister = _sessionFactory.GetEntityPersister(entityName);
172-
var index = persister.EntityMetamodel.GetPropertyIndexOrNull(memberExpression.Member.Name);
173+
if (persister.IsIdentifierMember(memberPath))
174+
{
175+
return false; // Identifier is always not null
176+
}
177+
178+
var index = persister.EntityMetamodel.GetPropertyIndexOrNull(memberPath);
173179
if (!index.HasValue || persister.EntityMetamodel.PropertyNullability[index.Value])
174180
{
175-
return true; // not mapped or nullable
181+
return true; // Not mapped or nullable
176182
}
177183

178184
return IsNullable(memberExpression.Expression, equalityExpression);
@@ -210,23 +216,6 @@ private bool IsNullableExtension(Expression extensionExpression, BinaryExpressio
210216
}
211217
}
212218

213-
private string TryGetEntityName(MemberExpression memberExpression)
214-
{
215-
System.Type entityType;
216-
// Try to get the actual entity type from the query source if possbile as member can be declared
217-
// in a base type
218-
if (memberExpression.Expression is QuerySourceReferenceExpression querySourceReferenceExpression)
219-
{
220-
entityType = querySourceReferenceExpression.Type;
221-
}
222-
else
223-
{
224-
entityType = memberExpression.Member.ReflectedType;
225-
}
226-
227-
return _sessionFactory.TryGetGuessEntityName(entityType);
228-
}
229-
230219
private static bool IsMemberAccess(Expression expression)
231220
{
232221
if (expression.NodeType == ExpressionType.MemberAccess)

src/NHibernate/Tuple/Entity/EntityMetamodel.cs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -411,14 +411,20 @@ private bool HasPartialUpdateComponentGeneration(Mapping.Component component)
411411

412412
private void MapPropertyToIndex(Mapping.Property prop, int i)
413413
{
414-
propertyIndexes[prop.Name] = i;
415-
Mapping.Component comp = prop.Value as Mapping.Component;
416-
if (comp != null)
414+
MapPropertyToIndex(null, prop, i);
415+
}
416+
417+
private void MapPropertyToIndex(string path, Mapping.Property prop, int i)
418+
{
419+
propertyIndexes[!string.IsNullOrEmpty(path) ? $"{path}.{prop.Name}" : prop.Name] = i;
420+
if (!(prop.Value is Mapping.Component comp))
417421
{
418-
foreach (Mapping.Property subprop in comp.PropertyIterator)
419-
{
420-
propertyIndexes[prop.Name + '.' + subprop.Name] = i;
421-
}
422+
return;
423+
}
424+
425+
foreach (var subprop in comp.PropertyIterator)
426+
{
427+
MapPropertyToIndex(!string.IsNullOrEmpty(path) ? $"{path}.{prop.Name}" : prop.Name, subprop, i);
422428
}
423429
}
424430

0 commit comments

Comments
 (0)