Skip to content

Commit f98059f

Browse files
committed
Merge branch 'NH-3092-3.3.x' into 3.3.x
2 parents 5b5cd30 + 8e29b9b commit f98059f

File tree

7 files changed

+221
-3
lines changed

7 files changed

+221
-3
lines changed

src/NHibernate.Test/Linq/MathTests.cs

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
using System;
2+
using System.Linq;
3+
using System.Linq.Expressions;
4+
using NHibernate.Dialect;
5+
using NHibernate.DomainModel.Northwind.Entities;
6+
using NUnit.Framework;
7+
8+
namespace NHibernate.Test.Linq
9+
{
10+
[TestFixture]
11+
public class MathTests : LinqTestCase
12+
{
13+
private IQueryable<OrderLine> _orderLines;
14+
15+
private void IgnoreIfNotSupported(string function)
16+
{
17+
if (!Dialect.Functions.ContainsKey(function))
18+
Assert.Ignore("Dialect {0} does not support '{1}' function", Dialect.GetType(), function);
19+
}
20+
21+
protected override void OnSetUp()
22+
{
23+
base.OnSetUp();
24+
_orderLines = db.OrderLines
25+
.OrderBy(ol => ol.Id)
26+
.Take(10).ToList().AsQueryable();
27+
}
28+
29+
[Test]
30+
public void SignAllPositiveTest()
31+
{
32+
IgnoreIfNotSupported("sign");
33+
var signs = (from o in db.OrderLines
34+
select Math.Sign(o.UnitPrice)).ToList();
35+
36+
Assert.True(signs.All(x => x == 1));
37+
}
38+
39+
[Test]
40+
public void SignAllNegativeTest()
41+
{
42+
IgnoreIfNotSupported("sign");
43+
var signs = (from o in db.OrderLines
44+
select Math.Sign(0m - o.UnitPrice)).ToList();
45+
46+
Assert.True(signs.All(x => x == -1));
47+
}
48+
49+
[Test]
50+
public void SinTest()
51+
{
52+
IgnoreIfNotSupported("sin");
53+
Test(o => Math.Round(Math.Sin((double) o.UnitPrice), 5));
54+
}
55+
56+
[Test]
57+
public void CosTest()
58+
{
59+
IgnoreIfNotSupported("cos");
60+
Test(o => Math.Round(Math.Cos((double)o.UnitPrice), 5));
61+
}
62+
63+
[Test]
64+
public void TanTest()
65+
{
66+
IgnoreIfNotSupported("tan");
67+
Test(o => Math.Round(Math.Tan((double)o.Discount), 5));
68+
}
69+
70+
[Test]
71+
public void SinhTest()
72+
{
73+
IgnoreIfNotSupported("sinh");
74+
Test(o => Math.Round(Math.Sinh((double)o.Discount), 5));
75+
}
76+
77+
[Test]
78+
public void CoshTest()
79+
{
80+
IgnoreIfNotSupported("cosh");
81+
Test(o => Math.Round(Math.Cosh((double)o.Discount), 5));
82+
}
83+
84+
[Test]
85+
public void TanhTest()
86+
{
87+
IgnoreIfNotSupported("tanh");
88+
Test(o => Math.Round(Math.Tanh((double)o.Discount), 5));
89+
}
90+
91+
[Test]
92+
public void AsinTest()
93+
{
94+
IgnoreIfNotSupported("asin");
95+
Test(o => Math.Round(Math.Asin((double)o.Discount), 5));
96+
}
97+
98+
[Test]
99+
public void AcosTest()
100+
{
101+
IgnoreIfNotSupported("acos");
102+
Test(o => Math.Round(Math.Acos((double)o.Discount), 5));
103+
}
104+
105+
[Test]
106+
public void AtanTest()
107+
{
108+
IgnoreIfNotSupported("atan");
109+
Test(o => Math.Round(Math.Atan((double)o.UnitPrice), 5));
110+
}
111+
112+
[Test]
113+
public void Atan2Test()
114+
{
115+
IgnoreIfNotSupported("atan2");
116+
if (Dialect is Oracle8iDialect)
117+
Assert.Ignore("Fails on Oracle due to NH-3381.");
118+
119+
Test(o => Math.Round(Math.Atan2((double)o.Discount, 0.5d), 5));
120+
}
121+
122+
private void Test(Expression<Func<OrderLine, double>> selector)
123+
{
124+
var expected = _orderLines
125+
.Select(selector)
126+
.ToList();
127+
128+
var actual = db.OrderLines
129+
.OrderBy(ol => ol.Id)
130+
.Select(selector)
131+
.Take(10)
132+
.ToList();
133+
134+
Assert.AreEqual(expected.Count, actual.Count);
135+
for (var i = 0; i < expected.Count; i++)
136+
Assert.AreEqual(expected[i], actual[i], 0.000001);
137+
}
138+
}
139+
}

src/NHibernate.Test/NHibernate.Test.csproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,9 +506,10 @@
506506
<Compile Include="Linq\CollectionAssert.cs" />
507507
<Compile Include="Linq\LoggingTests.cs" />
508508
<Compile Include="Linq\QueryTimeoutTests.cs" />
509+
<Compile Include="Linq\DateTimeTests.cs" />
509510
<Compile Include="Linq\JoinTests.cs" />
510511
<Compile Include="Linq\CustomExtensionsExample.cs" />
511-
<Compile Include="Linq\DateTimeTests.cs" />
512+
<Compile Include="Linq\MathTests.cs" />
512513
<Compile Include="Linq\DynamicQueryTests.cs" />
513514
<Compile Include="Linq\EagerLoadTests.cs" />
514515
<Compile Include="Linq\EnumTests.cs" />

src/NHibernate/Dialect/FirebirdDialect.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ public FirebirdDialect()
9393
RegisterFunction("log10", new StandardSQLFunction("log10", NHibernateUtil.Double));
9494
RegisterFunction("pi", new NoArgSQLFunction("pi", NHibernateUtil.Double));
9595
RegisterFunction("rand", new NoArgSQLFunction("rand", NHibernateUtil.Double));
96-
RegisterFunction("sing", new StandardSQLFunction("sing", NHibernateUtil.Double));
96+
RegisterFunction("sign", new StandardSQLFunction("sign", NHibernateUtil.Int32));
9797
RegisterFunction("sqtr", new StandardSQLFunction("sqtr", NHibernateUtil.Double));
9898
RegisterFunction("truncate", new StandardSQLFunction("truncate"));
9999
RegisterFunction("floor", new StandardSafeSQLFunction("floor", NHibernateUtil.Double, 1));

src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ public DefaultLinqToHqlGeneratorsRegistry()
3636
this.Merge(new ReplaceGenerator());
3737
this.Merge(new LengthGenerator());
3838
this.Merge(new TrimGenerator());
39+
this.Merge(new MathGenerator());
3940

4041
this.Merge(new AnyHqlGenerator());
4142
this.Merge(new AllHqlGenerator());
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
using System;
2+
using System.Collections.ObjectModel;
3+
using System.Linq.Expressions;
4+
using System.Reflection;
5+
using NHibernate.Hql.Ast;
6+
using NHibernate.Linq.Visitors;
7+
8+
namespace NHibernate.Linq.Functions
9+
{
10+
public class MathGenerator : BaseHqlGeneratorForMethod
11+
{
12+
public MathGenerator()
13+
{
14+
SupportedMethods = new[]
15+
{
16+
ReflectionHelper.GetMethodDefinition(() => Math.Sin(default(double))),
17+
ReflectionHelper.GetMethodDefinition(() => Math.Cos(default(double))),
18+
ReflectionHelper.GetMethodDefinition(() => Math.Tan(default(double))),
19+
20+
ReflectionHelper.GetMethodDefinition(() => Math.Sinh(default(double))),
21+
ReflectionHelper.GetMethodDefinition(() => Math.Cosh(default(double))),
22+
ReflectionHelper.GetMethodDefinition(() => Math.Tanh(default(double))),
23+
24+
ReflectionHelper.GetMethodDefinition(() => Math.Asin(default(double))),
25+
ReflectionHelper.GetMethodDefinition(() => Math.Acos(default(double))),
26+
ReflectionHelper.GetMethodDefinition(() => Math.Atan(default(double))),
27+
ReflectionHelper.GetMethodDefinition(() => Math.Atan2(default(double), default(double))),
28+
29+
ReflectionHelper.GetMethodDefinition(() => Math.Sqrt(default(double))),
30+
31+
ReflectionHelper.GetMethodDefinition(() => Math.Abs(default(decimal))),
32+
ReflectionHelper.GetMethodDefinition(() => Math.Abs(default(double))),
33+
ReflectionHelper.GetMethodDefinition(() => Math.Abs(default(float))),
34+
ReflectionHelper.GetMethodDefinition(() => Math.Abs(default(long))),
35+
ReflectionHelper.GetMethodDefinition(() => Math.Abs(default(int))),
36+
ReflectionHelper.GetMethodDefinition(() => Math.Abs(default(short))),
37+
ReflectionHelper.GetMethodDefinition(() => Math.Abs(default(sbyte))),
38+
39+
ReflectionHelper.GetMethodDefinition(() => Math.Sign(default(decimal))),
40+
ReflectionHelper.GetMethodDefinition(() => Math.Sign(default(double))),
41+
ReflectionHelper.GetMethodDefinition(() => Math.Sign(default(float))),
42+
ReflectionHelper.GetMethodDefinition(() => Math.Sign(default(long))),
43+
ReflectionHelper.GetMethodDefinition(() => Math.Sign(default(int))),
44+
ReflectionHelper.GetMethodDefinition(() => Math.Sign(default(short))),
45+
ReflectionHelper.GetMethodDefinition(() => Math.Sign(default(sbyte))),
46+
47+
ReflectionHelper.GetMethodDefinition(() => Math.Round(default(decimal))),
48+
ReflectionHelper.GetMethodDefinition(() => Math.Round(default(decimal), default(int))),
49+
ReflectionHelper.GetMethodDefinition(() => Math.Round(default(double))),
50+
ReflectionHelper.GetMethodDefinition(() => Math.Round(default(double), default(int))),
51+
ReflectionHelper.GetMethodDefinition(() => Math.Floor(default(decimal))),
52+
ReflectionHelper.GetMethodDefinition(() => Math.Floor(default(double))),
53+
ReflectionHelper.GetMethodDefinition(() => Math.Ceiling(default(decimal))),
54+
ReflectionHelper.GetMethodDefinition(() => Math.Ceiling(default(double))),
55+
ReflectionHelper.GetMethodDefinition(() => Math.Truncate(default(decimal))),
56+
ReflectionHelper.GetMethodDefinition(() => Math.Truncate(default(double))),
57+
};
58+
}
59+
60+
public override HqlTreeNode BuildHql(MethodInfo method, Expression expression, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
61+
{
62+
string function = method.Name.ToLowerInvariant();
63+
HqlExpression firstArgument = visitor.Visit(arguments[0]).AsExpression();
64+
65+
if (arguments.Count == 2)
66+
{
67+
return treeBuilder.MethodCall(function, firstArgument, visitor.Visit(arguments[1]).AsExpression());
68+
}
69+
70+
return treeBuilder.MethodCall(function, firstArgument);
71+
}
72+
}
73+
}

src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,11 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression)
341341
case ExpressionType.Convert:
342342
case ExpressionType.ConvertChecked:
343343
case ExpressionType.TypeAs:
344-
if (expression.Operand.Type.IsPrimitive && expression.Type.IsPrimitive)
344+
if ((expression.Operand.Type.IsPrimitive || expression.Operand.Type == typeof(Decimal)) &&
345+
(expression.Type.IsPrimitive || expression.Type == typeof(Decimal)))
346+
{
345347
return _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), expression.Type);
348+
}
346349

347350
return VisitExpression(expression.Operand);
348351
}

src/NHibernate/NHibernate.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@
292292
<Compile Include="Linq\ExpressionTransformers\RemoveRedundantCast.cs" />
293293
<Compile Include="Linq\ExpressionTransformers\SimplifyCompareTransformer.cs" />
294294
<Compile Include="Linq\Functions\CompareGenerator.cs" />
295+
<Compile Include="Linq\Functions\MathGenerator.cs" />
295296
<Compile Include="Linq\Functions\DictionaryGenerator.cs" />
296297
<Compile Include="Linq\Functions\EqualsGenerator.cs" />
297298
<Compile Include="Linq\GroupBy\PagingRewriter.cs" />

0 commit comments

Comments
 (0)