Skip to content

Commit c3c335b

Browse files
committed
Avoid unnecessary casting for Linq provider
1 parent 1df24c6 commit c3c335b

File tree

9 files changed

+109
-3
lines changed

9 files changed

+109
-3
lines changed

src/NHibernate.DomainModel/Northwind/Entities/User.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ public class User : IUser, IEntity
5050

5151
public virtual User NotMappedUser => this;
5252

53+
public virtual short Short { get; set; }
54+
55+
public virtual short? NullableShort { get; set; }
56+
5357
public virtual EnumStoredAsString Enum1 { get; set; }
5458

5559
public virtual EnumStoredAsString? NullableEnum1 { get; set; }

src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
<column name="ModifiedById" />
2121
</many-to-one>
2222

23+
<property name="Short" column="Enum2" insert="false" update="false" not-null="true" />
24+
<property name="NullableShort" formula="(case when Enum2 = 0 then null else Enum2 end)" insert="false" update="false" />
25+
2326
<property name="Enum1" type="NHibernate.DomainModel.Northwind.Entities.EnumStoredAsStringType, NHibernate.DomainModel">
2427
<column name="Enum1" length="12" />
2528
</property>

src/NHibernate.Test/Async/Linq/NullComparisonTests.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,23 @@ public async Task NullEqualityAsync()
472472

473473
await (ExpectAsync(session.Query<User>().Where(o => o.CreatedBy.ModifiedBy.Id == 5), Does.Not.Contain("is null").IgnoreCase));
474474
await (ExpectAsync(session.Query<User>().Where(o => 5 == o.CreatedBy.ModifiedBy.Id), Does.Not.Contain("is null").IgnoreCase));
475+
476+
short value = 3;
477+
await (ExpectAsync(session.Query<User>().Where(o => o.NullableShort == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")));
478+
await (ExpectAsync(session.Query<User>().Where(o => value == o.NullableShort), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")));
479+
480+
await (ExpectAsync(session.Query<User>().Where(o => o.NullableShort.Value == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")));
481+
await (ExpectAsync(session.Query<User>().Where(o => value == o.NullableShort.Value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")));
482+
await (ExpectAsync(session.Query<User>().Where(o => o.Short == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")));
483+
await (ExpectAsync(session.Query<User>().Where(o => value == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")));
484+
485+
await (ExpectAsync(session.Query<User>().Where(o => o.NullableShort == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")));
486+
await (ExpectAsync(session.Query<User>().Where(o => 3 == o.NullableShort), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")));
487+
488+
await (ExpectAsync(session.Query<User>().Where(o => o.NullableShort.Value == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")));
489+
await (ExpectAsync(session.Query<User>().Where(o => 3 == o.NullableShort.Value), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")));
490+
await (ExpectAsync(session.Query<User>().Where(o => o.Short == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")));
491+
await (ExpectAsync(session.Query<User>().Where(o => 3 == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")));
475492
}
476493

477494
[Test]
@@ -560,6 +577,15 @@ public async Task NullInequalityAsync()
560577

561578
await (ExpectAsync(session.Query<User>().Where(o => o.CreatedBy.ModifiedBy.Id != 5), Does.Contain("is null").IgnoreCase));
562579
await (ExpectAsync(session.Query<User>().Where(o => 5 != o.CreatedBy.ModifiedBy.Id), Does.Contain("is null").IgnoreCase));
580+
581+
short value = 3;
582+
await (ExpectAsync(session.Query<User>().Where(o => o.NullableShort != value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")));
583+
await (ExpectAsync(session.Query<User>().Where(o => value != o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")));
584+
585+
await (ExpectAsync(session.Query<User>().Where(o => o.NullableShort.Value != value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")));
586+
await (ExpectAsync(session.Query<User>().Where(o => value != o.NullableShort.Value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")));
587+
await (ExpectAsync(session.Query<User>().Where(o => o.Short != value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")));
588+
await (ExpectAsync(session.Query<User>().Where(o => value != o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")));
563589
}
564590

565591
[Test]

src/NHibernate.Test/Async/Linq/ParameterTests.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,15 @@ public async Task UsingValueTypeParameterTwiceAsync()
125125
1));
126126
}
127127

128+
[Test]
129+
public async Task UsingValueTypeParameterTwiceOnNullablePropertyAsync()
130+
{
131+
short value = 1;
132+
await (AssertTotalParametersAsync(
133+
db.Users.Where(o => o.NullableShort == value && o.NullableShort != value && o.Short == value),
134+
1));
135+
}
136+
128137
[Test]
129138
public async Task UsingParameterInEvaluatableExpressionAsync()
130139
{

src/NHibernate.Test/Linq/NullComparisonTests.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,23 @@ public void NullEquality()
460460

461461
Expect(session.Query<User>().Where(o => o.CreatedBy.ModifiedBy.Id == 5), Does.Not.Contain("is null").IgnoreCase);
462462
Expect(session.Query<User>().Where(o => 5 == o.CreatedBy.ModifiedBy.Id), Does.Not.Contain("is null").IgnoreCase);
463+
464+
short value = 3;
465+
Expect(session.Query<User>().Where(o => o.NullableShort == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"));
466+
Expect(session.Query<User>().Where(o => value == o.NullableShort), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"));
467+
468+
Expect(session.Query<User>().Where(o => o.NullableShort.Value == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"));
469+
Expect(session.Query<User>().Where(o => value == o.NullableShort.Value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"));
470+
Expect(session.Query<User>().Where(o => o.Short == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"));
471+
Expect(session.Query<User>().Where(o => value == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"));
472+
473+
Expect(session.Query<User>().Where(o => o.NullableShort == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"));
474+
Expect(session.Query<User>().Where(o => 3 == o.NullableShort), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"));
475+
476+
Expect(session.Query<User>().Where(o => o.NullableShort.Value == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"));
477+
Expect(session.Query<User>().Where(o => 3 == o.NullableShort.Value), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"));
478+
Expect(session.Query<User>().Where(o => o.Short == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"));
479+
Expect(session.Query<User>().Where(o => 3 == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"));
463480
}
464481

465482
[Test]
@@ -548,6 +565,15 @@ public void NullInequality()
548565

549566
Expect(session.Query<User>().Where(o => o.CreatedBy.ModifiedBy.Id != 5), Does.Contain("is null").IgnoreCase);
550567
Expect(session.Query<User>().Where(o => 5 != o.CreatedBy.ModifiedBy.Id), Does.Contain("is null").IgnoreCase);
568+
569+
short value = 3;
570+
Expect(session.Query<User>().Where(o => o.NullableShort != value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"));
571+
Expect(session.Query<User>().Where(o => value != o.NullableShort), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"));
572+
573+
Expect(session.Query<User>().Where(o => o.NullableShort.Value != value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"));
574+
Expect(session.Query<User>().Where(o => value != o.NullableShort.Value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"));
575+
Expect(session.Query<User>().Where(o => o.Short != value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"));
576+
Expect(session.Query<User>().Where(o => value != o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"));
551577
}
552578

553579
[Test]

src/NHibernate.Test/Linq/ParameterTests.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,15 @@ public void UsingValueTypeParameterTwice()
113113
1);
114114
}
115115

116+
[Test]
117+
public void UsingValueTypeParameterTwiceOnNullableProperty()
118+
{
119+
short value = 1;
120+
AssertTotalParameters(
121+
db.Users.Where(o => o.NullableShort == value && o.NullableShort != value && o.Short == value),
122+
1);
123+
}
124+
116125
[Test]
117126
public void UsingParameterInEvaluatableExpression()
118127
{

src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System.Linq.Expressions;
2+
using NHibernate.Util;
23
using Remotion.Linq.Parsing.ExpressionVisitors.Transformation;
34

45
namespace NHibernate.Linq.ExpressionTransformers
@@ -26,6 +27,13 @@ public Expression Transform(UnaryExpression expression)
2627
return expression.Operand;
2728
}
2829

30+
// Reduce double casting (e.g. (long?)(long)3 => (long?)3)
31+
if (expression.Operand.NodeType == ExpressionType.Convert &&
32+
expression.Type.UnwrapIfNullable() == expression.Operand.Type)
33+
{
34+
return Expression.Convert(((UnaryExpression) expression.Operand).Operand, expression.Type);
35+
}
36+
2937
return expression;
3038
}
3139

@@ -34,4 +42,4 @@ public ExpressionType[] SupportedExpressionTypes
3442
get { return _supportedExpressionTypes; }
3543
}
3644
}
37-
}
45+
}

src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Collections.Generic;
23
using System.Data;
34
using System.Dynamic;
45
using System.Linq;
@@ -23,6 +24,7 @@ public class HqlGeneratorExpressionVisitor : IHqlExpressionVisitor
2324
private readonly VisitorParameters _parameters;
2425
private readonly ILinqToHqlGeneratorsRegistry _functionRegistry;
2526
private readonly NullableExpressionDetector _nullableExpressionDetector;
27+
private readonly HashSet<Expression> _notCastableExpressions = new HashSet<Expression>();
2628

2729
public static HqlTreeNode Visit(Expression expression, VisitorParameters parameters)
2830
{
@@ -308,6 +310,17 @@ private HqlTreeNode VisitVBStringComparisonExpression(VBStringComparisonExpressi
308310

309311
protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression)
310312
{
313+
// In .NET some numeric types do not have thier own operators (e.g. short == short is converted to (int) short == (int) short),
314+
// in such case we dont want to add a sql cast
315+
if (expression.Left.NodeType == ExpressionType.Convert &&
316+
expression.Right.NodeType == ExpressionType.Convert &&
317+
((UnaryExpression) expression.Left).Operand.Type.UnwrapIfNullable() ==
318+
((UnaryExpression) expression.Right).Operand.Type.UnwrapIfNullable())
319+
{
320+
_notCastableExpressions.Add(expression.Left);
321+
_notCastableExpressions.Add(expression.Right);
322+
}
323+
311324
if (expression.NodeType == ExpressionType.Equal)
312325
{
313326
return TranslateEqualityComparison(expression);
@@ -496,7 +509,7 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression)
496509
case ExpressionType.Convert:
497510
case ExpressionType.ConvertChecked:
498511
case ExpressionType.TypeAs:
499-
return IsCastRequired(expression.Operand, expression.Type, out var existType)
512+
return IsCastRequired(expression.Operand, expression.Type, out var existType) && !_notCastableExpressions.Contains(expression)
500513
? _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), expression.Type)
501514
// Make a transparent cast when an IType exists, so that it can be used to retrieve the value from the data reader
502515
: existType && HqlIdent.SupportsType(expression.Type)

src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,15 @@ public override Expression Visit(Expression expression)
8181
#region NH additions
8282
// Variables should be evaluated only when they are part of an evaluatable expression (e.g. o => string.Format("...", variable))
8383
expression is UnaryExpression unaryExpression &&
84-
ExpressionsHelper.IsVariable(unaryExpression.Operand, out _, out _))
84+
(
85+
ExpressionsHelper.IsVariable(unaryExpression.Operand, out _, out _) ||
86+
// Check whether the variable is casted due to comparison with a nullable expression
87+
// (e.g. o.NullableShort == shortVariable)
88+
unaryExpression.Operand is UnaryExpression subUnaryExpression &&
89+
unaryExpression.Type.UnwrapIfNullable() == subUnaryExpression.Type &&
90+
ExpressionsHelper.IsVariable(subUnaryExpression.Operand, out _, out _)
91+
)
92+
)
8593
#endregion
8694
return base.Visit(expression);
8795

0 commit comments

Comments
 (0)