Skip to content

Commit 56e542e

Browse files
weelinkhazzik
authored andcommitted
Add query support for the static methods of System.Decimal (#1533)
Fixes #831
1 parent f81e6cc commit 56e542e

19 files changed

+489
-11
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using System;
2+
3+
namespace NHibernate.Test.NHSpecificTest.GH0831
4+
{
5+
class Entity
6+
{
7+
public virtual Guid Id { get; set; }
8+
public virtual decimal EntityValue { get; set; }
9+
10+
public override int GetHashCode()
11+
{
12+
return Id.GetHashCode();
13+
}
14+
15+
public override bool Equals(object obj)
16+
{
17+
var that = obj as Entity;
18+
19+
return (that != null) && Id.Equals(that.Id);
20+
}
21+
22+
public override string ToString()
23+
{
24+
return EntityValue.ToString();
25+
}
26+
}
27+
}
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Linq.Expressions;
5+
6+
using NHibernate.Cfg.MappingSchema;
7+
using NHibernate.Mapping.ByCode;
8+
9+
using NUnit.Framework;
10+
11+
namespace NHibernate.Test.NHSpecificTest.GH0831
12+
{
13+
public class ByCodeFixture : TestCaseMappingByCode
14+
{
15+
private readonly IList<Entity> entities = new List<Entity>
16+
{
17+
new Entity { EntityValue = 0.5m },
18+
new Entity { EntityValue = 1.0m },
19+
new Entity { EntityValue = 1.5m },
20+
new Entity { EntityValue = 2.0m },
21+
new Entity { EntityValue = 2.5m },
22+
new Entity { EntityValue = 3.0m }
23+
};
24+
25+
protected override HbmMapping GetMappings()
26+
{
27+
var mapper = new ModelMapper();
28+
mapper.Class<Entity>(rc =>
29+
{
30+
rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb));
31+
rc.Property(x => x.EntityValue);
32+
});
33+
34+
return mapper.CompileMappingForAllExplicitlyAddedEntities();
35+
}
36+
37+
protected override void OnSetUp()
38+
{
39+
using (ISession session = OpenSession())
40+
using (ITransaction transaction = session.BeginTransaction())
41+
{
42+
foreach (Entity entity in entities)
43+
{
44+
session.Save(entity);
45+
}
46+
47+
session.Flush();
48+
transaction.Commit();
49+
}
50+
}
51+
52+
protected override void OnTearDown()
53+
{
54+
using (ISession session = OpenSession())
55+
using (ITransaction transaction = session.BeginTransaction())
56+
{
57+
session.Delete("from System.Object");
58+
59+
session.Flush();
60+
transaction.Commit();
61+
}
62+
}
63+
64+
[Test]
65+
public void CanHandleAdd()
66+
{
67+
Assert.Multiple(() =>
68+
{
69+
CanFilter(e => decimal.Add(e.EntityValue, 2) > 3.0m);
70+
CanFilter(e => decimal.Add(2, e.EntityValue) > 3.0m);
71+
72+
CanSelect(e => decimal.Add(e.EntityValue, 2));
73+
CanSelect(e => decimal.Add(2, e.EntityValue));
74+
});
75+
}
76+
77+
[Test]
78+
public void CanHandleCeiling()
79+
{
80+
AssumeFunctionSupported("ceiling");
81+
82+
Assert.Multiple(() =>
83+
{
84+
CanFilter(e => decimal.Ceiling(e.EntityValue) > 1.0m);
85+
CanSelect(e => decimal.Ceiling(e.EntityValue));
86+
});
87+
}
88+
89+
[Test]
90+
public void CanHandleCompare()
91+
{
92+
AssumeFunctionSupported("sign");
93+
94+
Assert.Multiple(() =>
95+
{
96+
CanFilter(e => decimal.Compare(e.EntityValue, 1.5m) < 1);
97+
CanFilter(e => decimal.Compare(1.0m, e.EntityValue) < 1);
98+
99+
CanSelect(e => decimal.Compare(e.EntityValue, 1.5m));
100+
CanSelect(e => decimal.Compare(1.0m, e.EntityValue));
101+
});
102+
}
103+
104+
[Test]
105+
public void CanHandleDivide()
106+
{
107+
Assert.Multiple(() =>
108+
{
109+
CanFilter(e => decimal.Divide(e.EntityValue, 1.25m) < 1);
110+
CanFilter(e => decimal.Divide(1.25m, e.EntityValue) < 1);
111+
112+
CanSelect(e => decimal.Divide(e.EntityValue, 1.25m));
113+
CanSelect(e => decimal.Divide(1.25m, e.EntityValue));
114+
});
115+
}
116+
117+
[Test]
118+
public void CanHandleEquals()
119+
{
120+
Assert.Multiple(() =>
121+
{
122+
CanFilter(e => decimal.Equals(e.EntityValue, 1.0m));
123+
CanFilter(e => decimal.Equals(1.0m, e.EntityValue));
124+
});
125+
}
126+
127+
[Test]
128+
public void CanHandleFloor()
129+
{
130+
AssumeFunctionSupported("floor");
131+
132+
Assert.Multiple(() =>
133+
{
134+
CanFilter(e => decimal.Floor(e.EntityValue) > 1.0m);
135+
CanSelect(e => decimal.Floor(e.EntityValue));
136+
});
137+
}
138+
139+
[Test]
140+
public void CanHandleMultiply()
141+
{
142+
Assert.Multiple(() =>
143+
{
144+
CanFilter(e => decimal.Multiply(e.EntityValue, 10m) > 10m);
145+
CanFilter(e => decimal.Multiply(10m, e.EntityValue) > 10m);
146+
147+
CanSelect(e => decimal.Multiply(e.EntityValue, 10m));
148+
CanSelect(e => decimal.Multiply(10m, e.EntityValue));
149+
});
150+
}
151+
152+
[Test]
153+
public void CanHandleNegate()
154+
{
155+
Assert.Multiple(() =>
156+
{
157+
CanFilter(e => decimal.Negate(e.EntityValue) > -1.0m);
158+
CanSelect(e => decimal.Negate(e.EntityValue));
159+
});
160+
}
161+
162+
[Test]
163+
public void CanHandleRemainder()
164+
{
165+
Assume.That(TestDialect.SupportsModuloOnDecimal, Is.True);
166+
167+
Assert.Multiple(() =>
168+
{
169+
CanFilter(e => decimal.Remainder(e.EntityValue, 2m) == 0);
170+
CanFilter(e => decimal.Remainder(2m, e.EntityValue) < 1);
171+
172+
CanSelect(e => decimal.Remainder(e.EntityValue, 2m));
173+
CanSelect(e => decimal.Remainder(2m, e.EntityValue));
174+
});
175+
}
176+
177+
[Test]
178+
public void CanHandleRound()
179+
{
180+
AssumeFunctionSupported("round");
181+
182+
Assert.Multiple(() =>
183+
{
184+
CanFilter(e => decimal.Round(e.EntityValue) >= 2.0m);
185+
CanFilter(e => decimal.Round(e.EntityValue, 1) >= 1.5m);
186+
187+
// SQL round() always rounds up.
188+
CanSelect(e => decimal.Round(e.EntityValue), entities.Select(e => decimal.Round(e.EntityValue, MidpointRounding.AwayFromZero)));
189+
CanSelect(e => decimal.Round(e.EntityValue, 1), entities.Select(e => decimal.Round(e.EntityValue, 1, MidpointRounding.AwayFromZero)));
190+
});
191+
}
192+
193+
[Test]
194+
public void CanHandleSubtract()
195+
{
196+
Assert.Multiple(() =>
197+
{
198+
CanFilter(e => decimal.Subtract(e.EntityValue, 1m) > 1m);
199+
CanFilter(e => decimal.Subtract(2m, e.EntityValue) > 1m);
200+
201+
CanSelect(e => decimal.Subtract(e.EntityValue, 1m));
202+
CanSelect(e => decimal.Subtract(2m, e.EntityValue));
203+
});
204+
}
205+
206+
[Test]
207+
public void CanHandleTruncate()
208+
{
209+
AssumeFunctionSupported("truncate");
210+
211+
Assert.Multiple(() =>
212+
{
213+
CanFilter(e => decimal.Truncate(e.EntityValue) > 1m);
214+
CanSelect(e => decimal.Truncate(e.EntityValue));
215+
});
216+
}
217+
218+
private void CanFilter(Expression<Func<Entity, bool>> predicate)
219+
{
220+
using (ISession session = OpenSession())
221+
using (session.BeginTransaction())
222+
{
223+
IEnumerable<Entity> inMemory = entities.Where(predicate.Compile()).ToList();
224+
IEnumerable<Entity> inSession = session.Query<Entity>().Where(predicate).ToList();
225+
226+
CollectionAssert.AreEquivalent(inMemory, inSession);
227+
}
228+
}
229+
230+
private void CanSelect(Expression<Func<Entity, decimal>> predicate)
231+
{
232+
IEnumerable<decimal> inMemory = entities.Select(predicate.Compile()).ToList();
233+
234+
CanSelect(predicate, inMemory);
235+
}
236+
237+
private void CanSelect(Expression<Func<Entity, decimal>> predicate, IEnumerable<decimal> expected)
238+
{
239+
using (ISession session = OpenSession())
240+
using (session.BeginTransaction())
241+
{
242+
IEnumerable<decimal> inSession = null;
243+
Assert.That(() => inSession = session.Query<Entity>().Select(predicate).ToList(), Throws.Nothing);
244+
245+
Assert.That(inSession, Is.EquivalentTo(expected).Using((decimal a, decimal b) => Math.Abs(a - b) < 0.0001m));
246+
}
247+
}
248+
}
249+
}

src/NHibernate.Test/TestDialect.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,5 +75,10 @@ public bool SupportsSqlType(SqlType sqlType)
7575
return false;
7676
}
7777
}
78+
79+
/// <summary>
80+
/// Supports the modulo operator on decimal types
81+
/// </summary>
82+
public virtual bool SupportsModuloOnDecimal => true;
7883
}
7984
}

src/NHibernate.Test/TestDialects/FirebirdTestDialect.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,9 @@ public FirebirdTestDialect(Dialect.Dialect dialect) : base(dialect)
88

99
public override bool SupportsComplexExpressionInGroupBy => false;
1010
public override bool SupportsNonDataBoundCondition => false;
11+
/// <summary>
12+
/// Non-integer arguments are rounded before the division takes place. So, “7.5 mod 2.5” gives 2 (8 mod 3), not 0.
13+
/// </summary>
14+
public override bool SupportsModuloOnDecimal => false;
1115
}
1216
}

src/NHibernate.Test/TestDialects/MsSqlCe40TestDialect.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,10 @@ public MsSqlCe40TestDialect(Dialect.Dialect dialect) : base(dialect)
2525
public override bool SupportsDuplicatedColumnAliases => false;
2626

2727
public override bool SupportsEmptyInserts => false;
28+
29+
/// <summary>
30+
/// Modulo is not supported on real, float, money, and numeric data types. [ Data type = numeric ]
31+
/// </summary>
32+
public override bool SupportsModuloOnDecimal => false;
2833
}
2934
}

src/NHibernate.Test/TestDialects/SQLiteTestDialect.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,7 @@ public override bool SupportsHavingWithoutGroupBy
4444
{
4545
get { return false; }
4646
}
47+
48+
public override bool SupportsModuloOnDecimal => false;
4749
}
4850
}

src/NHibernate/Dialect/FirebirdDialect.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,8 @@ private void RegisterMathematicalFunctions()
464464
RegisterFunction("rand", new NoArgSQLFunction("rand", NHibernateUtil.Double));
465465
RegisterFunction("sign", new StandardSQLFunction("sign", NHibernateUtil.Int32));
466466
RegisterFunction("sqtr", new StandardSQLFunction("sqtr", NHibernateUtil.Double));
467-
RegisterFunction("truncate", new StandardSQLFunction("truncate"));
467+
RegisterFunction("trunc", new StandardSQLFunction("trunc"));
468+
RegisterFunction("truncate", new StandardSQLFunction("trunc"));
468469
RegisterFunction("floor", new StandardSQLFunction("floor"));
469470
RegisterFunction("round", new StandardSQLFunction("round"));
470471
}

src/NHibernate/Dialect/MsSql2000Dialect.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ protected virtual void RegisterFunctions()
287287
RegisterFunction("ceil", new StandardSQLFunction("ceiling"));
288288
RegisterFunction("floor", new StandardSQLFunction("floor"));
289289
RegisterFunction("round", new RoundEmulatingSingleParameterFunction());
290+
RegisterFunction("truncate", new SQLFunctionTemplate(null, "round(?1, ?2, 1)"));
290291

291292
RegisterFunction("power", new StandardSQLFunction("power", NHibernateUtil.Double));
292293

src/NHibernate/Dialect/MsSqlCeDialect.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ protected virtual void RegisterFunctions()
195195
RegisterFunction("mod", new SQLFunctionTemplate(NHibernateUtil.Int32, "((?1) % (?2))"));
196196

197197
RegisterFunction("round", new RoundEmulatingSingleParameterFunction());
198+
RegisterFunction("truncate", new SQLFunctionTemplate(null, "round(?1, ?2, 1)"));
198199

199200
RegisterFunction("bit_length", new SQLFunctionTemplate(NHibernateUtil.Int32, "datalength(?1) * 8"));
200201
RegisterFunction("extract", new SQLFunctionTemplate(NHibernateUtil.Int32, "datepart(?1, ?3)"));

src/NHibernate/Dialect/MySQLDialect.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,8 @@ protected virtual void RegisterFunctions()
261261
RegisterFunction("ceiling", new StandardSQLFunction("ceiling"));
262262
RegisterFunction("floor", new StandardSQLFunction("floor"));
263263
RegisterFunction("round", new StandardSQLFunction("round"));
264-
RegisterFunction("truncate", new StandardSQLFunction("truncate"));
265-
264+
RegisterFunction("truncate", new StandardSafeSQLFunction("truncate", 2));
265+
266266
RegisterFunction("rand", new NoArgSQLFunction("rand", NHibernateUtil.Double));
267267

268268
RegisterFunction("power", new StandardSQLFunction("power", NHibernateUtil.Double));

src/NHibernate/Dialect/Oracle8iDialect.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ protected virtual void RegisterFunctions()
229229

230230
RegisterFunction("round", new StandardSQLFunction("round"));
231231
RegisterFunction("trunc", new StandardSQLFunction("trunc"));
232+
RegisterFunction("truncate", new StandardSQLFunction("trunc"));
232233
RegisterFunction("ceil", new StandardSQLFunction("ceil"));
233234
RegisterFunction("ceiling", new StandardSQLFunction("ceil"));
234235
RegisterFunction("floor", new StandardSQLFunction("floor"));

src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,8 @@ protected virtual void RegisterMiscellaneousFunctions()
343343
RegisterFunction("transactsql", new StandardSQLFunction("transactsql", NHibernateUtil.String));
344344
RegisterFunction("varexists", new StandardSQLFunction("varexists", NHibernateUtil.Int32));
345345
RegisterFunction("watcomsql", new StandardSQLFunction("watcomsql", NHibernateUtil.String));
346+
RegisterFunction("truncnum", new StandardSafeSQLFunction("truncnum", 2));
347+
RegisterFunction("truncate", new StandardSafeSQLFunction("truncnum", 2));
346348
}
347349

348350
#region private static readonly string[] DialectKeywords = { ... }

src/NHibernate/Linq/Functions/CompareGenerator.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ internal class CompareGenerator : BaseHqlGeneratorForMethod, IRuntimeMethodHqlGe
3232

3333
ReflectHelper.GetMethodDefinition<float>(x => x.CompareTo(x)),
3434
ReflectHelper.GetMethodDefinition<double>(x => x.CompareTo(x)),
35+
36+
ReflectHelper.GetMethodDefinition(() => decimal.Compare(default(decimal), default(decimal))),
3537
ReflectHelper.GetMethodDefinition<decimal>(x => x.CompareTo(x)),
3638

3739
ReflectHelper.GetMethodDefinition<DateTime>(x => x.CompareTo(x)),

0 commit comments

Comments
 (0)