Skip to content

Commit e3794ee

Browse files
committed
Fix tests
1 parent c3c335b commit e3794ee

File tree

3 files changed

+98
-20
lines changed

3 files changed

+98
-20
lines changed

src/NHibernate.Test/Async/Linq/NullComparisonTests.cs

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
using System.Collections.Generic;
1313
using System.Linq;
1414
using System.Text;
15+
using NHibernate.Dialect;
1516
using NHibernate.Linq;
1617
using NHibernate.DomainModel.Northwind.Entities;
1718
using NUnit.Framework;
@@ -482,13 +483,17 @@ public async Task NullEqualityAsync()
482483
await (ExpectAsync(session.Query<User>().Where(o => o.Short == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")));
483484
await (ExpectAsync(session.Query<User>().Where(o => value == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")));
484485

485-
await (ExpectAsync(session.Query<User>().Where(o => o.NullableShort == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")));
486-
await (ExpectAsync(session.Query<User>().Where(o => 3 == o.NullableShort), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")));
487-
488-
await (ExpectAsync(session.Query<User>().Where(o => o.NullableShort.Value == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")));
489-
await (ExpectAsync(session.Query<User>().Where(o => 3 == o.NullableShort.Value), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")));
490-
await (ExpectAsync(session.Query<User>().Where(o => o.Short == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")));
491-
await (ExpectAsync(session.Query<User>().Where(o => 3 == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast")));
486+
var shouldCast = Sfi.Dialect is SQLiteDialect || // transparent cast is translated to sql
487+
Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int16.SqlType, out var shortCast) &&
488+
Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int32.SqlType, out var intCast) &&
489+
shortCast != intCast;
490+
await (ExpectAsync(session.Query<User>().Where(o => o.NullableShort == 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()));
491+
await (ExpectAsync(session.Query<User>().Where(o => 3 == o.NullableShort), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()));
492+
493+
await (ExpectAsync(session.Query<User>().Where(o => o.NullableShort.Value == 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()));
494+
await (ExpectAsync(session.Query<User>().Where(o => 3 == o.NullableShort.Value), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()));
495+
await (ExpectAsync(session.Query<User>().Where(o => o.Short == 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()));
496+
await (ExpectAsync(session.Query<User>().Where(o => 3 == o.Short), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()));
492497
}
493498

494499
[Test]
@@ -586,6 +591,38 @@ public async Task NullInequalityAsync()
586591
await (ExpectAsync(session.Query<User>().Where(o => value != o.NullableShort.Value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast")));
587592
await (ExpectAsync(session.Query<User>().Where(o => o.Short != value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")));
588593
await (ExpectAsync(session.Query<User>().Where(o => value != o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast")));
594+
595+
var shouldCast = Sfi.Dialect is SQLiteDialect || // transparent cast is translated to sql
596+
Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int16.SqlType, out var shortCast) &&
597+
Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int32.SqlType, out var intCast) &&
598+
shortCast != intCast;
599+
await (ExpectAsync(session.Query<User>().Where(o => o.NullableShort != 3), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast()));
600+
await (ExpectAsync(session.Query<User>().Where(o => 3 != o.NullableShort), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast()));
601+
602+
await (ExpectAsync(session.Query<User>().Where(o => o.NullableShort.Value != 3), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast()));
603+
await (ExpectAsync(session.Query<User>().Where(o => 3 != o.NullableShort.Value), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast()));
604+
await (ExpectAsync(session.Query<User>().Where(o => o.Short != 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()));
605+
await (ExpectAsync(session.Query<User>().Where(o => 3 != o.Short), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast()));
606+
}
607+
608+
private IResolveConstraint WithIsNullAndWithoutCast()
609+
{
610+
return Does.Contain("is null").IgnoreCase.And.Not.Contain("cast").IgnoreCase;
611+
}
612+
613+
private IResolveConstraint WithIsNullAndWithCast()
614+
{
615+
return Does.Contain("is null").IgnoreCase.And.Contain("cast").IgnoreCase;
616+
}
617+
618+
private IResolveConstraint WithoutIsNullAndWithoutCast()
619+
{
620+
return Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast").IgnoreCase;
621+
}
622+
623+
private IResolveConstraint WithoutIsNullAndWithCast()
624+
{
625+
return Does.Not.Contain("is null").IgnoreCase.And.Contain("cast").IgnoreCase;
589626
}
590627

591628
[Test]

src/NHibernate.Test/Linq/NullComparisonTests.cs

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System.Collections.Generic;
33
using System.Linq;
44
using System.Text;
5+
using NHibernate.Dialect;
56
using NHibernate.Linq;
67
using NHibernate.DomainModel.Northwind.Entities;
78
using NUnit.Framework;
@@ -470,13 +471,17 @@ public void NullEquality()
470471
Expect(session.Query<User>().Where(o => o.Short == value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"));
471472
Expect(session.Query<User>().Where(o => value == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"));
472473

473-
Expect(session.Query<User>().Where(o => o.NullableShort == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"));
474-
Expect(session.Query<User>().Where(o => 3 == o.NullableShort), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"));
475-
476-
Expect(session.Query<User>().Where(o => o.NullableShort.Value == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"));
477-
Expect(session.Query<User>().Where(o => 3 == o.NullableShort.Value), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"));
478-
Expect(session.Query<User>().Where(o => o.Short == 3), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"));
479-
Expect(session.Query<User>().Where(o => 3 == o.Short), Does.Not.Contain("is null").IgnoreCase.And.Contain("cast"));
474+
var shouldCast = Sfi.Dialect is SQLiteDialect || // transparent cast is translated to sql
475+
Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int16.SqlType, out var shortCast) &&
476+
Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int32.SqlType, out var intCast) &&
477+
shortCast != intCast;
478+
Expect(session.Query<User>().Where(o => o.NullableShort == 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast());
479+
Expect(session.Query<User>().Where(o => 3 == o.NullableShort), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast());
480+
481+
Expect(session.Query<User>().Where(o => o.NullableShort.Value == 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast());
482+
Expect(session.Query<User>().Where(o => 3 == o.NullableShort.Value), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast());
483+
Expect(session.Query<User>().Where(o => o.Short == 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast());
484+
Expect(session.Query<User>().Where(o => 3 == o.Short), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast());
480485
}
481486

482487
[Test]
@@ -574,6 +579,38 @@ public void NullInequality()
574579
Expect(session.Query<User>().Where(o => value != o.NullableShort.Value), Does.Contain("is null").IgnoreCase.And.Not.Contain("cast"));
575580
Expect(session.Query<User>().Where(o => o.Short != value), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"));
576581
Expect(session.Query<User>().Where(o => value != o.Short), Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast"));
582+
583+
var shouldCast = Sfi.Dialect is SQLiteDialect || // transparent cast is translated to sql
584+
Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int16.SqlType, out var shortCast) &&
585+
Sfi.Dialect.TryGetCastTypeName(NHibernateUtil.Int32.SqlType, out var intCast) &&
586+
shortCast != intCast;
587+
Expect(session.Query<User>().Where(o => o.NullableShort != 3), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast());
588+
Expect(session.Query<User>().Where(o => 3 != o.NullableShort), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast());
589+
590+
Expect(session.Query<User>().Where(o => o.NullableShort.Value != 3), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast());
591+
Expect(session.Query<User>().Where(o => 3 != o.NullableShort.Value), shouldCast ? WithIsNullAndWithCast() : WithIsNullAndWithoutCast());
592+
Expect(session.Query<User>().Where(o => o.Short != 3), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast());
593+
Expect(session.Query<User>().Where(o => 3 != o.Short), shouldCast ? WithoutIsNullAndWithCast() : WithoutIsNullAndWithoutCast());
594+
}
595+
596+
private IResolveConstraint WithIsNullAndWithoutCast()
597+
{
598+
return Does.Contain("is null").IgnoreCase.And.Not.Contain("cast").IgnoreCase;
599+
}
600+
601+
private IResolveConstraint WithIsNullAndWithCast()
602+
{
603+
return Does.Contain("is null").IgnoreCase.And.Contain("cast").IgnoreCase;
604+
}
605+
606+
private IResolveConstraint WithoutIsNullAndWithoutCast()
607+
{
608+
return Does.Not.Contain("is null").IgnoreCase.And.Not.Contain("cast").IgnoreCase;
609+
}
610+
611+
private IResolveConstraint WithoutIsNullAndWithCast()
612+
{
613+
return Does.Not.Contain("is null").IgnoreCase.And.Contain("cast").IgnoreCase;
577614
}
578615

579616
[Test]

src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public class HqlGeneratorExpressionVisitor : IHqlExpressionVisitor
2424
private readonly VisitorParameters _parameters;
2525
private readonly ILinqToHqlGeneratorsRegistry _functionRegistry;
2626
private readonly NullableExpressionDetector _nullableExpressionDetector;
27-
private readonly HashSet<Expression> _notCastableExpressions = new HashSet<Expression>();
27+
private readonly Dictionary<Expression, System.Type> _notCastableExpressions = new Dictionary<Expression, System.Type>();
2828

2929
public static HqlTreeNode Visit(Expression expression, VisitorParameters parameters)
3030
{
@@ -317,8 +317,9 @@ protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression)
317317
((UnaryExpression) expression.Left).Operand.Type.UnwrapIfNullable() ==
318318
((UnaryExpression) expression.Right).Operand.Type.UnwrapIfNullable())
319319
{
320-
_notCastableExpressions.Add(expression.Left);
321-
_notCastableExpressions.Add(expression.Right);
320+
var type = ((UnaryExpression) expression.Left).Operand.Type.UnwrapIfNullable();
321+
_notCastableExpressions.Add(expression.Left, type);
322+
_notCastableExpressions.Add(expression.Right, type);
322323
}
323324

324325
if (expression.NodeType == ExpressionType.Equal)
@@ -509,11 +510,14 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression)
509510
case ExpressionType.Convert:
510511
case ExpressionType.ConvertChecked:
511512
case ExpressionType.TypeAs:
512-
return IsCastRequired(expression.Operand, expression.Type, out var existType) && !_notCastableExpressions.Contains(expression)
513+
var notCastable = _notCastableExpressions.TryGetValue(expression, out var castType);
514+
castType = castType ?? expression.Type;
515+
516+
return IsCastRequired(expression.Operand, castType, out var existType) && !notCastable
513517
? _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), expression.Type)
514518
// Make a transparent cast when an IType exists, so that it can be used to retrieve the value from the data reader
515-
: existType && HqlIdent.SupportsType(expression.Type)
516-
? _hqlTreeBuilder.TransparentCast(VisitExpression(expression.Operand).AsExpression(), expression.Type)
519+
: existType && HqlIdent.SupportsType(castType)
520+
? _hqlTreeBuilder.TransparentCast(VisitExpression(expression.Operand).AsExpression(), castType)
517521
: VisitExpression(expression.Operand);
518522
}
519523

0 commit comments

Comments
 (0)