Skip to content

Commit a199f0a

Browse files
committed
Moved nullable check code into a separated class
1 parent f3822b8 commit a199f0a

File tree

2 files changed

+316
-293
lines changed

2 files changed

+316
-293
lines changed

src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs

Lines changed: 7 additions & 293 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,15 @@
11
using System;
2-
using System.Collections.Generic;
32
using System.Dynamic;
43
using System.Linq;
54
using System.Linq.Expressions;
65
using System.Runtime.CompilerServices;
76
using NHibernate.Engine.Query;
87
using NHibernate.Hql.Ast;
9-
using NHibernate.Linq.Clauses;
108
using NHibernate.Linq.Expressions;
119
using NHibernate.Linq.Functions;
12-
using NHibernate.Mapping.ByCode;
1310
using NHibernate.Param;
1411
using NHibernate.Util;
15-
using Remotion.Linq.Clauses;
1612
using Remotion.Linq.Clauses.Expressions;
17-
using Remotion.Linq.Clauses.ResultOperators;
1813

1914
namespace NHibernate.Linq.Visitors
2015
{
@@ -23,17 +18,7 @@ public class HqlGeneratorExpressionVisitor : IHqlExpressionVisitor
2318
private readonly HqlTreeBuilder _hqlTreeBuilder = new HqlTreeBuilder();
2419
private readonly VisitorParameters _parameters;
2520
private readonly ILinqToHqlGeneratorsRegistry _functionRegistry;
26-
private readonly Dictionary<BinaryExpression, List<MemberExpression>> _equalityNotNullMembers =
27-
new Dictionary<BinaryExpression, List<MemberExpression>>();
28-
29-
private static readonly HashSet<System.Type> NotNullOperators = new HashSet<System.Type>()
30-
{
31-
typeof(AllResultOperator),
32-
typeof(AnyResultOperator),
33-
typeof(ContainsResultOperator),
34-
typeof(CountResultOperator),
35-
typeof(LongCountResultOperator)
36-
};
21+
private readonly NullableExpressionDetector _nullableExpressionDetector;
3722

3823
public static HqlTreeNode Visit(Expression expression, VisitorParameters parameters)
3924
{
@@ -44,6 +29,7 @@ public HqlGeneratorExpressionVisitor(VisitorParameters parameters)
4429
{
4530
_functionRegistry = parameters.SessionFactory.Settings.LinqToHqlGeneratorsRegistry;
4631
_parameters = parameters;
32+
_nullableExpressionDetector = new NullableExpressionDetector(_parameters.SessionFactory, _functionRegistry);
4733
}
4834

4935

@@ -299,94 +285,6 @@ private HqlTreeNode VisitVBStringComparisonExpression(VBStringComparisonExpressi
299285
return VisitExpression(expression.Comparison);
300286
}
301287

302-
private void SearchForNotNullMembersCheck(BinaryExpression expression)
303-
{
304-
// Check for a member not null check that has a not equals expression
305-
// Example: o.Status != null && o.Status != "New"
306-
// Example: (o.Status != null && o.OldStatus != null) && (o.Status != o.OldStatus)
307-
// Example: (o.Status != null && o.OldStatus != null) && (o.Status == o.OldStatus)
308-
if (expression.NodeType != ExpressionType.AndAlso ||
309-
expression.Right.NodeType != ExpressionType.NotEqual &&
310-
expression.Right.NodeType != ExpressionType.Equal ||
311-
expression.Left.NodeType != ExpressionType.AndAlso &&
312-
expression.Left.NodeType != ExpressionType.NotEqual)
313-
{
314-
return;
315-
}
316-
317-
// Skip if there are no member access expressions on the right side
318-
var notEqualExpression = (BinaryExpression) expression.Right;
319-
if (!IsMemberAccess(notEqualExpression.Left) && !IsMemberAccess(notEqualExpression.Right))
320-
{
321-
return;
322-
}
323-
324-
var notNullMembers = new List<MemberExpression>();
325-
// We may have multiple conditions
326-
// Example: o.Status != null && o.OldStatus != null
327-
if (expression.Left.NodeType == ExpressionType.AndAlso)
328-
{
329-
FindAllNotNullMembers((BinaryExpression) expression.Left, notNullMembers);
330-
}
331-
else
332-
{
333-
FindNotNullMember((BinaryExpression) expression.Left, notNullMembers);
334-
}
335-
336-
if (notNullMembers.Count > 0)
337-
{
338-
_equalityNotNullMembers[notEqualExpression] = notNullMembers;
339-
}
340-
}
341-
342-
private static bool IsMemberAccess(Expression expression)
343-
{
344-
if (expression.NodeType == ExpressionType.MemberAccess)
345-
{
346-
return true;
347-
}
348-
349-
// Nullable members can be wrapped in a convert expression
350-
return expression is UnaryExpression unaryExpression && unaryExpression.Operand.NodeType == ExpressionType.MemberAccess;
351-
}
352-
353-
private static void FindAllNotNullMembers(BinaryExpression andAlsoExpression, List<MemberExpression> notNullMembers)
354-
{
355-
if (andAlsoExpression.Right.NodeType == ExpressionType.NotEqual)
356-
{
357-
FindNotNullMember((BinaryExpression) andAlsoExpression.Right, notNullMembers);
358-
}
359-
else if (andAlsoExpression.Right.NodeType == ExpressionType.AndAlso)
360-
{
361-
FindAllNotNullMembers((BinaryExpression) andAlsoExpression.Right, notNullMembers);
362-
}
363-
else
364-
{
365-
return;
366-
}
367-
368-
if (andAlsoExpression.Left.NodeType == ExpressionType.NotEqual)
369-
{
370-
FindNotNullMember((BinaryExpression) andAlsoExpression.Left, notNullMembers);
371-
}
372-
else if (andAlsoExpression.Left.NodeType == ExpressionType.AndAlso)
373-
{
374-
FindAllNotNullMembers((BinaryExpression) andAlsoExpression.Left, notNullMembers);
375-
}
376-
}
377-
378-
private static void FindNotNullMember(BinaryExpression notEqualExpression, List<MemberExpression> notNullMembers)
379-
{
380-
if (notEqualExpression.Left.NodeType == ExpressionType.MemberAccess && VisitorUtil.IsNullConstant(notEqualExpression.Right))
381-
{
382-
notNullMembers.Add((MemberExpression) notEqualExpression.Left);
383-
}
384-
else if (VisitorUtil.IsNullConstant(notEqualExpression.Left) && notEqualExpression.Right.NodeType == ExpressionType.MemberAccess)
385-
{
386-
notNullMembers.Add((MemberExpression) notEqualExpression.Right);
387-
}
388-
}
389-
390288
protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression)
391289
{
392290
if (expression.NodeType == ExpressionType.Equal)
@@ -398,7 +296,7 @@ protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression)
398296
return TranslateInequalityComparison(expression);
399297
}
400298

401-
SearchForNotNullMembersCheck(expression);
299+
_nullableExpressionDetector.SearchForNotNullMemberChecks(expression);
402300

403301
var lhs = VisitExpression(expression.Left).AsExpression();
404302
var rhs = VisitExpression(expression.Right).AsExpression();
@@ -481,8 +379,8 @@ private HqlTreeNode TranslateInequalityComparison(BinaryExpression expression)
481379
return _hqlTreeBuilder.IsNotNull(lhs);
482380
}
483381

484-
var lhsNullable = IsNullable(expression.Left, expression);
485-
var rhsNullable = IsNullable(expression.Right, expression);
382+
var lhsNullable = _nullableExpressionDetector.IsNullable(expression.Left, expression);
383+
var rhsNullable = _nullableExpressionDetector.IsNullable(expression.Right, expression);
486384

487385
var inequality = _hqlTreeBuilder.Inequality(lhs, rhs);
488386

@@ -544,8 +442,8 @@ private HqlTreeNode TranslateEqualityComparison(BinaryExpression expression)
544442
return _hqlTreeBuilder.IsNull((lhs));
545443
}
546444

547-
var lhsNullable = IsNullable(expression.Left, expression);
548-
var rhsNullable = IsNullable(expression.Right, expression);
445+
var lhsNullable = _nullableExpressionDetector.IsNullable(expression.Left, expression);
446+
var rhsNullable = _nullableExpressionDetector.IsNullable(expression.Right, expression);
549447

550448
var equality = _hqlTreeBuilder.Equality(lhs, rhs);
551449

@@ -564,190 +462,6 @@ private HqlTreeNode TranslateEqualityComparison(BinaryExpression expression)
564462
_hqlTreeBuilder.IsNull(rhs2)));
565463
}
566464

567-
private bool IsNullable(Expression expression, BinaryExpression equalityExpression)
568-
{
569-
var currentExpression = expression;
570-
while (true)
571-
{
572-
switch (currentExpression.NodeType)
573-
{
574-
case ExpressionType.Convert:
575-
case ExpressionType.ConvertChecked:
576-
case ExpressionType.TypeAs:
577-
var unaryExpression = (UnaryExpression) currentExpression;
578-
return IsNullable(unaryExpression.Operand, equalityExpression); // a cast will not return null if the operand is not null
579-
case ExpressionType.Not:
580-
case ExpressionType.And:
581-
case ExpressionType.Or:
582-
case ExpressionType.ExclusiveOr:
583-
case ExpressionType.LeftShift:
584-
case ExpressionType.RightShift:
585-
case ExpressionType.AndAlso:
586-
case ExpressionType.OrElse:
587-
case ExpressionType.Equal:
588-
case ExpressionType.NotEqual:
589-
case ExpressionType.GreaterThanOrEqual:
590-
case ExpressionType.GreaterThan:
591-
case ExpressionType.LessThan:
592-
case ExpressionType.LessThanOrEqual:
593-
return false;
594-
case ExpressionType.Add:
595-
case ExpressionType.AddChecked:
596-
case ExpressionType.Divide:
597-
case ExpressionType.Modulo:
598-
case ExpressionType.Multiply:
599-
case ExpressionType.MultiplyChecked:
600-
case ExpressionType.Power:
601-
case ExpressionType.Subtract:
602-
case ExpressionType.SubtractChecked:
603-
var binaryExpression = (BinaryExpression) currentExpression;
604-
return IsNullable(binaryExpression.Left, equalityExpression) || IsNullable(binaryExpression.Right, equalityExpression);
605-
case ExpressionType.ArrayIndex:
606-
return true; // for indexed lists we cannot determine whether the item will be null or not
607-
case ExpressionType.Coalesce:
608-
return IsNullable(((BinaryExpression) currentExpression).Right, equalityExpression);
609-
case ExpressionType.Conditional:
610-
var conditionalExpression = (ConditionalExpression) currentExpression;
611-
return IsNullable(conditionalExpression.IfTrue, equalityExpression) ||
612-
IsNullable(conditionalExpression.IfFalse, equalityExpression);
613-
case ExpressionType.Call:
614-
var methodInfo = ((MethodCallExpression) currentExpression).Method;
615-
return !_functionRegistry.TryGetGenerator(methodInfo, out var method) || method.AllowsNullableReturnType(methodInfo);
616-
case ExpressionType.MemberAccess:
617-
var memberExpression = (MemberExpression) currentExpression;
618-
619-
if (_functionRegistry.TryGetGenerator(memberExpression.Member, out _))
620-
{
621-
// We have to skip the property as it will be converted to a function that can return null
622-
// if the argument is null
623-
currentExpression = memberExpression.Expression;
624-
continue;
625-
}
626-
627-
var memberType = ReflectHelper.GetPropertyOrFieldType(memberExpression.Member);
628-
if (memberType?.IsValueType == true && !memberType.IsNullable())
629-
{
630-
currentExpression = memberExpression.Expression;
631-
continue;
632-
}
633-
634-
// Check if there was a not null check prior the equality expression
635-
if ((
636-
equalityExpression.NodeType == ExpressionType.NotEqual ||
637-
equalityExpression.NodeType == ExpressionType.Equal
638-
) &&
639-
_equalityNotNullMembers.TryGetValue(equalityExpression, out var notNullMembers) &&
640-
notNullMembers.Any(o => AreEqual(o, memberExpression)))
641-
{
642-
return false;
643-
}
644-
645-
// We have to check the member mapping to determine if is nullable
646-
var entityName = TryGetEntityName(memberExpression);
647-
if (entityName == null)
648-
{
649-
return true; // not mapped
650-
}
651-
652-
var persister = _parameters.SessionFactory.GetEntityPersister(entityName);
653-
var index = persister.EntityMetamodel.GetPropertyIndexOrNull(memberExpression.Member.Name);
654-
if (!index.HasValue || persister.EntityMetamodel.PropertyNullability[index.Value])
655-
{
656-
return true; // not mapped or nullable
657-
}
658-
659-
currentExpression = memberExpression.Expression;
660-
continue;
661-
case ExpressionType.Extension:
662-
switch (currentExpression)
663-
{
664-
case QuerySourceReferenceExpression querySourceReferenceExpression:
665-
switch (querySourceReferenceExpression.ReferencedQuerySource)
666-
{
667-
case MainFromClause _:
668-
return false; // we reached to the root expression, there were no nullable expressions
669-
case NhJoinClause joinClause:
670-
return IsNullable(joinClause.FromExpression, equalityExpression);
671-
default:
672-
return true; // unknown query source
673-
}
674-
case SubQueryExpression subQuery:
675-
if (subQuery.QueryModel.SelectClause.Selector is NhAggregatedExpression subQueryAggregatedExpression)
676-
{
677-
return subQueryAggregatedExpression.AllowsNullableReturnType;
678-
}
679-
else if (subQuery.QueryModel.ResultOperators.Any(o => NotNullOperators.Contains(o.GetType())))
680-
{
681-
return false;
682-
}
683-
684-
return true;
685-
case NhAggregatedExpression aggregatedExpression:
686-
return aggregatedExpression.AllowsNullableReturnType;
687-
default:
688-
return true; // a query can return null and we cannot calculate it as it is not yet executed
689-
}
690-
case ExpressionType.TypeIs: // an equal or in operator will be generated and those cannot return null
691-
case ExpressionType.NewArrayInit:
692-
return false;
693-
case ExpressionType.Constant:
694-
return VisitorUtil.IsNullConstant(currentExpression);
695-
case ExpressionType.Parameter:
696-
return !currentExpression.Type.IsValueType;
697-
default:
698-
return true;
699-
}
700-
}
701-
}
702-
703-
private bool AreEqual(MemberExpression memberExpression, MemberExpression otherMemberExpression)
704-
{
705-
if (memberExpression.Member != otherMemberExpression.Member ||
706-
memberExpression.Expression.NodeType != otherMemberExpression.Expression.NodeType)
707-
{
708-
return false;
709-
}
710-
711-
switch (memberExpression.Expression)
712-
{
713-
case QuerySourceReferenceExpression querySourceReferenceExpression:
714-
if (otherMemberExpression.Expression is QuerySourceReferenceExpression otherQuerySourceReferenceExpression)
715-
{
716-
return querySourceReferenceExpression.ReferencedQuerySource ==
717-
otherQuerySourceReferenceExpression.ReferencedQuerySource;
718-
}
719-
720-
return false;
721-
// Components have a nested member expression
722-
case MemberExpression nestedMemberExpression:
723-
if (otherMemberExpression.Expression is MemberExpression otherNestedMemberExpression)
724-
{
725-
return AreEqual(nestedMemberExpression, otherNestedMemberExpression);
726-
}
727-
728-
return false;
729-
default:
730-
return memberExpression.Expression == otherMemberExpression.Expression;
731-
}
732-
}
733-
734-
private string TryGetEntityName(MemberExpression memberExpression)
735-
{
736-
System.Type entityType;
737-
// Try to get the actual entity type from the query source if possbile as member can be declared
738-
// in a base type
739-
if (memberExpression.Expression is QuerySourceReferenceExpression querySourceReferenceExpression)
740-
{
741-
entityType = querySourceReferenceExpression.Type;
742-
}
743-
else
744-
{
745-
entityType = memberExpression.Member.ReflectedType;
746-
}
747-
748-
return _parameters.SessionFactory.TryGetGuessEntityName(entityType);
749-
}
750-
751465
protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression)
752466
{
753467
switch (expression.NodeType)

0 commit comments

Comments
 (0)