Skip to content

Commit 011f678

Browse files
committed
Support basic arithmetic operations (+, -, *, /) in QueryOver
1 parent b3cdf05 commit 011f678

File tree

3 files changed

+129
-19
lines changed

3 files changed

+129
-19
lines changed

src/NHibernate.Test/Async/Criteria/Lambda/IntegrationFixture.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,5 +510,31 @@ public async Task StatelessSessionAsync()
510510
Assert.That(statelessPerson2.Id, Is.EqualTo(personId));
511511
}
512512
}
513+
514+
[Test]
515+
public async Task QueryOverArithmeticAsync()
516+
{
517+
using (ISession s = OpenSession())
518+
using (ITransaction t = s.BeginTransaction())
519+
{
520+
await (s.SaveAsync(new Person() {Name = "test person 1", Age = 20}));
521+
await (s.SaveAsync(new Person() {Name = "test person 2", Age = 50}));
522+
await (t.CommitAsync());
523+
}
524+
525+
using (var s = OpenSession())
526+
{
527+
var persons1 = await (s.QueryOver<Person>().Where(p => (p.Age * 2) / 2 + 20 - 20 == 20).ListAsync());
528+
var persons2 = await (s.QueryOver<Person>().Where(p => (-(-p.Age)) > 20).ListAsync());
529+
var persons3 = await (s.QueryOver<Person>().WhereRestrictionOn(p => (p.Age * 2) / 2 + 20 - 20).IsBetween(19).And(21).ListAsync());
530+
var persons4 = await (s.QueryOver<Person>().WhereRestrictionOn(p => -(-p.Age)).IsBetween(19).And(21).ListAsync());
531+
var persons5 = await (s.QueryOver<Person>().WhereRestrictionOn(p => (p.Age * 2) / 2 + 20 - 20).IsBetween(19).And(51).ListAsync());
532+
Assert.That(persons1.Count, Is.EqualTo(1));
533+
Assert.That(persons2.Count, Is.EqualTo(1));
534+
Assert.That(persons3.Count, Is.EqualTo(1));
535+
Assert.That(persons4.Count, Is.EqualTo(1));
536+
Assert.That(persons5.Count, Is.EqualTo(2));
537+
}
538+
}
513539
}
514540
}

src/NHibernate.Test/Criteria/Lambda/IntegrationFixture.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,5 +498,31 @@ public void StatelessSession()
498498
Assert.That(statelessPerson2.Id, Is.EqualTo(personId));
499499
}
500500
}
501+
502+
[Test]
503+
public void QueryOverArithmetic()
504+
{
505+
using (ISession s = OpenSession())
506+
using (ITransaction t = s.BeginTransaction())
507+
{
508+
s.Save(new Person() {Name = "test person 1", Age = 20});
509+
s.Save(new Person() {Name = "test person 2", Age = 50});
510+
t.Commit();
511+
}
512+
513+
using (var s = OpenSession())
514+
{
515+
var persons1 = s.QueryOver<Person>().Where(p => (p.Age * 2) / 2 + 20 - 20 == 20).List();
516+
var persons2 = s.QueryOver<Person>().Where(p => (-(-p.Age)) > 20).List();
517+
var persons3 = s.QueryOver<Person>().WhereRestrictionOn(p => (p.Age * 2) / 2 + 20 - 20).IsBetween(19).And(21).List();
518+
var persons4 = s.QueryOver<Person>().WhereRestrictionOn(p => -(-p.Age)).IsBetween(19).And(21).List();
519+
var persons5 = s.QueryOver<Person>().WhereRestrictionOn(p => (p.Age * 2) / 2 + 20 - 20).IsBetween(19).And(51).List();
520+
Assert.That(persons1.Count, Is.EqualTo(1));
521+
Assert.That(persons2.Count, Is.EqualTo(1));
522+
Assert.That(persons3.Count, Is.EqualTo(1));
523+
Assert.That(persons4.Count, Is.EqualTo(1));
524+
Assert.That(persons5.Count, Is.EqualTo(2));
525+
}
526+
}
501527
}
502528
}

src/NHibernate/Impl/ExpressionProcessor.cs

Lines changed: 77 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
using System.Runtime.CompilerServices;
77
using System.Text.RegularExpressions;
88
using NHibernate.Criterion;
9+
using NHibernate.Dialect.Function;
10+
using NHibernate.Type;
911
using NHibernate.Util;
1012
using Expression = System.Linq.Expressions.Expression;
1113

@@ -104,6 +106,8 @@ public string AsProperty()
104106
private static readonly Dictionary<LambdaSubqueryType, IDictionary<ExpressionType, Func<string, DetachedCriteria, AbstractCriterion>>> _subqueryExpressionCreatorTypes;
105107
private static readonly Dictionary<string, Func<MethodCallExpression, ICriterion>> _customMethodCallProcessors;
106108
private static readonly Dictionary<string, Func<Expression, IProjection>> _customProjectionProcessors;
109+
private static readonly Dictionary<ExpressionType, ISQLFunction> _binaryArithmethicTemplates = new Dictionary<ExpressionType, ISQLFunction>();
110+
private static readonly ISQLFunction _unaryNegateTemplate;
107111

108112
static ExpressionProcessor()
109113
{
@@ -198,6 +202,17 @@ static ExpressionProcessor()
198202
RegisterCustomProjection(() => Math.Round(default(double), default(int)), ProjectionsExtensions.ProcessRound);
199203
RegisterCustomProjection(() => Math.Round(default(decimal), default(int)), ProjectionsExtensions.ProcessRound);
200204
RegisterCustomProjection(() => ProjectionsExtensions.AsEntity(default(object)), ProjectionsExtensions.ProcessAsEntity);
205+
206+
RegisterBinaryArithmeticExpression(ExpressionType.Add, "+");
207+
RegisterBinaryArithmeticExpression(ExpressionType.Subtract, "-");
208+
RegisterBinaryArithmeticExpression(ExpressionType.Multiply, "*");
209+
RegisterBinaryArithmeticExpression(ExpressionType.Divide, "/");
210+
_unaryNegateTemplate = new VarArgsSQLFunction("(-", string.Empty, ")");
211+
}
212+
213+
private static void RegisterBinaryArithmeticExpression(ExpressionType type, string sqlOperand)
214+
{
215+
_binaryArithmethicTemplates[type] = new VarArgsSQLFunction("(", sqlOperand, ")");
201216
}
202217

203218
private static ICriterion Eq(ProjectionInfo property, object value)
@@ -248,15 +263,12 @@ public static object FindValue(Expression expression)
248263
public static ProjectionInfo FindMemberProjection(Expression expression)
249264
{
250265
if (!IsMemberExpression(expression))
251-
return ProjectionInfo.ForProjection(Projections.Constant(FindValue(expression)));
266+
return AsArithmeticExpression(expression) ?? ProjectionInfo.ForProjection(Projections.Constant(FindValue(expression)));
252267

253-
var unaryExpression = expression as UnaryExpression;
254-
if (unaryExpression != null)
268+
var unwrapExpression = UnwrapConvertExpression(expression);
269+
if (unwrapExpression != null)
255270
{
256-
if (!IsConversion(unaryExpression.NodeType))
257-
throw new ArgumentException("Cannot interpret member from " + expression, nameof(expression));
258-
259-
return FindMemberProjection(unaryExpression.Operand);
271+
return FindMemberProjection(unwrapExpression);
260272
}
261273

262274
var methodCallExpression = expression as MethodCallExpression;
@@ -283,6 +295,55 @@ public static ProjectionInfo FindMemberProjection(Expression expression)
283295
return ProjectionInfo.ForProperty(FindMemberExpression(expression));
284296
}
285297

298+
private static Expression UnwrapConvertExpression(Expression expression)
299+
{
300+
if (expression is UnaryExpression unaryExpression)
301+
{
302+
if (!IsConversion(unaryExpression.NodeType))
303+
{
304+
if (IsSupportedUnaryExpression(unaryExpression))
305+
return null;
306+
307+
throw new ArgumentException("Cannot interpret member from " + expression, nameof(expression));
308+
}
309+
return unaryExpression.Operand;
310+
}
311+
312+
return null;
313+
}
314+
315+
private static bool IsSupportedUnaryExpression(UnaryExpression expression)
316+
{
317+
return expression.NodeType == ExpressionType.Negate;
318+
}
319+
320+
private static ProjectionInfo AsArithmeticExpression(Expression expression)
321+
{
322+
if (!(expression is BinaryExpression be))
323+
{
324+
if (expression is UnaryExpression unary && unary.NodeType == ExpressionType.Negate)
325+
{
326+
return ProjectionInfo.ForProjection(
327+
new SqlFunctionProjection(_unaryNegateTemplate, TypeFactory.HeuristicType(unary.Type), FindMemberProjection(unary.Operand).AsProjection()));
328+
}
329+
330+
var unwrapExpression = UnwrapConvertExpression(expression);
331+
return unwrapExpression != null ? AsArithmeticExpression(unwrapExpression) : null;
332+
}
333+
334+
if (!_binaryArithmethicTemplates.TryGetValue(be.NodeType, out var template))
335+
{
336+
return null;
337+
}
338+
339+
return ProjectionInfo.ForProjection(
340+
new SqlFunctionProjection(
341+
template,
342+
TypeFactory.HeuristicType(be.Type),
343+
FindMemberProjection(be.Left).AsProjection(),
344+
FindMemberProjection(be.Right).AsProjection()));
345+
}
346+
286347
//http://stackoverflow.com/a/2509524/259946
287348
private static readonly Regex GeneratedMemberNameRegex = new Regex(@"^(CS\$)?<\w*>[1-9a-s]__[a-zA-Z]+[0-9]*$", RegexOptions.Compiled | RegexOptions.Singleline);
288349

@@ -410,13 +471,10 @@ private static System.Type FindMemberType(Expression expression)
410471
return memberExpression.Type;
411472
}
412473

413-
var unaryExpression = expression as UnaryExpression;
414-
if (unaryExpression != null)
474+
var unwrapExpression = UnwrapConvertExpression(expression);
475+
if (unwrapExpression != null)
415476
{
416-
if (!IsConversion(unaryExpression.NodeType))
417-
throw new ArgumentException("Cannot interpret member from " + expression, nameof(expression));
418-
419-
return FindMemberType(unaryExpression.Operand);
477+
return FindMemberType(unwrapExpression);
420478
}
421479

422480
var methodCallExpression = expression as MethodCallExpression;
@@ -425,6 +483,9 @@ private static System.Type FindMemberType(Expression expression)
425483
return methodCallExpression.Method.ReturnType;
426484
}
427485

486+
if (expression is BinaryExpression || expression is UnaryExpression)
487+
return expression.Type;
488+
428489
throw new ArgumentException("Could not determine member type from " + expression, nameof(expression));
429490
}
430491

@@ -446,13 +507,10 @@ private static bool IsMemberExpression(Expression expression)
446507
return EvaluatesToNull(memberExpression.Expression);
447508
}
448509

449-
var unaryExpression = expression as UnaryExpression;
450-
if (unaryExpression != null)
510+
var unwrapExpression = UnwrapConvertExpression(expression);
511+
if (unwrapExpression != null)
451512
{
452-
if (!IsConversion(unaryExpression.NodeType))
453-
throw new ArgumentException("Cannot interpret member from " + expression, nameof(expression));
454-
455-
return IsMemberExpression(unaryExpression.Operand);
513+
return IsMemberExpression(unwrapExpression);
456514
}
457515

458516
var methodCallExpression = expression as MethodCallExpression;

0 commit comments

Comments
 (0)