Skip to content

Commit 545d05a

Browse files
authored
Fix casting properties from object type in LINQ (#3015)
Fixes #3005
1 parent b809f77 commit 545d05a

File tree

4 files changed

+157
-10
lines changed

4 files changed

+157
-10
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
//------------------------------------------------------------------------------
2+
// <auto-generated>
3+
// This code was generated by AsyncGenerator.
4+
//
5+
// Changes to this file may cause incorrect behavior and will be lost if
6+
// the code is regenerated.
7+
// </auto-generated>
8+
//------------------------------------------------------------------------------
9+
10+
11+
using System;
12+
using System.Linq;
13+
using NHibernate.Cfg.MappingSchema;
14+
using NHibernate.Mapping.ByCode;
15+
using NUnit.Framework;
16+
using NHibernate.Linq;
17+
18+
namespace NHibernate.Test.NHSpecificTest.GH3005
19+
{
20+
using System.Threading.Tasks;
21+
[TestFixture]
22+
public class ByCodeFixtureAsync : TestCaseMappingByCode
23+
{
24+
protected override HbmMapping GetMappings()
25+
{
26+
var mapper = new ModelMapper();
27+
mapper.Class<Entity>(rc =>
28+
{
29+
rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb));
30+
rc.Property(x => x.Name);
31+
rc.Property(x => x.Duration);
32+
});
33+
34+
return mapper.CompileMappingForAllExplicitlyAddedEntities();
35+
}
36+
37+
protected override void OnSetUp()
38+
{
39+
using (var session = OpenSession())
40+
using (var transaction = session.BeginTransaction())
41+
{
42+
var e1 = new Entity { Name = "Bob", Duration = TimeSpan.FromMinutes(1) };
43+
session.Save(e1);
44+
45+
transaction.Commit();
46+
}
47+
}
48+
49+
protected override void OnTearDown()
50+
{
51+
using (var session = OpenSession())
52+
using (var transaction = session.BeginTransaction())
53+
{
54+
session.CreateQuery("delete from System.Object").ExecuteUpdate();
55+
56+
transaction.Commit();
57+
}
58+
}
59+
60+
[Test]
61+
public async Task CanCastFromObjectAsync()
62+
{
63+
using (var session = OpenSession())
64+
{
65+
var result = await (session.Query<Entity>().Select(x => (TimeSpan)(object)x.Duration).FirstOrDefaultAsync());
66+
67+
Assert.That(result, Is.EqualTo(TimeSpan.FromMinutes(1)));
68+
}
69+
}
70+
}
71+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using System;
2+
3+
namespace NHibernate.Test.NHSpecificTest.GH3005
4+
{
5+
class Entity
6+
{
7+
public virtual Guid Id { get; set; }
8+
public virtual string Name { get; set; }
9+
public virtual TimeSpan Duration { get; set; }
10+
}
11+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
using System;
2+
using System.Linq;
3+
using NHibernate.Cfg.MappingSchema;
4+
using NHibernate.Mapping.ByCode;
5+
using NUnit.Framework;
6+
7+
namespace NHibernate.Test.NHSpecificTest.GH3005
8+
{
9+
[TestFixture]
10+
public class ByCodeFixture : TestCaseMappingByCode
11+
{
12+
protected override HbmMapping GetMappings()
13+
{
14+
var mapper = new ModelMapper();
15+
mapper.Class<Entity>(rc =>
16+
{
17+
rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb));
18+
rc.Property(x => x.Name);
19+
rc.Property(x => x.Duration);
20+
});
21+
22+
return mapper.CompileMappingForAllExplicitlyAddedEntities();
23+
}
24+
25+
protected override void OnSetUp()
26+
{
27+
using (var session = OpenSession())
28+
using (var transaction = session.BeginTransaction())
29+
{
30+
var e1 = new Entity { Name = "Bob", Duration = TimeSpan.FromMinutes(1) };
31+
session.Save(e1);
32+
33+
transaction.Commit();
34+
}
35+
}
36+
37+
protected override void OnTearDown()
38+
{
39+
using (var session = OpenSession())
40+
using (var transaction = session.BeginTransaction())
41+
{
42+
session.CreateQuery("delete from System.Object").ExecuteUpdate();
43+
44+
transaction.Commit();
45+
}
46+
}
47+
48+
[Test]
49+
public void CanCastFromObject()
50+
{
51+
using (var session = OpenSession())
52+
{
53+
var result = session.Query<Entity>().Select(x => (TimeSpan)(object)x.Duration).FirstOrDefault();
54+
55+
Assert.That(result, Is.EqualTo(TimeSpan.FromMinutes(1)));
56+
}
57+
}
58+
}
59+
}

src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,11 @@ protected HqlTreeNode VisitNhAverage(NhAverageExpression expression)
245245
// otherwise the result may be incorrect. In SQL Server avg always returns int
246246
// when the argument is int.
247247
var hqlExpression = VisitExpression(expression.Expression).AsExpression();
248-
hqlExpression = IsCastRequired(expression.Expression, expression.Type, out _)
249-
? (HqlExpression) _hqlTreeBuilder.Cast(hqlExpression, expression.Type)
250-
: _hqlTreeBuilder.TransparentCast(hqlExpression, expression.Type);
248+
hqlExpression = IsCastRequired(expression.Expression, expression.Type, out var needTransparentCast)
249+
? _hqlTreeBuilder.Cast(hqlExpression, expression.Type)
250+
: needTransparentCast
251+
? _hqlTreeBuilder.TransparentCast(hqlExpression, expression.Type)
252+
: hqlExpression;
251253

252254
// In Oracle the avg function can return a number with up to 40 digits which cannot be retrieved from the data reader due to the lack of such
253255
// numeric type in .NET. In order to avoid that we have to add a cast to trim the number so that it can be converted into a .NET numeric type.
@@ -532,10 +534,10 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression)
532534
castType = expression.Type;
533535
}
534536

535-
return IsCastRequired(expression.Operand, castType, out var existType) && castable
537+
return IsCastRequired(expression.Operand, castType, out var needTransparentCast) && castable
536538
? _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), castType)
537539
// Make a transparent cast when an IType exists, so that it can be used to retrieve the value from the data reader
538-
: existType && HqlIdent.SupportsType(castType)
540+
: needTransparentCast
539541
? _hqlTreeBuilder.TransparentCast(VisitExpression(expression.Operand).AsExpression(), castType)
540542
: VisitExpression(expression.Operand);
541543
}
@@ -643,12 +645,16 @@ protected HqlTreeNode VisitNewArrayExpression(NewArrayExpression expression)
643645
return _hqlTreeBuilder.ExpressionSubTreeHolder(expressionSubTree);
644646
}
645647

646-
private bool IsCastRequired(Expression expression, System.Type toType, out bool existType)
648+
private bool IsCastRequired(Expression expression, System.Type toType, out bool needTransparentCast)
647649
{
648-
existType = false;
649-
return toType != typeof(object) &&
650-
expression.Type.UnwrapIfNullable() != toType.UnwrapIfNullable() &&
651-
IsCastRequired(ExpressionsHelper.GetType(_parameters, expression), TypeFactory.GetDefaultTypeFor(toType), out existType);
650+
needTransparentCast =
651+
toType != typeof(object)
652+
&& expression.Type != typeof(object)
653+
&& expression.Type != toType
654+
&& HqlIdent.SupportsType(toType)
655+
&& expression.Type.UnwrapIfNullable() != toType.UnwrapIfNullable();
656+
657+
return needTransparentCast && IsCastRequired(ExpressionsHelper.GetType(_parameters, expression), TypeFactory.GetDefaultTypeFor(toType), out needTransparentCast);
652658
}
653659

654660
private bool IsCastRequired(IType type, IType toType, out bool existType)

0 commit comments

Comments
 (0)