diff --git a/src/NHibernate.Test/NHSpecificTest/GH0831/Entity.cs b/src/NHibernate.Test/NHSpecificTest/GH0831/Entity.cs new file mode 100644 index 00000000000..9ffba39ddd0 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH0831/Entity.cs @@ -0,0 +1,27 @@ +using System; + +namespace NHibernate.Test.NHSpecificTest.GH0831 +{ + class Entity + { + public virtual Guid Id { get; set; } + public virtual decimal EntityValue { get; set; } + + public override int GetHashCode() + { + return Id.GetHashCode(); + } + + public override bool Equals(object obj) + { + var that = obj as Entity; + + return (that != null) && Id.Equals(that.Id); + } + + public override string ToString() + { + return EntityValue.ToString(); + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH0831/FixtureByCode.cs b/src/NHibernate.Test/NHSpecificTest/GH0831/FixtureByCode.cs new file mode 100644 index 00000000000..b0dced8ea2f --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH0831/FixtureByCode.cs @@ -0,0 +1,249 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; + +using NHibernate.Cfg.MappingSchema; +using NHibernate.Mapping.ByCode; + +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH0831 +{ + public class ByCodeFixture : TestCaseMappingByCode + { + private readonly IList entities = new List + { + new Entity { EntityValue = 0.5m }, + new Entity { EntityValue = 1.0m }, + new Entity { EntityValue = 1.5m }, + new Entity { EntityValue = 2.0m }, + new Entity { EntityValue = 2.5m }, + new Entity { EntityValue = 3.0m } + }; + + protected override HbmMapping GetMappings() + { + var mapper = new ModelMapper(); + mapper.Class(rc => + { + rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb)); + rc.Property(x => x.EntityValue); + }); + + return mapper.CompileMappingForAllExplicitlyAddedEntities(); + } + + protected override void OnSetUp() + { + using (ISession session = OpenSession()) + using (ITransaction transaction = session.BeginTransaction()) + { + foreach (Entity entity in entities) + { + session.Save(entity); + } + + session.Flush(); + transaction.Commit(); + } + } + + protected override void OnTearDown() + { + using (ISession session = OpenSession()) + using (ITransaction transaction = session.BeginTransaction()) + { + session.Delete("from System.Object"); + + session.Flush(); + transaction.Commit(); + } + } + + [Test] + public void CanHandleAdd() + { + Assert.Multiple(() => + { + CanFilter(e => decimal.Add(e.EntityValue, 2) > 3.0m); + CanFilter(e => decimal.Add(2, e.EntityValue) > 3.0m); + + CanSelect(e => decimal.Add(e.EntityValue, 2)); + CanSelect(e => decimal.Add(2, e.EntityValue)); + }); + } + + [Test] + public void CanHandleCeiling() + { + AssumeFunctionSupported("ceiling"); + + Assert.Multiple(() => + { + CanFilter(e => decimal.Ceiling(e.EntityValue) > 1.0m); + CanSelect(e => decimal.Ceiling(e.EntityValue)); + }); + } + + [Test] + public void CanHandleCompare() + { + AssumeFunctionSupported("sign"); + + Assert.Multiple(() => + { + CanFilter(e => decimal.Compare(e.EntityValue, 1.5m) < 1); + CanFilter(e => decimal.Compare(1.0m, e.EntityValue) < 1); + + CanSelect(e => decimal.Compare(e.EntityValue, 1.5m)); + CanSelect(e => decimal.Compare(1.0m, e.EntityValue)); + }); + } + + [Test] + public void CanHandleDivide() + { + Assert.Multiple(() => + { + CanFilter(e => decimal.Divide(e.EntityValue, 1.25m) < 1); + CanFilter(e => decimal.Divide(1.25m, e.EntityValue) < 1); + + CanSelect(e => decimal.Divide(e.EntityValue, 1.25m)); + CanSelect(e => decimal.Divide(1.25m, e.EntityValue)); + }); + } + + [Test] + public void CanHandleEquals() + { + Assert.Multiple(() => + { + CanFilter(e => decimal.Equals(e.EntityValue, 1.0m)); + CanFilter(e => decimal.Equals(1.0m, e.EntityValue)); + }); + } + + [Test] + public void CanHandleFloor() + { + AssumeFunctionSupported("floor"); + + Assert.Multiple(() => + { + CanFilter(e => decimal.Floor(e.EntityValue) > 1.0m); + CanSelect(e => decimal.Floor(e.EntityValue)); + }); + } + + [Test] + public void CanHandleMultiply() + { + Assert.Multiple(() => + { + CanFilter(e => decimal.Multiply(e.EntityValue, 10m) > 10m); + CanFilter(e => decimal.Multiply(10m, e.EntityValue) > 10m); + + CanSelect(e => decimal.Multiply(e.EntityValue, 10m)); + CanSelect(e => decimal.Multiply(10m, e.EntityValue)); + }); + } + + [Test] + public void CanHandleNegate() + { + Assert.Multiple(() => + { + CanFilter(e => decimal.Negate(e.EntityValue) > -1.0m); + CanSelect(e => decimal.Negate(e.EntityValue)); + }); + } + + [Test] + public void CanHandleRemainder() + { + Assume.That(TestDialect.SupportsModuloOnDecimal, Is.True); + + Assert.Multiple(() => + { + CanFilter(e => decimal.Remainder(e.EntityValue, 2m) == 0); + CanFilter(e => decimal.Remainder(2m, e.EntityValue) < 1); + + CanSelect(e => decimal.Remainder(e.EntityValue, 2m)); + CanSelect(e => decimal.Remainder(2m, e.EntityValue)); + }); + } + + [Test] + public void CanHandleRound() + { + AssumeFunctionSupported("round"); + + Assert.Multiple(() => + { + CanFilter(e => decimal.Round(e.EntityValue) >= 2.0m); + CanFilter(e => decimal.Round(e.EntityValue, 1) >= 1.5m); + + // SQL round() always rounds up. + CanSelect(e => decimal.Round(e.EntityValue), entities.Select(e => decimal.Round(e.EntityValue, MidpointRounding.AwayFromZero))); + CanSelect(e => decimal.Round(e.EntityValue, 1), entities.Select(e => decimal.Round(e.EntityValue, 1, MidpointRounding.AwayFromZero))); + }); + } + + [Test] + public void CanHandleSubtract() + { + Assert.Multiple(() => + { + CanFilter(e => decimal.Subtract(e.EntityValue, 1m) > 1m); + CanFilter(e => decimal.Subtract(2m, e.EntityValue) > 1m); + + CanSelect(e => decimal.Subtract(e.EntityValue, 1m)); + CanSelect(e => decimal.Subtract(2m, e.EntityValue)); + }); + } + + [Test] + public void CanHandleTruncate() + { + AssumeFunctionSupported("truncate"); + + Assert.Multiple(() => + { + CanFilter(e => decimal.Truncate(e.EntityValue) > 1m); + CanSelect(e => decimal.Truncate(e.EntityValue)); + }); + } + + private void CanFilter(Expression> predicate) + { + using (ISession session = OpenSession()) + using (session.BeginTransaction()) + { + IEnumerable inMemory = entities.Where(predicate.Compile()).ToList(); + IEnumerable inSession = session.Query().Where(predicate).ToList(); + + CollectionAssert.AreEquivalent(inMemory, inSession); + } + } + + private void CanSelect(Expression> predicate) + { + IEnumerable inMemory = entities.Select(predicate.Compile()).ToList(); + + CanSelect(predicate, inMemory); + } + + private void CanSelect(Expression> predicate, IEnumerable expected) + { + using (ISession session = OpenSession()) + using (session.BeginTransaction()) + { + IEnumerable inSession = null; + Assert.That(() => inSession = session.Query().Select(predicate).ToList(), Throws.Nothing); + + Assert.That(inSession, Is.EquivalentTo(expected).Using((decimal a, decimal b) => Math.Abs(a - b) < 0.0001m)); + } + } + } +} diff --git a/src/NHibernate.Test/TestDialect.cs b/src/NHibernate.Test/TestDialect.cs index 9d6c321ad15..68efde5d7f3 100644 --- a/src/NHibernate.Test/TestDialect.cs +++ b/src/NHibernate.Test/TestDialect.cs @@ -75,5 +75,10 @@ public bool SupportsSqlType(SqlType sqlType) return false; } } + + /// + /// Supports the modulo operator on decimal types + /// + public virtual bool SupportsModuloOnDecimal => true; } } diff --git a/src/NHibernate.Test/TestDialects/FirebirdTestDialect.cs b/src/NHibernate.Test/TestDialects/FirebirdTestDialect.cs index 96e4181b1cd..09b259a98f5 100644 --- a/src/NHibernate.Test/TestDialects/FirebirdTestDialect.cs +++ b/src/NHibernate.Test/TestDialects/FirebirdTestDialect.cs @@ -8,5 +8,9 @@ public FirebirdTestDialect(Dialect.Dialect dialect) : base(dialect) public override bool SupportsComplexExpressionInGroupBy => false; public override bool SupportsNonDataBoundCondition => false; + /// + /// Non-integer arguments are rounded before the division takes place. So, “7.5 mod 2.5” gives 2 (8 mod 3), not 0. + /// + public override bool SupportsModuloOnDecimal => false; } } diff --git a/src/NHibernate.Test/TestDialects/MsSqlCe40TestDialect.cs b/src/NHibernate.Test/TestDialects/MsSqlCe40TestDialect.cs index 8f384459d76..89156042f21 100644 --- a/src/NHibernate.Test/TestDialects/MsSqlCe40TestDialect.cs +++ b/src/NHibernate.Test/TestDialects/MsSqlCe40TestDialect.cs @@ -25,5 +25,10 @@ public MsSqlCe40TestDialect(Dialect.Dialect dialect) : base(dialect) public override bool SupportsDuplicatedColumnAliases => false; public override bool SupportsEmptyInserts => false; + + /// + /// Modulo is not supported on real, float, money, and numeric data types. [ Data type = numeric ] + /// + public override bool SupportsModuloOnDecimal => false; } } diff --git a/src/NHibernate.Test/TestDialects/SQLiteTestDialect.cs b/src/NHibernate.Test/TestDialects/SQLiteTestDialect.cs index 95c77873d2f..da04e00f2f0 100644 --- a/src/NHibernate.Test/TestDialects/SQLiteTestDialect.cs +++ b/src/NHibernate.Test/TestDialects/SQLiteTestDialect.cs @@ -44,5 +44,7 @@ public override bool SupportsHavingWithoutGroupBy { get { return false; } } + + public override bool SupportsModuloOnDecimal => false; } } diff --git a/src/NHibernate/Dialect/FirebirdDialect.cs b/src/NHibernate/Dialect/FirebirdDialect.cs index 5fa709e15c9..5883dfe2344 100644 --- a/src/NHibernate/Dialect/FirebirdDialect.cs +++ b/src/NHibernate/Dialect/FirebirdDialect.cs @@ -464,7 +464,8 @@ private void RegisterMathematicalFunctions() RegisterFunction("rand", new NoArgSQLFunction("rand", NHibernateUtil.Double)); RegisterFunction("sign", new StandardSQLFunction("sign", NHibernateUtil.Int32)); RegisterFunction("sqtr", new StandardSQLFunction("sqtr", NHibernateUtil.Double)); - RegisterFunction("truncate", new StandardSQLFunction("truncate")); + RegisterFunction("trunc", new StandardSQLFunction("trunc")); + RegisterFunction("truncate", new StandardSQLFunction("trunc")); RegisterFunction("floor", new StandardSQLFunction("floor")); RegisterFunction("round", new StandardSQLFunction("round")); } diff --git a/src/NHibernate/Dialect/MsSql2000Dialect.cs b/src/NHibernate/Dialect/MsSql2000Dialect.cs index 050d65e142d..3553d5923a6 100644 --- a/src/NHibernate/Dialect/MsSql2000Dialect.cs +++ b/src/NHibernate/Dialect/MsSql2000Dialect.cs @@ -287,6 +287,7 @@ protected virtual void RegisterFunctions() RegisterFunction("ceil", new StandardSQLFunction("ceiling")); RegisterFunction("floor", new StandardSQLFunction("floor")); RegisterFunction("round", new RoundEmulatingSingleParameterFunction()); + RegisterFunction("truncate", new SQLFunctionTemplate(null, "round(?1, ?2, 1)")); RegisterFunction("power", new StandardSQLFunction("power", NHibernateUtil.Double)); diff --git a/src/NHibernate/Dialect/MsSqlCeDialect.cs b/src/NHibernate/Dialect/MsSqlCeDialect.cs index 7d7e81b1a0e..3aebc47e5e3 100644 --- a/src/NHibernate/Dialect/MsSqlCeDialect.cs +++ b/src/NHibernate/Dialect/MsSqlCeDialect.cs @@ -195,6 +195,7 @@ protected virtual void RegisterFunctions() RegisterFunction("mod", new SQLFunctionTemplate(NHibernateUtil.Int32, "((?1) % (?2))")); RegisterFunction("round", new RoundEmulatingSingleParameterFunction()); + RegisterFunction("truncate", new SQLFunctionTemplate(null, "round(?1, ?2, 1)")); RegisterFunction("bit_length", new SQLFunctionTemplate(NHibernateUtil.Int32, "datalength(?1) * 8")); RegisterFunction("extract", new SQLFunctionTemplate(NHibernateUtil.Int32, "datepart(?1, ?3)")); diff --git a/src/NHibernate/Dialect/MySQLDialect.cs b/src/NHibernate/Dialect/MySQLDialect.cs index f38759e7964..3292dabaffd 100644 --- a/src/NHibernate/Dialect/MySQLDialect.cs +++ b/src/NHibernate/Dialect/MySQLDialect.cs @@ -261,8 +261,8 @@ protected virtual void RegisterFunctions() RegisterFunction("ceiling", new StandardSQLFunction("ceiling")); RegisterFunction("floor", new StandardSQLFunction("floor")); RegisterFunction("round", new StandardSQLFunction("round")); - RegisterFunction("truncate", new StandardSQLFunction("truncate")); - + RegisterFunction("truncate", new StandardSafeSQLFunction("truncate", 2)); + RegisterFunction("rand", new NoArgSQLFunction("rand", NHibernateUtil.Double)); RegisterFunction("power", new StandardSQLFunction("power", NHibernateUtil.Double)); diff --git a/src/NHibernate/Dialect/Oracle8iDialect.cs b/src/NHibernate/Dialect/Oracle8iDialect.cs index 321237c3a05..7120cd13624 100644 --- a/src/NHibernate/Dialect/Oracle8iDialect.cs +++ b/src/NHibernate/Dialect/Oracle8iDialect.cs @@ -229,6 +229,7 @@ protected virtual void RegisterFunctions() RegisterFunction("round", new StandardSQLFunction("round")); RegisterFunction("trunc", new StandardSQLFunction("trunc")); + RegisterFunction("truncate", new StandardSQLFunction("trunc")); RegisterFunction("ceil", new StandardSQLFunction("ceil")); RegisterFunction("ceiling", new StandardSQLFunction("ceil")); RegisterFunction("floor", new StandardSQLFunction("floor")); diff --git a/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs b/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs index 5758cd5da31..09f8c6491b6 100644 --- a/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs +++ b/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs @@ -343,6 +343,8 @@ protected virtual void RegisterMiscellaneousFunctions() RegisterFunction("transactsql", new StandardSQLFunction("transactsql", NHibernateUtil.String)); RegisterFunction("varexists", new StandardSQLFunction("varexists", NHibernateUtil.Int32)); RegisterFunction("watcomsql", new StandardSQLFunction("watcomsql", NHibernateUtil.String)); + RegisterFunction("truncnum", new StandardSafeSQLFunction("truncnum", 2)); + RegisterFunction("truncate", new StandardSafeSQLFunction("truncnum", 2)); } #region private static readonly string[] DialectKeywords = { ... } diff --git a/src/NHibernate/Linq/Functions/CompareGenerator.cs b/src/NHibernate/Linq/Functions/CompareGenerator.cs index 173faf7c0c5..7e88f9c8e18 100644 --- a/src/NHibernate/Linq/Functions/CompareGenerator.cs +++ b/src/NHibernate/Linq/Functions/CompareGenerator.cs @@ -32,6 +32,8 @@ internal class CompareGenerator : BaseHqlGeneratorForMethod, IRuntimeMethodHqlGe ReflectHelper.GetMethodDefinition(x => x.CompareTo(x)), ReflectHelper.GetMethodDefinition(x => x.CompareTo(x)), + + ReflectHelper.GetMethodDefinition(() => decimal.Compare(default(decimal), default(decimal))), ReflectHelper.GetMethodDefinition(x => x.CompareTo(x)), ReflectHelper.GetMethodDefinition(x => x.CompareTo(x)), diff --git a/src/NHibernate/Linq/Functions/DecimalGenerator.cs b/src/NHibernate/Linq/Functions/DecimalGenerator.cs new file mode 100644 index 00000000000..f79a3bdcb12 --- /dev/null +++ b/src/NHibernate/Linq/Functions/DecimalGenerator.cs @@ -0,0 +1,106 @@ +using System.Collections.ObjectModel; +using System.Linq.Expressions; +using System.Reflection; + +using NHibernate.Hql.Ast; +using NHibernate.Linq.Visitors; +using NHibernate.Util; + +namespace NHibernate.Linq.Functions +{ + internal class DecimalAddGenerator : BaseHqlGeneratorForMethod + { + public DecimalAddGenerator() + { + SupportedMethods = new[] + { + ReflectHelper.GetMethodDefinition(() => decimal.Add(default(decimal), default(decimal))) + }; + } + + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) + { + return treeBuilder.TransparentCast(treeBuilder.Add(visitor.Visit(arguments[0]).AsExpression(), visitor.Visit(arguments[1]).AsExpression()), typeof(decimal)); + } + } + + internal class DecimalDivideGenerator : BaseHqlGeneratorForMethod + { + public DecimalDivideGenerator() + { + SupportedMethods = new[] + { + ReflectHelper.GetMethodDefinition(() => decimal.Divide(default(decimal), default(decimal))) + }; + } + + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) + { + return treeBuilder.Cast(treeBuilder.Divide(visitor.Visit(arguments[0]).AsExpression(), visitor.Visit(arguments[1]).AsExpression()), typeof(decimal)); + } + } + + internal class DecimalMultiplyGenerator : BaseHqlGeneratorForMethod + { + public DecimalMultiplyGenerator() + { + SupportedMethods = new[] + { + ReflectHelper.GetMethodDefinition(() => decimal.Multiply(default(decimal), default(decimal))) + }; + } + + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) + { + return treeBuilder.TransparentCast(treeBuilder.Multiply(visitor.Visit(arguments[0]).AsExpression(), visitor.Visit(arguments[1]).AsExpression()), typeof(decimal)); + } + } + + internal class DecimalSubtractGenerator : BaseHqlGeneratorForMethod + { + public DecimalSubtractGenerator() + { + SupportedMethods = new[] + { + ReflectHelper.GetMethodDefinition(() => decimal.Subtract(default(decimal), default(decimal))) + }; + } + + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) + { + return treeBuilder.TransparentCast(treeBuilder.Subtract(visitor.Visit(arguments[0]).AsExpression(), visitor.Visit(arguments[1]).AsExpression()), typeof(decimal)); + } + } + + internal class DecimalRemainderGenerator : BaseHqlGeneratorForMethod + { + public DecimalRemainderGenerator() + { + SupportedMethods = new[] + { + ReflectHelper.GetMethodDefinition(() => decimal.Remainder(default(decimal), default(decimal))) + }; + } + + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) + { + return treeBuilder.TransparentCast(treeBuilder.MethodCall("mod", visitor.Visit(arguments[0]).AsExpression(), visitor.Visit(arguments[1]).AsExpression()), typeof(decimal)); + } + } + + internal class DecimalNegateGenerator : BaseHqlGeneratorForMethod + { + public DecimalNegateGenerator() + { + SupportedMethods = new[] + { + ReflectHelper.GetMethodDefinition(() => decimal.Negate(default(decimal))) + }; + } + + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) + { + return treeBuilder.TransparentCast(treeBuilder.Negate(visitor.Visit(arguments[0]).AsExpression()), typeof(decimal)); + } + } +} diff --git a/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs b/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs index 75e20952052..e2cf2574cac 100644 --- a/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs +++ b/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs @@ -51,6 +51,15 @@ public DefaultLinqToHqlGeneratorsRegistry() this.Merge(new CollectionContainsGenerator()); this.Merge(new DateTimePropertiesHqlGenerator()); + + this.Merge(new DecimalAddGenerator()); + this.Merge(new DecimalDivideGenerator()); + this.Merge(new DecimalMultiplyGenerator()); + this.Merge(new DecimalSubtractGenerator()); + this.Merge(new DecimalRemainderGenerator()); + this.Merge(new DecimalNegateGenerator()); + this.Merge(new RoundGenerator()); + this.Merge(new TruncateGenerator()); } protected bool GetRuntimeMethodGenerator(MethodInfo method, out IHqlGeneratorForMethod methodGenerator) @@ -100,4 +109,4 @@ public void RegisterGenerator(IRuntimeMethodHqlGenerator generator) runtimeMethodHqlGenerators.Add(generator); } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Functions/EqualsGenerator.cs b/src/NHibernate/Linq/Functions/EqualsGenerator.cs index a7c8b52102c..1593b8d3d6f 100644 --- a/src/NHibernate/Linq/Functions/EqualsGenerator.cs +++ b/src/NHibernate/Linq/Functions/EqualsGenerator.cs @@ -32,6 +32,8 @@ public EqualsGenerator() ReflectHelper.GetMethodDefinition(x => x.Equals(x)), ReflectHelper.GetMethodDefinition(x => x.Equals(x)), + + ReflectHelper.GetMethodDefinition(() => decimal.Equals(default(decimal), default(decimal))), ReflectHelper.GetMethodDefinition(x => x.Equals(x)), ReflectHelper.GetMethodDefinition(x => x.Equals(x)), diff --git a/src/NHibernate/Linq/Functions/MathGenerator.cs b/src/NHibernate/Linq/Functions/MathGenerator.cs index aff4807f73b..f7f42ef5c13 100644 --- a/src/NHibernate/Linq/Functions/MathGenerator.cs +++ b/src/NHibernate/Linq/Functions/MathGenerator.cs @@ -45,17 +45,14 @@ public MathGenerator() ReflectHelper.GetMethodDefinition(() => Math.Sign(default(short))), ReflectHelper.GetMethodDefinition(() => Math.Sign(default(sbyte))), - ReflectHelper.GetMethodDefinition(() => Math.Round(default(decimal))), - ReflectHelper.GetMethodDefinition(() => Math.Round(default(decimal), default(int))), - ReflectHelper.GetMethodDefinition(() => Math.Round(default(double))), - ReflectHelper.GetMethodDefinition(() => Math.Round(default(double), default(int))), ReflectHelper.GetMethodDefinition(() => Math.Floor(default(decimal))), ReflectHelper.GetMethodDefinition(() => Math.Floor(default(double))), + ReflectHelper.GetMethodDefinition(() => decimal.Floor(default(decimal))), + ReflectHelper.GetMethodDefinition(() => Math.Ceiling(default(decimal))), ReflectHelper.GetMethodDefinition(() => Math.Ceiling(default(double))), - ReflectHelper.GetMethodDefinition(() => Math.Truncate(default(decimal))), - ReflectHelper.GetMethodDefinition(() => Math.Truncate(default(double))), - + ReflectHelper.GetMethodDefinition(() => decimal.Ceiling(default(decimal))), + ReflectHelper.GetMethodDefinition(() => Math.Pow(default(double), default(double))), }; } diff --git a/src/NHibernate/Linq/Functions/RoundGenerator.cs b/src/NHibernate/Linq/Functions/RoundGenerator.cs new file mode 100644 index 00000000000..a297cdd9299 --- /dev/null +++ b/src/NHibernate/Linq/Functions/RoundGenerator.cs @@ -0,0 +1,36 @@ +using System; +using System.Collections.ObjectModel; +using System.Linq.Expressions; +using System.Reflection; +using NHibernate.Hql.Ast; +using NHibernate.Linq.Visitors; +using NHibernate.Util; + +namespace NHibernate.Linq.Functions +{ + internal class RoundGenerator : BaseHqlGeneratorForMethod + { + public RoundGenerator() + { + SupportedMethods = new[] + { + ReflectHelper.GetMethodDefinition(() => Math.Round(default(double))), + ReflectHelper.GetMethodDefinition(() => Math.Round(default(double), default(int))), + ReflectHelper.GetMethodDefinition(() => Math.Round(default(decimal))), + ReflectHelper.GetMethodDefinition(() => Math.Round(default(decimal), default(int))), + ReflectHelper.GetMethodDefinition(() => decimal.Round(default(decimal))), + ReflectHelper.GetMethodDefinition(() => decimal.Round(default(decimal), default(int))), + }; + } + + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) + { + var numberOfDecimals = arguments.Count == 2 + ? visitor.Visit(arguments[1]).AsExpression() + : treeBuilder.Constant(0); + return treeBuilder.TransparentCast( + treeBuilder.MethodCall("round", visitor.Visit(arguments[0]).AsExpression(), numberOfDecimals), + method.ReturnType); + } + } +} diff --git a/src/NHibernate/Linq/Functions/TruncateGenerator.cs b/src/NHibernate/Linq/Functions/TruncateGenerator.cs new file mode 100644 index 00000000000..c1f1babae25 --- /dev/null +++ b/src/NHibernate/Linq/Functions/TruncateGenerator.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.ObjectModel; +using System.Linq.Expressions; +using System.Reflection; +using NHibernate.Hql.Ast; +using NHibernate.Linq.Visitors; +using NHibernate.Util; + +namespace NHibernate.Linq.Functions +{ + internal class TruncateGenerator : BaseHqlGeneratorForMethod + { + public TruncateGenerator() + { + SupportedMethods = new[] + { + ReflectHelper.GetMethodDefinition(() => Math.Truncate(default(decimal))), + ReflectHelper.GetMethodDefinition(() => Math.Truncate(default(double))), + ReflectHelper.GetMethodDefinition(() => decimal.Truncate(default(decimal))) + }; + } + + public override HqlTreeNode BuildHql(MethodInfo method, Expression expression, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) + { + return treeBuilder.MethodCall("truncate", visitor.Visit(arguments[0]).AsExpression(), treeBuilder.Constant(0)); + } + } +}