6
6
using System . Runtime . CompilerServices ;
7
7
using System . Text . RegularExpressions ;
8
8
using NHibernate . Criterion ;
9
+ using NHibernate . Dialect . Function ;
10
+ using NHibernate . Type ;
9
11
using NHibernate . Util ;
10
12
using Expression = System . Linq . Expressions . Expression ;
11
13
@@ -104,6 +106,8 @@ public string AsProperty()
104
106
private static readonly Dictionary < LambdaSubqueryType , IDictionary < ExpressionType , Func < string , DetachedCriteria , AbstractCriterion > > > _subqueryExpressionCreatorTypes ;
105
107
private static readonly Dictionary < string , Func < MethodCallExpression , ICriterion > > _customMethodCallProcessors ;
106
108
private static readonly Dictionary < string , Func < Expression , IProjection > > _customProjectionProcessors ;
109
+ private static readonly Dictionary < ExpressionType , ISQLFunction > _binaryArithmethicTemplates = new Dictionary < ExpressionType , ISQLFunction > ( ) ;
110
+ private static readonly ISQLFunction _unaryNegateTemplate ;
107
111
108
112
static ExpressionProcessor ( )
109
113
{
@@ -198,6 +202,17 @@ static ExpressionProcessor()
198
202
RegisterCustomProjection ( ( ) => Math . Round ( default ( double ) , default ( int ) ) , ProjectionsExtensions . ProcessRound ) ;
199
203
RegisterCustomProjection ( ( ) => Math . Round ( default ( decimal ) , default ( int ) ) , ProjectionsExtensions . ProcessRound ) ;
200
204
RegisterCustomProjection ( ( ) => ProjectionsExtensions . AsEntity ( default ( object ) ) , ProjectionsExtensions . ProcessAsEntity ) ;
205
+
206
+ RegisterBinaryArithmeticExpression ( ExpressionType . Add , "+" ) ;
207
+ RegisterBinaryArithmeticExpression ( ExpressionType . Subtract , "-" ) ;
208
+ RegisterBinaryArithmeticExpression ( ExpressionType . Multiply , "*" ) ;
209
+ RegisterBinaryArithmeticExpression ( ExpressionType . Divide , "/" ) ;
210
+ _unaryNegateTemplate = new VarArgsSQLFunction ( "(-" , string . Empty , ")" ) ;
211
+ }
212
+
213
+ private static void RegisterBinaryArithmeticExpression ( ExpressionType type , string sqlOperand )
214
+ {
215
+ _binaryArithmethicTemplates [ type ] = new VarArgsSQLFunction ( "(" , sqlOperand , ")" ) ;
201
216
}
202
217
203
218
private static ICriterion Eq ( ProjectionInfo property , object value )
@@ -248,15 +263,12 @@ public static object FindValue(Expression expression)
248
263
public static ProjectionInfo FindMemberProjection ( Expression expression )
249
264
{
250
265
if ( ! IsMemberExpression ( expression ) )
251
- return ProjectionInfo . ForProjection ( Projections . Constant ( FindValue ( expression ) ) ) ;
266
+ return AsArithmeticExpression ( expression ) ?? ProjectionInfo . ForProjection ( Projections . Constant ( FindValue ( expression ) ) ) ;
252
267
253
- var unaryExpression = expression as UnaryExpression ;
254
- if ( unaryExpression != null )
268
+ var unwrapExpression = UnwrapConvertExpression ( expression ) ;
269
+ if ( unwrapExpression != null )
255
270
{
256
- if ( ! IsConversion ( unaryExpression . NodeType ) )
257
- throw new ArgumentException ( "Cannot interpret member from " + expression , nameof ( expression ) ) ;
258
-
259
- return FindMemberProjection ( unaryExpression . Operand ) ;
271
+ return FindMemberProjection ( unwrapExpression ) ;
260
272
}
261
273
262
274
var methodCallExpression = expression as MethodCallExpression ;
@@ -283,6 +295,55 @@ public static ProjectionInfo FindMemberProjection(Expression expression)
283
295
return ProjectionInfo . ForProperty ( FindMemberExpression ( expression ) ) ;
284
296
}
285
297
298
+ private static Expression UnwrapConvertExpression ( Expression expression )
299
+ {
300
+ if ( expression is UnaryExpression unaryExpression )
301
+ {
302
+ if ( ! IsConversion ( unaryExpression . NodeType ) )
303
+ {
304
+ if ( IsSupportedUnaryExpression ( unaryExpression ) )
305
+ return null ;
306
+
307
+ throw new ArgumentException ( "Cannot interpret member from " + expression , nameof ( expression ) ) ;
308
+ }
309
+ return unaryExpression . Operand ;
310
+ }
311
+
312
+ return null ;
313
+ }
314
+
315
+ private static bool IsSupportedUnaryExpression ( UnaryExpression expression )
316
+ {
317
+ return expression . NodeType == ExpressionType . Negate ;
318
+ }
319
+
320
+ private static ProjectionInfo AsArithmeticExpression ( Expression expression )
321
+ {
322
+ if ( ! ( expression is BinaryExpression be ) )
323
+ {
324
+ if ( expression is UnaryExpression unary && unary . NodeType == ExpressionType . Negate )
325
+ {
326
+ return ProjectionInfo . ForProjection (
327
+ new SqlFunctionProjection ( _unaryNegateTemplate , TypeFactory . HeuristicType ( unary . Type ) , FindMemberProjection ( unary . Operand ) . AsProjection ( ) ) ) ;
328
+ }
329
+
330
+ var unwrapExpression = UnwrapConvertExpression ( expression ) ;
331
+ return unwrapExpression != null ? AsArithmeticExpression ( unwrapExpression ) : null ;
332
+ }
333
+
334
+ if ( ! _binaryArithmethicTemplates . TryGetValue ( be . NodeType , out var template ) )
335
+ {
336
+ return null ;
337
+ }
338
+
339
+ return ProjectionInfo . ForProjection (
340
+ new SqlFunctionProjection (
341
+ template ,
342
+ TypeFactory . HeuristicType ( be . Type ) ,
343
+ FindMemberProjection ( be . Left ) . AsProjection ( ) ,
344
+ FindMemberProjection ( be . Right ) . AsProjection ( ) ) ) ;
345
+ }
346
+
286
347
//http://stackoverflow.com/a/2509524/259946
287
348
private static readonly Regex GeneratedMemberNameRegex = new Regex ( @"^(CS\$)?<\w*>[1-9a-s]__[a-zA-Z]+[0-9]*$" , RegexOptions . Compiled | RegexOptions . Singleline ) ;
288
349
@@ -410,13 +471,10 @@ private static System.Type FindMemberType(Expression expression)
410
471
return memberExpression . Type ;
411
472
}
412
473
413
- var unaryExpression = expression as UnaryExpression ;
414
- if ( unaryExpression != null )
474
+ var unwrapExpression = UnwrapConvertExpression ( expression ) ;
475
+ if ( unwrapExpression != null )
415
476
{
416
- if ( ! IsConversion ( unaryExpression . NodeType ) )
417
- throw new ArgumentException ( "Cannot interpret member from " + expression , nameof ( expression ) ) ;
418
-
419
- return FindMemberType ( unaryExpression . Operand ) ;
477
+ return FindMemberType ( unwrapExpression ) ;
420
478
}
421
479
422
480
var methodCallExpression = expression as MethodCallExpression ;
@@ -425,6 +483,9 @@ private static System.Type FindMemberType(Expression expression)
425
483
return methodCallExpression . Method . ReturnType ;
426
484
}
427
485
486
+ if ( expression is BinaryExpression || expression is UnaryExpression )
487
+ return expression . Type ;
488
+
428
489
throw new ArgumentException ( "Could not determine member type from " + expression , nameof ( expression ) ) ;
429
490
}
430
491
@@ -446,13 +507,10 @@ private static bool IsMemberExpression(Expression expression)
446
507
return EvaluatesToNull ( memberExpression . Expression ) ;
447
508
}
448
509
449
- var unaryExpression = expression as UnaryExpression ;
450
- if ( unaryExpression != null )
510
+ var unwrapExpression = UnwrapConvertExpression ( expression ) ;
511
+ if ( unwrapExpression != null )
451
512
{
452
- if ( ! IsConversion ( unaryExpression . NodeType ) )
453
- throw new ArgumentException ( "Cannot interpret member from " + expression , nameof ( expression ) ) ;
454
-
455
- return IsMemberExpression ( unaryExpression . Operand ) ;
513
+ return IsMemberExpression ( unwrapExpression ) ;
456
514
}
457
515
458
516
var methodCallExpression = expression as MethodCallExpression ;
0 commit comments