Skip to content

Commit 293c2eb

Browse files
NH-3787 - Decimal truncation in Linq ternary expression (#707)
* Fixes #1335
1 parent 433921a commit 293c2eb

File tree

12 files changed

+349
-4
lines changed

12 files changed

+349
-4
lines changed
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
//------------------------------------------------------------------------------
2+
// <auto-generated>
3+
// This code was generated by AsyncGenerator.
4+
//
5+
// Changes to this file may cause incorrect behavior and will be lost if
6+
// the code is regenerated.
7+
// </auto-generated>
8+
//------------------------------------------------------------------------------
9+
10+
11+
using System.Linq;
12+
using NHibernate.Criterion;
13+
using NHibernate.Linq;
14+
using NHibernate.Transform;
15+
using NHibernate.Type;
16+
using NUnit.Framework;
17+
18+
namespace NHibernate.Test.NHSpecificTest.NH3787
19+
{
20+
using System.Threading.Tasks;
21+
[TestFixture]
22+
public class TestFixtureAsync : BugTestCase
23+
{
24+
private const decimal _testRate = 12345.1234567890123M;
25+
26+
protected override bool AppliesTo(Dialect.Dialect dialect)
27+
{
28+
return !TestDialect.HasBrokenDecimalType;
29+
}
30+
31+
protected override void OnSetUp()
32+
{
33+
base.OnSetUp();
34+
35+
using (var s = OpenSession())
36+
using (var t = s.BeginTransaction())
37+
{
38+
var testEntity = new TestEntity
39+
{
40+
UsePreviousRate = true,
41+
PreviousRate = _testRate,
42+
Rate = 54321.1234567890123M
43+
};
44+
s.Save(testEntity);
45+
t.Commit();
46+
}
47+
}
48+
49+
protected override void OnTearDown()
50+
{
51+
using (var s = OpenSession())
52+
using (var t = s.BeginTransaction())
53+
{
54+
s.CreateQuery("delete from TestEntity").ExecuteUpdate();
55+
t.Commit();
56+
}
57+
}
58+
59+
[Test]
60+
public async Task TestLinqQueryAsync()
61+
{
62+
using (var s = OpenSession())
63+
using (var t = s.BeginTransaction())
64+
{
65+
var queryResult = await (s
66+
.Query<TestEntity>()
67+
.Where(e => e.PreviousRate == _testRate)
68+
.ToListAsync());
69+
70+
Assert.That(queryResult.Count, Is.EqualTo(1));
71+
Assert.That(queryResult[0].PreviousRate, Is.EqualTo(_testRate));
72+
await (t.CommitAsync());
73+
}
74+
}
75+
76+
[Test]
77+
public async Task TestLinqProjectionAsync()
78+
{
79+
using (var s = OpenSession())
80+
using (var t = s.BeginTransaction())
81+
{
82+
var queryResult = await ((from test in s.Query<TestEntity>()
83+
select new RateDto { Rate = test.UsePreviousRate ? test.PreviousRate : test.Rate }).ToListAsync());
84+
85+
// Check it has not been truncated to the default scale (10) of NHibernate.
86+
Assert.That(queryResult[0].Rate, Is.EqualTo(_testRate));
87+
await (t.CommitAsync());
88+
}
89+
}
90+
91+
[Test]
92+
public async Task TestLinqQueryOnExpressionAsync()
93+
{
94+
using (var s = OpenSession())
95+
using (var t = s.BeginTransaction())
96+
{
97+
var queryResult = await (s
98+
.Query<TestEntity>()
99+
.Where(
100+
// Without MappedAs, the test fails for SQL Server because it would restrict its parameter to the dialect's default scale.
101+
e => (e.UsePreviousRate ? e.PreviousRate : e.Rate) == _testRate.MappedAs(TypeFactory.Basic("decimal(18,13)")))
102+
.ToListAsync());
103+
104+
Assert.That(queryResult.Count, Is.EqualTo(1));
105+
Assert.That(queryResult[0].PreviousRate, Is.EqualTo(_testRate));
106+
await (t.CommitAsync());
107+
}
108+
}
109+
110+
[Test]
111+
public async Task TestQueryOverProjectionAsync()
112+
{
113+
using (var s = OpenSession())
114+
using (var t = s.BeginTransaction())
115+
{
116+
TestEntity testEntity = null;
117+
118+
var rateDto = new RateDto();
119+
//Generated sql
120+
//exec sp_executesql N'SELECT (case when this_.UsePreviousRate = @p0 then this_.PreviousRate else this_.Rate end) as y0_ FROM [TestEntity] this_',N'@p0 bit',@p0=1
121+
var query = s
122+
.QueryOver(() => testEntity)
123+
.Select(
124+
Projections
125+
.Alias(
126+
Projections.Conditional(
127+
Restrictions.Eq(Projections.Property(() => testEntity.UsePreviousRate), true),
128+
Projections.Property(() => testEntity.PreviousRate),
129+
Projections.Property(() => testEntity.Rate)),
130+
"Rate")
131+
.WithAlias(() => rateDto.Rate));
132+
133+
var queryResult = await (query.TransformUsing(Transformers.AliasToBean<RateDto>()).ListAsync<RateDto>());
134+
135+
Assert.That(queryResult[0].Rate, Is.EqualTo(_testRate));
136+
await (t.CommitAsync());
137+
}
138+
}
139+
}
140+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
<?xml version="1.0" encoding="utf-8" ?>
2+
<hibernate-mapping xmlns="urn:nhibernate-mapping-2.2" assembly="NHibernate.Test"
3+
namespace="NHibernate.Test.NHSpecificTest.NH3787">
4+
<class name="TestEntity" table="TestEntity">
5+
<id name="Id">
6+
<generator class="native"/>
7+
</id>
8+
<property name="UsePreviousRate" type="boolean" not-null="true"/>
9+
<property name="PreviousRate" type="decimal(18,13)" not-null="true"/>
10+
<property name="Rate" type="decimal(18,13)" not-null="true"/>
11+
</class>
12+
</hibernate-mapping>
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
namespace NHibernate.Test.NHSpecificTest.NH3787
2+
{
3+
public class RateDto
4+
{
5+
public decimal Rate { get; set; }
6+
}
7+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
namespace NHibernate.Test.NHSpecificTest.NH3787
2+
{
3+
public class TestEntity
4+
{
5+
public virtual int Id { get; set; }
6+
public virtual bool UsePreviousRate { get; set; }
7+
public virtual decimal Rate { get; set; }
8+
public virtual decimal PreviousRate { get; set; }
9+
}
10+
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
using System.Linq;
2+
using NHibernate.Criterion;
3+
using NHibernate.Linq;
4+
using NHibernate.Transform;
5+
using NHibernate.Type;
6+
using NUnit.Framework;
7+
8+
namespace NHibernate.Test.NHSpecificTest.NH3787
9+
{
10+
[TestFixture]
11+
public class TestFixture : BugTestCase
12+
{
13+
private const decimal _testRate = 12345.1234567890123M;
14+
15+
protected override bool AppliesTo(Dialect.Dialect dialect)
16+
{
17+
return !TestDialect.HasBrokenDecimalType;
18+
}
19+
20+
protected override void OnSetUp()
21+
{
22+
base.OnSetUp();
23+
24+
using (var s = OpenSession())
25+
using (var t = s.BeginTransaction())
26+
{
27+
var testEntity = new TestEntity
28+
{
29+
UsePreviousRate = true,
30+
PreviousRate = _testRate,
31+
Rate = 54321.1234567890123M
32+
};
33+
s.Save(testEntity);
34+
t.Commit();
35+
}
36+
}
37+
38+
protected override void OnTearDown()
39+
{
40+
using (var s = OpenSession())
41+
using (var t = s.BeginTransaction())
42+
{
43+
s.CreateQuery("delete from TestEntity").ExecuteUpdate();
44+
t.Commit();
45+
}
46+
}
47+
48+
[Test]
49+
public void TestLinqQuery()
50+
{
51+
using (var s = OpenSession())
52+
using (var t = s.BeginTransaction())
53+
{
54+
var queryResult = s
55+
.Query<TestEntity>()
56+
.Where(e => e.PreviousRate == _testRate)
57+
.ToList();
58+
59+
Assert.That(queryResult.Count, Is.EqualTo(1));
60+
Assert.That(queryResult[0].PreviousRate, Is.EqualTo(_testRate));
61+
t.Commit();
62+
}
63+
}
64+
65+
[Test]
66+
public void TestLinqProjection()
67+
{
68+
using (var s = OpenSession())
69+
using (var t = s.BeginTransaction())
70+
{
71+
var queryResult = (from test in s.Query<TestEntity>()
72+
select new RateDto { Rate = test.UsePreviousRate ? test.PreviousRate : test.Rate }).ToList();
73+
74+
// Check it has not been truncated to the default scale (10) of NHibernate.
75+
Assert.That(queryResult[0].Rate, Is.EqualTo(_testRate));
76+
t.Commit();
77+
}
78+
}
79+
80+
[Test]
81+
public void TestLinqQueryOnExpression()
82+
{
83+
using (var s = OpenSession())
84+
using (var t = s.BeginTransaction())
85+
{
86+
var queryResult = s
87+
.Query<TestEntity>()
88+
.Where(
89+
// Without MappedAs, the test fails for SQL Server because it would restrict its parameter to the dialect's default scale.
90+
e => (e.UsePreviousRate ? e.PreviousRate : e.Rate) == _testRate.MappedAs(TypeFactory.Basic("decimal(18,13)")))
91+
.ToList();
92+
93+
Assert.That(queryResult.Count, Is.EqualTo(1));
94+
Assert.That(queryResult[0].PreviousRate, Is.EqualTo(_testRate));
95+
t.Commit();
96+
}
97+
}
98+
99+
[Test]
100+
public void TestQueryOverProjection()
101+
{
102+
using (var s = OpenSession())
103+
using (var t = s.BeginTransaction())
104+
{
105+
TestEntity testEntity = null;
106+
107+
var rateDto = new RateDto();
108+
//Generated sql
109+
//exec sp_executesql N'SELECT (case when this_.UsePreviousRate = @p0 then this_.PreviousRate else this_.Rate end) as y0_ FROM [TestEntity] this_',N'@p0 bit',@p0=1
110+
var query = s
111+
.QueryOver(() => testEntity)
112+
.Select(
113+
Projections
114+
.Alias(
115+
Projections.Conditional(
116+
Restrictions.Eq(Projections.Property(() => testEntity.UsePreviousRate), true),
117+
Projections.Property(() => testEntity.PreviousRate),
118+
Projections.Property(() => testEntity.Rate)),
119+
"Rate")
120+
.WithAlias(() => rateDto.Rate));
121+
122+
var queryResult = query.TransformUsing(Transformers.AliasToBean<RateDto>()).List<RateDto>();
123+
124+
Assert.That(queryResult[0].Rate, Is.EqualTo(_testRate));
125+
t.Commit();
126+
}
127+
}
128+
}
129+
}

src/NHibernate/Dialect/Dialect.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ protected Dialect()
9898
RegisterFunction("upper", new StandardSQLFunction("upper"));
9999
RegisterFunction("lower", new StandardSQLFunction("lower"));
100100
RegisterFunction("cast", new CastFunction());
101+
RegisterFunction("transparentcast", new TransparentCastFunction());
101102
RegisterFunction("extract", new AnsiExtractFunction());
102103
RegisterFunction("concat", new VarArgsSQLFunction(NHibernateUtil.String, "(", "||", ")"));
103104

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using System;
2+
3+
namespace NHibernate.Dialect.Function
4+
{
5+
/// <summary>
6+
/// A HQL only cast for helping HQL knowing the type. Does not generates any actual cast in SQL code.
7+
/// </summary>
8+
[Serializable]
9+
public class TransparentCastFunction : CastFunction
10+
{
11+
protected override bool CastingIsRequired(string sqlType)
12+
{
13+
return false;
14+
}
15+
}
16+
}

src/NHibernate/Dialect/SQLiteDialect.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ protected virtual void RegisterFunctions()
8686
RegisterFunction("cast", new SQLiteCastFunction());
8787

8888
RegisterFunction("round", new StandardSQLFunction("round"));
89+
90+
// NH-3787: SQLite requires the cast in SQL too for not defaulting to string.
91+
RegisterFunction("transparentcast", new CastFunction());
8992
}
9093

9194
#region private static readonly string[] DialectKeywords = { ... }

src/NHibernate/Hql/Ast/ANTLR/SessionFactoryHelperExtensions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ public IType FindFunctionReturnType(String functionName, IASTNode first)
7777

7878
if (first != null)
7979
{
80-
if (functionName == "cast")
80+
if (sqlFunction is CastFunction)
8181
{
8282
argumentType = TypeFactory.HeuristicType(first.NextSibling.Text);
8383
}

src/NHibernate/Hql/Ast/HqlTreeBuilder.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,17 @@ public HqlCast Cast(HqlExpression expression, System.Type type)
301301
return new HqlCast(_factory, expression, type);
302302
}
303303

304+
/// <summary>
305+
/// Generate a cast node intended solely to hint HQL at the resulting type, without issuing an actual SQL cast.
306+
/// </summary>
307+
/// <param name="expression">The expression to cast.</param>
308+
/// <param name="type">The resulting type.</param>
309+
/// <returns>A <see cref="HqlTransparentCast"/> node.</returns>
310+
public HqlTransparentCast TransparentCast(HqlExpression expression, System.Type type)
311+
{
312+
return new HqlTransparentCast(_factory, expression, type);
313+
}
314+
304315
public HqlBitwiseNot BitwiseNot()
305316
{
306317
return new HqlBitwiseNot(_factory);

src/NHibernate/Hql/Ast/HqlTreeNode.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,19 @@ public HqlCast(IASTFactory factory, HqlExpression expression, System.Type type)
701701
}
702702
}
703703

704+
/// <summary>
705+
/// Cast node intended solely to hint HQL at the resulting type, without issuing an actual SQL cast.
706+
/// </summary>
707+
public class HqlTransparentCast : HqlExpression
708+
{
709+
public HqlTransparentCast(IASTFactory factory, HqlExpression expression, System.Type type)
710+
: base(HqlSqlWalker.METHOD_CALL, "method", factory)
711+
{
712+
AddChild(new HqlIdent(factory, "transparentcast"));
713+
AddChild(new HqlExpressionList(factory, expression, new HqlIdent(factory, type)));
714+
}
715+
}
716+
704717
public class HqlCoalesce : HqlExpression
705718
{
706719
public HqlCoalesce(IASTFactory factory, HqlExpression lhs, HqlExpression rhs)

0 commit comments

Comments
 (0)