From 750363aee931642aa488f0bd0fb3357e210bebda Mon Sep 17 00:00:00 2001 From: maca88 Date: Sun, 28 Mar 2021 16:28:16 +0200 Subject: [PATCH] Fix MappedAs when called on an UnaryExpression --- .../Async/Linq/ByMethod/MappedAsTests.cs | 52 +++++++++++++++++++ .../Linq/ByMethod/MappedAsTests.cs | 41 +++++++++++++++ .../Visitors/ExpressionParameterVisitor.cs | 4 +- .../Linq/Visitors/ParameterTypeLocator.cs | 4 +- 4 files changed, 97 insertions(+), 4 deletions(-) create mode 100644 src/NHibernate.Test/Async/Linq/ByMethod/MappedAsTests.cs create mode 100644 src/NHibernate.Test/Linq/ByMethod/MappedAsTests.cs diff --git a/src/NHibernate.Test/Async/Linq/ByMethod/MappedAsTests.cs b/src/NHibernate.Test/Async/Linq/ByMethod/MappedAsTests.cs new file mode 100644 index 00000000000..af79cd5f272 --- /dev/null +++ b/src/NHibernate.Test/Async/Linq/ByMethod/MappedAsTests.cs @@ -0,0 +1,52 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using NHibernate.Linq; +using NUnit.Framework; + +namespace NHibernate.Test.Linq.ByMethod +{ + using System.Threading.Tasks; + [TestFixture] + public class MappedAsTestsAsync : LinqTestCase + { + [Test] + public async Task WithUnaryExpressionAsync() + { + var num = 1; + await (db.Orders.Where(o => o.Freight == (-num).MappedAs(NHibernateUtil.Decimal)).ToListAsync()); + await (db.Orders.Where(o => o.Freight == ((decimal) num).MappedAs(NHibernateUtil.Decimal)).ToListAsync()); + await (db.Orders.Where(o => o.Freight == ((decimal?) (decimal) num).MappedAs(NHibernateUtil.Decimal)).ToListAsync()); + } + + [Test] + public async Task WithNewExpressionAsync() + { + var num = 1; + await (db.Orders.Where(o => o.Freight == new decimal(num).MappedAs(NHibernateUtil.Decimal)).ToListAsync()); + } + + [Test] + public async Task WithMethodCallExpressionAsync() + { + var num = 1; + await (db.Orders.Where(o => o.Freight == GetDecimal(num).MappedAs(NHibernateUtil.Decimal)).ToListAsync()); + } + + private decimal GetDecimal(int number) + { + return number; + } + } +} diff --git a/src/NHibernate.Test/Linq/ByMethod/MappedAsTests.cs b/src/NHibernate.Test/Linq/ByMethod/MappedAsTests.cs new file mode 100644 index 00000000000..f87cff8af92 --- /dev/null +++ b/src/NHibernate.Test/Linq/ByMethod/MappedAsTests.cs @@ -0,0 +1,41 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using NHibernate.Linq; +using NUnit.Framework; + +namespace NHibernate.Test.Linq.ByMethod +{ + [TestFixture] + public class MappedAsTests : LinqTestCase + { + [Test] + public void WithUnaryExpression() + { + var num = 1; + db.Orders.Where(o => o.Freight == (-num).MappedAs(NHibernateUtil.Decimal)).ToList(); + db.Orders.Where(o => o.Freight == ((decimal) num).MappedAs(NHibernateUtil.Decimal)).ToList(); + db.Orders.Where(o => o.Freight == ((decimal?) (decimal) num).MappedAs(NHibernateUtil.Decimal)).ToList(); + } + + [Test] + public void WithNewExpression() + { + var num = 1; + db.Orders.Where(o => o.Freight == new decimal(num).MappedAs(NHibernateUtil.Decimal)).ToList(); + } + + [Test] + public void WithMethodCallExpression() + { + var num = 1; + db.Orders.Where(o => o.Freight == GetDecimal(num).MappedAs(NHibernateUtil.Decimal)).ToList(); + } + + private decimal GetDecimal(int number) + { + return number; + } + } +} diff --git a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs index c0658d4312c..76467a8035e 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs @@ -70,7 +70,7 @@ protected override Expression VisitMethodCall(MethodCallExpression expression) { var rawParameter = Visit(expression.Arguments[0]); // TODO 6.0: Remove below code and return expression as this logic is now inside ConstantTypeLocator - var parameter = rawParameter as ConstantExpression; + var parameter = ParameterTypeLocator.UnwrapUnary(rawParameter) as ConstantExpression; var type = expression.Arguments[1] as ConstantExpression; if (parameter == null) throw new HibernateException( @@ -83,7 +83,7 @@ protected override Expression VisitMethodCall(MethodCallExpression expression) _parameters[parameter].Type = (IType)type.Value; - return parameter; + return rawParameter; } var method = expression.Method.IsGenericMethod diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 7e7d5834272..ccc950d57f2 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -228,7 +228,7 @@ protected override Expression VisitMethodCall(MethodCallExpression node) if (VisitorUtil.IsMappedAs(node.Method)) { var rawParameter = Visit(node.Arguments[0]); - var parameter = rawParameter as ConstantExpression; + var parameter = UnwrapUnary(rawParameter) as ConstantExpression; var type = node.Arguments[1] as ConstantExpression; if (parameter == null) throw new HibernateException( @@ -405,7 +405,7 @@ private bool IsDynamicMember(Expression expression) /// /// The expression to unwrap. /// The unwrapped expression. - private static Expression UnwrapUnary(Expression expression) + internal static Expression UnwrapUnary(Expression expression) { while (expression is UnaryExpression unaryExpression) {