From 63e7475522b8f992cba3b37a252d7019f80a10bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Delaporte?= Date: Sun, 9 Apr 2017 21:34:33 +0200 Subject: [PATCH 1/3] NH-3944 - preliminary work, Nh clauses eliminated. --- .../Linq/ByMethod/GroupByHavingTests.cs | 18 ++ src/NHibernate/Linq/Clauses/NhHavingClause.cs | 19 --- src/NHibernate/Linq/Clauses/NhJoinClause.cs | 62 ------- src/NHibernate/Linq/Clauses/NhWithClause.cs | 19 --- .../GroupBy/AggregatingGroupByRewriter.cs | 20 +-- src/NHibernate/Linq/GroupBy/PagingRewriter.cs | 23 +-- .../NestedSelects/NestedSelectRewriter.cs | 159 +++++++++--------- .../Linq/ReWriters/AddJoinsReWriter.cs | 8 +- src/NHibernate/Linq/Visitors/JoinBuilder.cs | 29 ++-- .../Linq/Visitors/LeftJoinRewriter.cs | 33 ++-- .../Linq/Visitors/QueryModelVisitor.cs | 78 ++++----- .../Linq/Visitors/VisitorParameters.cs | 98 ++++++++++- src/NHibernate/NHibernate.csproj | 3 - 13 files changed, 274 insertions(+), 295 deletions(-) delete mode 100644 src/NHibernate/Linq/Clauses/NhHavingClause.cs delete mode 100644 src/NHibernate/Linq/Clauses/NhJoinClause.cs delete mode 100644 src/NHibernate/Linq/Clauses/NhWithClause.cs diff --git a/src/NHibernate.Test/Linq/ByMethod/GroupByHavingTests.cs b/src/NHibernate.Test/Linq/ByMethod/GroupByHavingTests.cs index 818d7593722..7b4e4cf3c50 100644 --- a/src/NHibernate.Test/Linq/ByMethod/GroupByHavingTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/GroupByHavingTests.cs @@ -135,5 +135,23 @@ public void SingleKeyGroupAndCountWithHavingClause() var hornRow = orderCounts.Single(row => row.CompanyName == "Around the Horn"); Assert.That(hornRow.OrderCount, Is.EqualTo(13)); } + + [Test, Explicit("Demonstrate an unsupported case for PagingRewriter")] + public void SingleKeyGroupAndCountWithHavingClausePagingAndOuterWhere() + { + var orderCounts = db.Orders + .GroupBy(o => o.Customer.CompanyName) + .Where(g => g.Count() > 10) + .Select(g => new { CompanyName = g.Key, OrderCount = g.Count() }) + .OrderBy(oc => oc.CompanyName) + .Skip(5) + .Take(10) + .Where(oc => oc.CompanyName.Contains("F")) + .ToList(); + + Assert.That(orderCounts, Has.Count.EqualTo(3)); + var frankRow = orderCounts.Single(row => row.CompanyName == "Frankenversand"); + Assert.That(frankRow.OrderCount, Is.EqualTo(15)); + } } } diff --git a/src/NHibernate/Linq/Clauses/NhHavingClause.cs b/src/NHibernate/Linq/Clauses/NhHavingClause.cs deleted file mode 100644 index 131043ecc92..00000000000 --- a/src/NHibernate/Linq/Clauses/NhHavingClause.cs +++ /dev/null @@ -1,19 +0,0 @@ -using Remotion.Linq.Clauses; -using Remotion.Linq.Clauses.ExpressionTreeVisitors; -using System.Linq.Expressions; - -namespace NHibernate.Linq.Clauses -{ - public class NhHavingClause : WhereClause - { - public NhHavingClause(Expression predicate) - : base(predicate) - { - } - - public override string ToString() - { - return "having " + FormattingExpressionTreeVisitor.Format(Predicate); - } - } -} \ No newline at end of file diff --git a/src/NHibernate/Linq/Clauses/NhJoinClause.cs b/src/NHibernate/Linq/Clauses/NhJoinClause.cs deleted file mode 100644 index 944a8248d4d..00000000000 --- a/src/NHibernate/Linq/Clauses/NhJoinClause.cs +++ /dev/null @@ -1,62 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Collections.ObjectModel; -using System.Linq.Expressions; -using NHibernate.Linq.Visitors; -using Remotion.Linq.Clauses; -using Remotion.Linq.Clauses.Expressions; - -namespace NHibernate.Linq.Clauses -{ - /// - /// All joins are created as outer joins. An optimization in finds - /// joins that may be inner joined and calls on them. - /// 's will - /// then emit the correct HQL join. - /// - public class NhJoinClause : AdditionalFromClause - { - public NhJoinClause(string itemName, System.Type itemType, Expression fromExpression) - : this(itemName, itemType, fromExpression, new NhWithClause[0]) - { - } - - public NhJoinClause(string itemName, System.Type itemType, Expression fromExpression, IEnumerable restrictions) - : base(itemName, itemType, fromExpression) - { - Restrictions = new ObservableCollection(); - foreach (var withClause in restrictions) - Restrictions.Add(withClause); - IsInner = false; - } - - public ObservableCollection Restrictions { get; private set; } - - public bool IsInner { get; private set; } - - public override AdditionalFromClause Clone(CloneContext cloneContext) - { - var joinClause = new NhJoinClause(ItemName, ItemType, FromExpression); - foreach (var withClause in Restrictions) - { - var withClause2 = new NhWithClause(withClause.Predicate); - joinClause.Restrictions.Add(withClause2); - } - - cloneContext.QuerySourceMapping.AddMapping(this, new QuerySourceReferenceExpression(joinClause)); - return base.Clone(cloneContext); - } - - public void MakeInner() - { - IsInner = true; - } - - public override void TransformExpressions(Func transformation) - { - foreach (var withClause in Restrictions) - withClause.TransformExpressions(transformation); - base.TransformExpressions(transformation); - } - } -} \ No newline at end of file diff --git a/src/NHibernate/Linq/Clauses/NhWithClause.cs b/src/NHibernate/Linq/Clauses/NhWithClause.cs deleted file mode 100644 index ae21bc0f4a6..00000000000 --- a/src/NHibernate/Linq/Clauses/NhWithClause.cs +++ /dev/null @@ -1,19 +0,0 @@ -using System.Linq.Expressions; -using Remotion.Linq.Clauses; -using Remotion.Linq.Clauses.ExpressionTreeVisitors; - -namespace NHibernate.Linq.Clauses -{ - public class NhWithClause : WhereClause - { - public NhWithClause(Expression predicate) - : base(predicate) - { - } - - public override string ToString() - { - return "with " + FormattingExpressionTreeVisitor.Format(Predicate); - } - } -} diff --git a/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs b/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs index 8150d0ae5df..316ba170add 100644 --- a/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs +++ b/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs @@ -2,7 +2,6 @@ using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; -using NHibernate.Linq.Clauses; using NHibernate.Linq.ReWriters; using NHibernate.Linq.Visitors; using Remotion.Linq; @@ -43,11 +42,9 @@ public static class AggregatingGroupByRewriter typeof (CacheableResultOperator) }; - public static void ReWrite(QueryModel queryModel) + public static void ReWrite(QueryModel queryModel, VisitorParameters parameters) { - var subQueryExpression = queryModel.MainFromClause.FromExpression as SubQueryExpression; - - if (subQueryExpression != null) + if (queryModel.MainFromClause.FromExpression is SubQueryExpression subQueryExpression) { var operators = subQueryExpression.QueryModel.ResultOperators .Where(x => !QueryReferenceExpressionFlattener.FlattenableResultOperators.Contains(x.GetType())) @@ -55,17 +52,16 @@ public static void ReWrite(QueryModel queryModel) if (operators.Length == 1) { - var groupBy = operators[0] as GroupResultOperator; - if (groupBy != null) + if (operators[0] is GroupResultOperator groupBy) { - FlattenSubQuery(queryModel, subQueryExpression.QueryModel, groupBy); + FlattenSubQuery(queryModel, subQueryExpression.QueryModel, groupBy, parameters); RemoveCostantGroupByKeys(queryModel, groupBy); } } } } - private static void FlattenSubQuery(QueryModel queryModel, QueryModel subQueryModel, GroupResultOperator groupBy) + private static void FlattenSubQuery(QueryModel queryModel, QueryModel subQueryModel, GroupResultOperator groupBy, VisitorParameters parameters) { foreach (var resultOperator in queryModel.ResultOperators.Where(resultOperator => !AcceptableOuterResultOperators.Contains(resultOperator.GetType()))) { @@ -81,11 +77,9 @@ private static void FlattenSubQuery(QueryModel queryModel, QueryModel subQueryMo clause.TransformExpressions(s => GroupBySelectClauseRewriter.ReWrite(s, groupBy, subQueryModel)); //all outer where clauses actually are having clauses - var whereClause = clause as WhereClause; - if (whereClause != null) + if (clause is WhereClause whereClause) { - queryModel.BodyClauses.RemoveAt(i); - queryModel.BodyClauses.Insert(i, new NhHavingClause(whereClause.Predicate)); + parameters.AddHavingClause(whereClause); } } diff --git a/src/NHibernate/Linq/GroupBy/PagingRewriter.cs b/src/NHibernate/Linq/GroupBy/PagingRewriter.cs index 922f1790a10..87d89e4f2ee 100644 --- a/src/NHibernate/Linq/GroupBy/PagingRewriter.cs +++ b/src/NHibernate/Linq/GroupBy/PagingRewriter.cs @@ -9,17 +9,16 @@ namespace NHibernate.Linq.GroupBy { internal static class PagingRewriter { - private static readonly System.Type[] PagingResultOperators = new[] - { - typeof (SkipResultOperator), - typeof (TakeResultOperator), - }; + private static readonly System.Type[] PagingResultOperators = + new[] + { + typeof (SkipResultOperator), + typeof (TakeResultOperator), + }; public static void ReWrite(QueryModel queryModel) { - var subQueryExpression = queryModel.MainFromClause.FromExpression as SubQueryExpression; - - if (subQueryExpression != null && + if (queryModel.MainFromClause.FromExpression is SubQueryExpression subQueryExpression && subQueryExpression.QueryModel.ResultOperators.All(x => PagingResultOperators.Contains(x.GetType()))) { FlattenSubQuery(subQueryExpression, queryModel); @@ -28,7 +27,7 @@ public static void ReWrite(QueryModel queryModel) private static void FlattenSubQuery(SubQueryExpression subQueryExpression, QueryModel queryModel) { - // we can not flattern subquery if outer query has body clauses. + // we can not flatten subquery if outer query has body clauses. var subQueryModel = subQueryExpression.QueryModel; var subQueryMainFromClause = subQueryModel.MainFromClause; if (queryModel.BodyClauses.Count == 0) @@ -46,9 +45,13 @@ private static void FlattenSubQuery(SubQueryExpression subQueryExpression, Query { var cro = new ContainsResultOperator(new QuerySourceReferenceExpression(subQueryMainFromClause)); + // Cloning may cause having/join/with clauses listed in VisitorParameters to no more be matched. + // Not a problem for now, because those clauses imply a projection, which is not supported + // by the "new WhereClause(new SubQueryExpression(newSubQueryModel))" below. See + // SingleKeyGroupAndCountWithHavingClausePagingAndOuterWhere test by example. var newSubQueryModel = subQueryModel.Clone(); newSubQueryModel.ResultOperators.Add(cro); - newSubQueryModel.ResultTypeOverride = typeof (bool); + newSubQueryModel.ResultTypeOverride = typeof(bool); var where = new WhereClause(new SubQueryExpression(newSubQueryModel)); queryModel.BodyClauses.Add(where); diff --git a/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs b/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs index 37ab62de54b..61a057d97e6 100644 --- a/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs +++ b/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs @@ -3,9 +3,9 @@ using System.Linq; using System.Linq.Expressions; using System.Reflection; -using NHibernate.Linq.Clauses; using NHibernate.Linq.GroupBy; using NHibernate.Linq.Visitors; +using NHibernate.Type; using NHibernate.Util; using Remotion.Linq; using Remotion.Linq.Clauses; @@ -26,20 +26,20 @@ static class NestedSelectRewriter private static readonly PropertyInfo IGroupingKeyProperty = (PropertyInfo) ReflectHelper.GetProperty, Tuple>(g => g.Key); - public static void ReWrite(QueryModel queryModel, ISessionFactory sessionFactory) + public static void ReWrite(QueryModel queryModel, VisitorParameters parameters) { - var nsqmv = new NestedSelectDetector(sessionFactory); + var nsqmv = new NestedSelectDetector(parameters.SessionFactory); nsqmv.VisitExpression(queryModel.SelectClause.Selector); if (!nsqmv.HasSubqueries) return; var elementExpression = new List(); - var group = Expression.Parameter(typeof (IGrouping), "g"); - + var group = Expression.Parameter(typeof(IGrouping), "g"); + var replacements = new Dictionary(); foreach (var expression in nsqmv.Expressions) { - var processed = ProcessExpression(queryModel, sessionFactory, expression, elementExpression, group); + var processed = ProcessExpression(queryModel, expression, elementExpression, group, parameters); if (processed != null) replacements.Add(expression, processed); } @@ -48,7 +48,7 @@ public static void ReWrite(QueryModel queryModel, ISessionFactory sessionFactory var expressions = new List(); - var identifier = GetIdentifier(sessionFactory, new QuerySourceReferenceExpression(queryModel.MainFromClause)); + var identifier = GetIdentifier(parameters.SessionFactory, new QuerySourceReferenceExpression(queryModel.MainFromClause)); var rewriter = new SelectClauseRewriter(key, expressions, identifier, replacements); @@ -59,98 +59,104 @@ public static void ReWrite(QueryModel queryModel, ISessionFactory sessionFactory var keySelector = CreateSelector(elementExpression, 0); var elementSelector = CreateSelector(elementExpression, 1); - - var input = Expression.Parameter(typeof (IEnumerable), "input"); + + var input = Expression.Parameter(typeof(IEnumerable), "input"); var lambda = Expression.Lambda( Expression.Call(GroupByMethod, - Expression.Call(CastMethod, input), - keySelector, - elementSelector), + Expression.Call(CastMethod, input), + keySelector, + elementSelector), input); queryModel.ResultOperators.Add(new ClientSideSelect2(lambda)); - queryModel.ResultOperators.Add(new ClientSideSelect(Expression.Lambda(resultSelector, @group))); + queryModel.ResultOperators.Add(new ClientSideSelect(Expression.Lambda(resultSelector, group))); var initializers = elementExpression.Select(e => ConvertToObject(e.Expression)); - queryModel.SelectClause.Selector = Expression.NewArrayInit(typeof (object), initializers); + queryModel.SelectClause.Selector = Expression.NewArrayInit(typeof(object), initializers); } - private static Expression ProcessExpression(QueryModel queryModel, ISessionFactory sessionFactory, Expression expression, List elementExpression, ParameterExpression @group) + private static Expression ProcessExpression(QueryModel queryModel, Expression expression, List elementExpression, + ParameterExpression group, VisitorParameters parameters) { - var memberExpression = expression as MemberExpression; - if (memberExpression != null) - return ProcessMemberExpression(sessionFactory, elementExpression, queryModel, @group, memberExpression); - - var subQueryExpression = expression as SubQueryExpression; - if (subQueryExpression != null) - return ProcessSubquery(sessionFactory, elementExpression, queryModel, @group, subQueryExpression.QueryModel); - + if (expression is MemberExpression memberExpression) + return ProcessMemberExpression(elementExpression, queryModel, group, memberExpression, parameters); + + if (expression is SubQueryExpression subQueryExpression) + return ProcessSubquery(elementExpression, queryModel, group, subQueryExpression.QueryModel, parameters); + return null; } - private static Expression ProcessSubquery(ISessionFactory sessionFactory, ICollection elementExpression, QueryModel queryModel, Expression @group, QueryModel subQueryModel) + private static Expression ProcessSubquery(ICollection elementExpression, QueryModel queryModel, Expression group, + QueryModel subQueryModel, VisitorParameters parameters) { var subQueryMainFromClause = subQueryModel.MainFromClause; - var restrictions = subQueryModel.BodyClauses - .OfType() - .Select(w => new NhWithClause(w.Predicate)); - - var join = new NhJoinClause(subQueryMainFromClause.ItemName, - subQueryMainFromClause.ItemType, - subQueryMainFromClause.FromExpression, - restrictions); + var restrictions = subQueryModel.BodyClauses.OfType().ToList(); + var join = new AdditionalFromClause( + subQueryMainFromClause.ItemName, + subQueryMainFromClause.ItemType, + subQueryMainFromClause.FromExpression); - queryModel.BodyClauses.Add(@join); + parameters.AddLeftJoin(join, restrictions); + queryModel.BodyClauses.Add(join); + foreach (var withClause in restrictions) + { + queryModel.BodyClauses.Add(withClause); + } - var visitor = new SwapQuerySourceVisitor(subQueryMainFromClause, @join); + var visitor = new SwapQuerySourceVisitor(subQueryMainFromClause, join); queryModel.TransformExpressions(visitor.Swap); var selector = subQueryModel.SelectClause.Selector; var collectionType = subQueryModel.GetResultType(); - + var elementType = selector.Type; - var source = new QuerySourceReferenceExpression(@join); + var source = new QuerySourceReferenceExpression(join); - return BuildSubCollectionQuery(sessionFactory, elementExpression, @group, source, selector, elementType, collectionType); + return BuildSubCollectionQuery(parameters.SessionFactory, elementExpression, group, source, selector, elementType, collectionType); } - private static Expression ProcessMemberExpression(ISessionFactory sessionFactory, ICollection elementExpression, QueryModel queryModel, Expression @group, Expression memberExpression) + private static Expression ProcessMemberExpression(ICollection elementExpression, QueryModel queryModel, Expression group, + Expression memberExpression, VisitorParameters parameters) { - var join = new NhJoinClause(new NameGenerator(queryModel).GetNewName(), - GetElementType(memberExpression.Type), - memberExpression); + var join = new AdditionalFromClause( + new NameGenerator(queryModel).GetNewName(), + GetElementType(memberExpression.Type), + memberExpression); + parameters.AddLeftJoin(join, null); - queryModel.BodyClauses.Add(@join); + queryModel.BodyClauses.Add(join); - var source = new QuerySourceReferenceExpression(@join); + var source = new QuerySourceReferenceExpression(join); - return BuildSubCollectionQuery(sessionFactory, elementExpression, @group, source, source, source.Type, memberExpression.Type); + return BuildSubCollectionQuery(parameters.SessionFactory, elementExpression, group, source, source, source.Type, memberExpression.Type); } - private static Expression BuildSubCollectionQuery(ISessionFactory sessionFactory, ICollection expressions, Expression @group, Expression source, Expression select, System.Type elementType, System.Type collectionType) + private static Expression BuildSubCollectionQuery(ISessionFactory sessionFactory, ICollection expressions, Expression group, + Expression source, Expression select, System.Type elementType, System.Type collectionType) { var predicate = MakePredicate(expressions.Count); var identifier = GetIdentifier(sessionFactory, source); - var selector = MakeSelector(expressions, @select, identifier); + var selector = MakeSelector(expressions, select, identifier); - return SubCollectionQuery(collectionType, elementType, @group, predicate, selector); + return SubCollectionQuery(collectionType, elementType, group, predicate, selector); } - private static LambdaExpression MakeSelector(ICollection elementExpression, Expression @select, Expression identifier) + private static LambdaExpression MakeSelector(ICollection elementExpression, Expression select, Expression identifier) { - var parameter = Expression.Parameter(typeof (Tuple), "value"); + var parameter = Expression.Parameter(typeof(Tuple), "value"); var rewriter = new SelectClauseRewriter(parameter, elementExpression, identifier, 1, new Dictionary()); - var selectorBody = rewriter.VisitExpression(@select); + var selectorBody = rewriter.VisitExpression(select); return Expression.Lambda(selectorBody, parameter); } @@ -161,27 +167,29 @@ private static Expression SubCollectionQuery(System.Type collectionType, System. var selectMethod = ReflectionCache.EnumerableMethods.SelectDefinition.MakeGenericMethod(new[] { typeof(Tuple), elementType }); - var select = Expression.Call(selectMethod, - Expression.Call(WhereMethod, source, predicate), - selector); + var select = Expression.Call( + selectMethod, + Expression.Call(WhereMethod, source, predicate), + selector); if (collectionType.IsArray) { var toArrayMethod = ReflectionCache.EnumerableMethods.ToArrayDefinition.MakeGenericMethod(new[] { elementType }); - var array = Expression.Call(toArrayMethod, @select); + var array = Expression.Call(toArrayMethod, select); return array; } var constructor = GetCollectionConstructor(collectionType, elementType); if (constructor != null) - return Expression.New(constructor, (Expression) @select); + return Expression.New(constructor, select); var toListMethod = ReflectionCache.EnumerableMethods.ToListDefinition.MakeGenericMethod(new[] { elementType }); - return Expression.Call(Expression.Call(toListMethod, @select), - "AsReadonly", - System.Type.EmptyTypes); + return Expression.Call( + Expression.Call(toListMethod, select), + "AsReadonly", + System.Type.EmptyTypes); } private static ConstructorInfo GetCollectionConstructor(System.Type collectionType, System.Type elementType) @@ -201,12 +209,13 @@ private static ConstructorInfo GetCollectionConstructor(System.Type collectionTy private static LambdaExpression MakePredicate(int index) { // t => Not(ReferenceEquals(t.Items[index], null)) - var t = Expression.Parameter(typeof (Tuple), "t"); + var t = Expression.Parameter(typeof(Tuple), "t"); return Expression.Lambda( Expression.Not( - Expression.Call(ObjectReferenceEquals, - ArrayIndex(Expression.Property(t, Tuple.ItemsProperty), index), - Expression.Constant(null))), + Expression.Call( + ObjectReferenceEquals, + ArrayIndex(Expression.Property(t, Tuple.ItemsProperty), index), + Expression.Constant(null))), t); } @@ -219,26 +228,26 @@ private static Expression GetIdentifier(ISessionFactory sessionFactory, Expressi if (classMetadata == null) return Expression.Constant(null); - var propertyName=classMetadata.IdentifierPropertyName; - NHibernate.Type.EmbeddedComponentType componentType; - if (propertyName == null && (componentType=classMetadata.IdentifierType as NHibernate.Type.EmbeddedComponentType)!=null) - { - //The identifier is an embedded composite key. We only need one property from it for a null check - propertyName = componentType.PropertyNames.First(); - } + var propertyName = classMetadata.IdentifierPropertyName; + EmbeddedComponentType componentType; + if (propertyName == null && (componentType = classMetadata.IdentifierType as EmbeddedComponentType) != null) + { + //The identifier is an embedded composite key. We only need one property from it for a null check + propertyName = componentType.PropertyNames.First(); + } - return ConvertToObject(Expression.PropertyOrField(expression, propertyName)); + return ConvertToObject(Expression.PropertyOrField(expression, propertyName)); } private static LambdaExpression CreateSelector(IEnumerable expressions, int tuple) { - var parameter = Expression.Parameter(typeof (object[]), "x"); + var parameter = Expression.Parameter(typeof(object[]), "x"); - var initializers = expressions.Select((x, index) => new { x.Tuple, index}) + var initializers = expressions.Select((x, index) => new { x.Tuple, index }) .Where(x => x.Tuple == tuple) .Select(x => ArrayIndex(parameter, x.index)); - var newArrayInit = Expression.NewArrayInit(typeof (object), initializers); + var newArrayInit = Expression.NewArrayInit(typeof(object), initializers); return Expression.Lambda( Expression.New(Tuple.Constructor, newArrayInit), @@ -252,7 +261,7 @@ private static Expression ArrayIndex(Expression param, int value) private static Expression ConvertToObject(Expression expression) { - return Expression.Convert(expression, typeof (object)); + return Expression.Convert(expression, typeof(object)); } private static System.Type GetElementType(System.Type type) @@ -261,6 +270,6 @@ private static System.Type GetElementType(System.Type type) if (elementType == null) throw new NotSupportedException("Unknown collection type " + type.FullName); return elementType; - } + } } } diff --git a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs index 0dc17e068b1..1670f152227 100644 --- a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs +++ b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs @@ -18,17 +18,17 @@ public class AddJoinsReWriter : QueryModelVisitorBase, IIsEntityDecider private readonly MemberExpressionJoinDetector _memberExpressionJoinDetector; private readonly WhereJoinDetector _whereJoinDetector; - private AddJoinsReWriter(ISessionFactoryImplementor sessionFactory, QueryModel queryModel) + private AddJoinsReWriter(QueryModel queryModel, VisitorParameters parameters) { - _sessionFactory = sessionFactory; - var joiner = new Joiner(queryModel); + _sessionFactory = parameters.SessionFactory; + var joiner = new Joiner(queryModel, parameters); _memberExpressionJoinDetector = new MemberExpressionJoinDetector(this, joiner); _whereJoinDetector = new WhereJoinDetector(this, joiner); } public static void ReWrite(QueryModel queryModel, VisitorParameters parameters) { - var visitor = new AddJoinsReWriter(parameters.SessionFactory, queryModel); + var visitor = new AddJoinsReWriter(queryModel, parameters); visitor.VisitQueryModel(queryModel); } diff --git a/src/NHibernate/Linq/Visitors/JoinBuilder.cs b/src/NHibernate/Linq/Visitors/JoinBuilder.cs index f41ccea5022..b785d6e18d5 100644 --- a/src/NHibernate/Linq/Visitors/JoinBuilder.cs +++ b/src/NHibernate/Linq/Visitors/JoinBuilder.cs @@ -1,6 +1,5 @@ using System.Collections.Generic; using System.Linq.Expressions; -using NHibernate.Linq.Clauses; using Remotion.Linq; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; @@ -17,43 +16,38 @@ public interface IJoiner public class Joiner : IJoiner { - private readonly Dictionary _joins = new Dictionary(); + private readonly Dictionary _joins = new Dictionary(); private readonly NameGenerator _nameGenerator; + private readonly VisitorParameters _parameters; private readonly QueryModel _queryModel; - internal Joiner(QueryModel queryModel) + internal Joiner(QueryModel queryModel, VisitorParameters parameters) { _nameGenerator = new NameGenerator(queryModel); + _parameters = parameters; _queryModel = queryModel; } - public IEnumerable Joins - { - get { return _joins.Values; } - } - public Expression AddJoin(Expression expression, string key) { - NhJoinClause join; - - if (!_joins.TryGetValue(key, out join)) + if (!_joins.TryGetValue(key, out AdditionalFromClause join)) { - join = new NhJoinClause(_nameGenerator.GetNewName(), expression.Type, expression); + join = new AdditionalFromClause(_nameGenerator.GetNewName(), expression.Type, expression); + _parameters.AddLeftJoin(join, null); _queryModel.BodyClauses.Add(join); _joins.Add(key, join); } - return new QuerySourceReferenceExpression(@join); + return new QuerySourceReferenceExpression(join); } public void MakeInnerIfJoined(string key) { // key is not joined if it occurs only at tails of expressions, e.g. // a.B == null, a.B != null, a.B == c.D etc. - NhJoinClause nhJoinClause; - if (_joins.TryGetValue(key, out nhJoinClause)) + if (_joins.TryGetValue(key, out AdditionalFromClause join)) { - nhJoinClause.MakeInner(); + _parameters.MakeInnerJoin(join); } } @@ -64,8 +58,7 @@ public bool CanAddJoin(Expression expression) if (_queryModel.MainFromClause == source) return true; - var bodyClause = source as IBodyClause; - if (bodyClause != null && _queryModel.BodyClauses.Contains(bodyClause)) + if (source is IBodyClause bodyClause && _queryModel.BodyClauses.Contains(bodyClause)) return true; var resultOperatorBase = source as ResultOperatorBase; diff --git a/src/NHibernate/Linq/Visitors/LeftJoinRewriter.cs b/src/NHibernate/Linq/Visitors/LeftJoinRewriter.cs index 83a2facc0b5..e8cf6299208 100644 --- a/src/NHibernate/Linq/Visitors/LeftJoinRewriter.cs +++ b/src/NHibernate/Linq/Visitors/LeftJoinRewriter.cs @@ -1,6 +1,5 @@ using System.Collections.Generic; using System.Linq; -using NHibernate.Linq.Clauses; using Remotion.Linq; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.ExpressionTreeVisitors; @@ -11,9 +10,16 @@ namespace NHibernate.Linq.Visitors { public class LeftJoinRewriter : QueryModelVisitorBase { - public static void ReWrite(QueryModel queryModel) + public static void ReWrite(QueryModel queryModel, VisitorParameters parameters) { - new LeftJoinRewriter().VisitQueryModel(queryModel); + new LeftJoinRewriter(parameters).VisitQueryModel(queryModel); + } + + private readonly VisitorParameters _parameters; + + public LeftJoinRewriter(VisitorParameters parameters) + { + _parameters = parameters; } public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, QueryModel queryModel, int index) @@ -28,14 +34,9 @@ public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, var mainFromClause = subQueryModel.MainFromClause; - var restrictions = subQueryModel.BodyClauses - .OfType() - .Select(w => new NhWithClause(w.Predicate)); - - var join = new NhJoinClause(mainFromClause.ItemName, - mainFromClause.ItemType, - mainFromClause.FromExpression, - restrictions); + var join = new AdditionalFromClause(mainFromClause.ItemName, mainFromClause.ItemType, mainFromClause.FromExpression); + var restrictions = subQueryModel.BodyClauses.OfType().ToList(); + _parameters.AddLeftJoin(join, restrictions); var innerSelectorMapping = new QuerySourceMapping(); innerSelectorMapping.AddMapping(fromClause, subQueryModel.SelectClause.Selector); @@ -43,11 +44,11 @@ public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, queryModel.TransformExpressions(ex => ReferenceReplacingExpressionTreeVisitor.ReplaceClauseReferences(ex, innerSelectorMapping, false)); queryModel.BodyClauses.RemoveAt(index); - queryModel.BodyClauses.Insert(index, @join); - InsertBodyClauses(subQueryModel.BodyClauses.Where(b => !(b is WhereClause)), queryModel, index + 1); + queryModel.BodyClauses.Insert(index, join); + InsertBodyClauses(subQueryModel.BodyClauses, queryModel, index + 1); var innerBodyClauseMapping = new QuerySourceMapping(); - innerBodyClauseMapping.AddMapping(mainFromClause, new QuerySourceReferenceExpression(@join)); + innerBodyClauseMapping.AddMapping(mainFromClause, new QuerySourceReferenceExpression(join)); queryModel.TransformExpressions(ex => ReferenceReplacingExpressionTreeVisitor.ReplaceClauseReferences(ex, innerBodyClauseMapping, false)); } @@ -57,14 +58,14 @@ private static void InsertBodyClauses(IEnumerable bodyClauses, Quer foreach (var bodyClause in bodyClauses) { destinationQueryModel.BodyClauses.Insert(destinationIndex, bodyClause); - ++destinationIndex; + destinationIndex++; } } private static bool IsLeftJoin(QueryModel subQueryModel) { return subQueryModel.ResultOperators.Count == 1 && - subQueryModel.ResultOperators[0] is DefaultIfEmptyResultOperator; + subQueryModel.ResultOperators[0] is DefaultIfEmptyResultOperator; } } } \ No newline at end of file diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index 3e4a1c863cf..ee5bcceb8d0 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -3,7 +3,6 @@ using System.Linq.Expressions; using System.Reflection; using NHibernate.Hql.Ast; -using NHibernate.Linq.Clauses; using NHibernate.Linq.Expressions; using NHibernate.Linq.GroupBy; using NHibernate.Linq.GroupJoin; @@ -25,7 +24,7 @@ public class QueryModelVisitor : QueryModelVisitorBase public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel queryModel, VisitorParameters parameters, bool root, NhLinqExpressionReturnType? rootReturnType) { - NestedSelectRewriter.ReWrite(queryModel, parameters.SessionFactory); + NestedSelectRewriter.ReWrite(queryModel, parameters); // Remove unnecessary body operators RemoveUnnecessaryBodyOperators.ReWrite(queryModel); @@ -37,7 +36,7 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer NonAggregatingGroupByRewriter.ReWrite(queryModel); // Rewrite aggregate group-by statements - AggregatingGroupByRewriter.ReWrite(queryModel); + AggregatingGroupByRewriter.ReWrite(queryModel, parameters); // Rewrite aggregating group-joins AggregatingGroupJoinRewriter.ReWrite(queryModel); @@ -48,7 +47,7 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer SubQueryFromClauseFlattener.ReWrite(queryModel); // Rewrite left-joins - LeftJoinRewriter.ReWrite(queryModel); + LeftJoinRewriter.ReWrite(queryModel, parameters); // Rewrite paging PagingRewriter.ReWrite(queryModel); @@ -280,10 +279,7 @@ private MethodCallExpression GetAggregateMethodCall(MethodInfo aggregateMethodTe public override void VisitMainFromClause(MainFromClause fromClause, QueryModel queryModel) { - var querySourceName = VisitorParameters.QuerySourceNamer.GetName(fromClause); - var hqlExpressionTree = HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, VisitorParameters); - - _hqlTree.AddFromClause(_hqlTree.TreeBuilder.Range(hqlExpressionTree, _hqlTree.TreeBuilder.Alias(querySourceName))); + AddFromClause(fromClause); // apply any result operators that were rewritten if (RewrittenOperatorResult != null) @@ -300,56 +296,39 @@ public override void VisitMainFromClause(MainFromClause fromClause, QueryModel q public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, QueryModel queryModel, int index) { - var querySourceName = VisitorParameters.QuerySourceNamer.GetName(fromClause); - - var joinClause = fromClause as NhJoinClause; - if (joinClause != null) - { - VisitNhJoinClause(querySourceName, joinClause); - } - else if (fromClause.FromExpression is MemberExpression) + if (fromClause.FromExpression is MemberExpression) { // It's a join - _hqlTree.AddFromClause( - _hqlTree.TreeBuilder.Join( - HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, VisitorParameters).AsExpression(), - _hqlTree.TreeBuilder.Alias(querySourceName))); + var expression = HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, VisitorParameters).AsExpression(); + var querySourceName = VisitorParameters.QuerySourceNamer.GetName(fromClause); + var alias = _hqlTree.TreeBuilder.Alias(querySourceName); + var hqlJoin = VisitorParameters.IsLeftJoin(fromClause) ? + (HqlTreeNode)_hqlTree.TreeBuilder.LeftJoin(expression, alias) : + _hqlTree.TreeBuilder.Join(expression, alias); + + foreach (var withClause in VisitorParameters.GetRestrictions(fromClause)) + { + var booleanExpression = HqlGeneratorExpressionTreeVisitor.Visit(withClause.Predicate, VisitorParameters).ToBooleanExpression(); + hqlJoin.AddChild(_hqlTree.TreeBuilder.With(booleanExpression)); + } + + _hqlTree.AddFromClause(hqlJoin); } else { - // TODO - exact same code as in MainFromClause; refactor this out - _hqlTree.AddFromClause( - _hqlTree.TreeBuilder.Range( - HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, VisitorParameters), - _hqlTree.TreeBuilder.Alias(querySourceName))); - + AddFromClause(fromClause); } base.VisitAdditionalFromClause(fromClause, queryModel, index); } - private void VisitNhJoinClause(string querySourceName, NhJoinClause joinClause) + private void AddFromClause(FromClauseBase fromClause) { - var expression = HqlGeneratorExpressionTreeVisitor.Visit(joinClause.FromExpression, VisitorParameters).AsExpression(); - var alias = _hqlTree.TreeBuilder.Alias(querySourceName); - - HqlTreeNode hqlJoin; - if (joinClause.IsInner) - { - hqlJoin = _hqlTree.TreeBuilder.Join(expression, @alias); - } - else - { - hqlJoin = _hqlTree.TreeBuilder.LeftJoin(expression, @alias); - } - - foreach (var withClause in joinClause.Restrictions) - { - var booleanExpression = HqlGeneratorExpressionTreeVisitor.Visit(withClause.Predicate, VisitorParameters).ToBooleanExpression(); - hqlJoin.AddChild(_hqlTree.TreeBuilder.With(booleanExpression)); - } - - _hqlTree.AddFromClause(hqlJoin); + var querySourceName = VisitorParameters.QuerySourceNamer.GetName(fromClause); + _hqlTree.AddFromClause( + _hqlTree.TreeBuilder.Range( + HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, VisitorParameters), + _hqlTree.TreeBuilder.Alias(querySourceName))); } public override void VisitResultOperator(ResultOperatorBase resultOperator, QueryModel queryModel, int index) @@ -397,11 +376,12 @@ public override void VisitWhereClause(WhereClause whereClause, QueryModel queryM // Visit the predicate to build the query var expression = HqlGeneratorExpressionTreeVisitor.Visit(whereClause.Predicate, VisitorParameters).ToBooleanExpression(); - if (whereClause is NhHavingClause) + if (VisitorParameters.IsHavingClause(whereClause)) { _hqlTree.AddHavingClause(expression); } - else + // With clauses are handled with joins. + else if (!VisitorParameters.IsWithClause(whereClause)) { _hqlTree.AddWhereClause(expression); } diff --git a/src/NHibernate/Linq/Visitors/VisitorParameters.cs b/src/NHibernate/Linq/Visitors/VisitorParameters.cs index 27ef5de7029..5e676a21c22 100644 --- a/src/NHibernate/Linq/Visitors/VisitorParameters.cs +++ b/src/NHibernate/Linq/Visitors/VisitorParameters.cs @@ -3,23 +3,29 @@ using NHibernate.Engine; using NHibernate.Engine.Query; using NHibernate.Param; +using Remotion.Linq.Clauses; namespace NHibernate.Linq.Visitors { public class VisitorParameters { - public ISessionFactoryImplementor SessionFactory { get; private set; } + public ISessionFactoryImplementor SessionFactory { get; } - public IDictionary ConstantToParameterMap { get; private set; } + public IDictionary ConstantToParameterMap { get; } - public List RequiredHqlParameters { get; private set; } + public List RequiredHqlParameters { get; } - public QuerySourceNamer QuerySourceNamer { get; set; } + public QuerySourceNamer QuerySourceNamer { get; } + + private readonly HashSet _havingClauses = new HashSet(); + private readonly HashSet _leftJoins = new HashSet(); + private readonly HashSet _withClauses = new HashSet(); + private readonly Dictionary> _joinRestrictions = new Dictionary>(); public VisitorParameters( - ISessionFactoryImplementor sessionFactory, - IDictionary constantToParameterMap, - List requiredHqlParameters, + ISessionFactoryImplementor sessionFactory, + IDictionary constantToParameterMap, + List requiredHqlParameters, QuerySourceNamer querySourceNamer) { SessionFactory = sessionFactory; @@ -27,5 +33,83 @@ public VisitorParameters( RequiredHqlParameters = requiredHqlParameters; QuerySourceNamer = querySourceNamer; } + + /// + /// Indicates if a Linq where clause needs to be converted to a HQL having clause. + /// + /// The clause to test. + /// true if the clause needs to be converted to a HQL having clause, false otherwise. + public bool IsHavingClause(WhereClause clause) + { + return _havingClauses.Contains(clause); + } + + /// + /// Indicates if a Linq where clause needs to be converted to a HQL with clause. + /// + /// The clause to test. + /// true if the clause needs to be converted to a HQL with clause, false otherwise. + public bool IsWithClause(WhereClause clause) + { + return _withClauses.Contains(clause); + } + + /// + /// Indicates if a Linq join clause needs to be converted to a HQL left join. + /// + /// The join to test. + /// true if the clause needs to be converted to a HQL left join, false otherwise. + public bool IsLeftJoin(AdditionalFromClause join) + { + return _leftJoins.Contains(join); + } + + /// + /// Get the clauses to apply to the join as HQL with clauses. + /// + /// The join. + /// A list of where clauses to apply as HQL with to the join. + public IEnumerable GetRestrictions(AdditionalFromClause join) + { + if (_joinRestrictions.TryGetValue(join, out var restrictions)) + return restrictions; + return new List(); + } + + /// + /// Add a detected having clause. + /// + /// The clause to add. + public void AddHavingClause(WhereClause clause) + { + _havingClauses.Add(clause); + } + + /// + /// Add a detected join. + /// + /// The join to add. + /// Its restrictions if any. + public void AddLeftJoin(AdditionalFromClause join, IEnumerable restrictions) + { + _leftJoins.Add(join); + if (restrictions != null) + { + _joinRestrictions.Add(join, restrictions); + foreach (var with in restrictions) + { + _withClauses.Add(with); + } + } + } + + /// + /// Remove a join clause from left join clauses if it was one. + /// + /// The join clause to handle as inner. + public void MakeInnerJoin(AdditionalFromClause join) + { + _leftJoins.Remove(join); + } } } \ No newline at end of file diff --git a/src/NHibernate/NHibernate.csproj b/src/NHibernate/NHibernate.csproj index b52b617a9c5..3d80f7fcab3 100644 --- a/src/NHibernate/NHibernate.csproj +++ b/src/NHibernate/NHibernate.csproj @@ -296,9 +296,6 @@ - - - From 12b65d6e22a91df8cb23be1e50ff4029aca1402f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Delaporte?= Date: Sat, 15 Apr 2017 13:29:16 +0200 Subject: [PATCH 2/3] NH-3944 - migration to Relinq v2 and modernization to .Net Framework v4 expression practices. --- doc/reference/modules/query_linq.xml | 17 +- .../NHibernate.DomainModel.csproj | 1 + src/NHibernate.DomainModel/app.config | 11 + src/NHibernate.Test/App.config | 14 +- .../Linq/CustomQueryModelRewriterTests.cs | 12 +- .../NHSpecificTest/NH3386/Fixture.cs | 1 + src/NHibernate.Test/NHibernate.Test.csproj | 10 +- src/NHibernate.Test/packages.config | 14 +- src/NHibernate.TestDatabaseSetup/App.config | 16 +- src/NHibernate/Linq/DBOnlyAttribute.cs | 14 + src/NHibernate/Linq/DefaultQueryProvider.cs | 8 +- .../RemoveCharToIntConversion.cs | 14 +- .../RemoveRedundantCast.cs | 2 +- .../SimplifyCompareTransformer.cs | 2 +- .../Expressions/NhAggregatedExpression.cs | 33 +- .../Linq/Expressions/NhAverageExpression.cs | 14 +- .../Linq/Expressions/NhCountExpression.cs | 20 +- .../Linq/Expressions/NhDistinctExpression.cs | 11 +- .../Linq/Expressions/NhExpression.cs | 48 +++ .../Linq/Expressions/NhMaxExpression.cs | 11 +- .../Linq/Expressions/NhMinExpression.cs | 11 +- .../Linq/Expressions/NhNewExpression.cs | 58 +--- .../Linq/Expressions/NhNominatedExpression.cs | 25 +- .../Linq/Expressions/NhStarExpression.cs | 18 + .../Linq/Expressions/NhSumExpression.cs | 11 +- .../GroupBy/GroupBySelectClauseRewriter.cs | 42 ++- .../Linq/GroupBy/GroupKeyNominator.cs | 30 +- ...IsNonAggregatingGroupByDetectionVisitor.cs | 18 +- .../Linq/GroupBy/KeySelectorVisitor.cs | 6 +- .../GroupBy/NonAggregatingGroupByRewriter.cs | 34 +- .../GroupJoin/AggregatingGroupJoinRewriter.cs | 4 +- .../GroupJoinAggregateDetectionVisitor.cs | 26 +- .../GroupJoinSelectClauseRewriter.cs | 6 +- .../GroupJoin/LocateGroupJoinQuerySource.cs | 11 +- .../NonAggregatingGroupJoinRewriter.cs | 40 +-- src/NHibernate/Linq/LinqExtensionMethods.cs | 12 +- src/NHibernate/Linq/LinqLogging.cs | 13 +- .../NestedSelects/NestedSelectDetector.cs | 26 +- .../NestedSelects/NestedSelectRewriter.cs | 6 +- .../NestedSelects/SelectClauseRewriter.cs | 54 ++- src/NHibernate/Linq/NestedSelects/Tuple.cs | 8 +- src/NHibernate/Linq/NhLinqExpression.cs | 8 +- src/NHibernate/Linq/NhRelinqQueryParser.cs | 7 +- .../ArrayIndexExpressionFlattener.cs | 16 +- .../MergeAggregatingResultsRewriter.cs | 130 ++++---- .../ReWriters/MoveOrderByToEndRewriter.cs | 6 +- .../QueryReferenceExpressionFlattener.cs | 22 +- .../Linq/ReWriters/ResultOperatorRewriter.cs | 65 ++-- .../ReWriters/ResultOperatorRewriterResult.cs | 49 ++- .../Linq/Visitors/EqualityHqlGenerator.cs | 10 +- .../Linq/Visitors/ExpressionKeyVisitor.cs | 72 ++-- .../Visitors/ExpressionParameterVisitor.cs | 23 +- ...or.cs => HqlGeneratorExpressionVisitor.cs} | 264 ++++++--------- src/NHibernate/Linq/Visitors/JoinBuilder.cs | 8 +- .../Linq/Visitors/LeftJoinRewriter.cs | 6 +- .../Visitors/MemberExpressionJoinDetector.cs | 38 +-- .../Linq/Visitors/NhExpressionTreeVisitor.cs | 98 ------ .../Linq/Visitors/NhExpressionVisitor.cs | 70 ++++ ...hPartialEvaluatingExpressionTreeVisitor.cs | 32 -- .../NhPartialEvaluatingExpressionVisitor.cs | 42 +++ .../PagingRewriterSelectClauseVisitor.cs | 18 +- .../QueryExpressionSourceIdentifer.cs | 6 +- .../Linq/Visitors/QueryModelVisitor.cs | 39 ++- .../Linq/Visitors/QuerySourceIdentifier.cs | 7 +- .../Linq/Visitors/QuerySourceLocator.cs | 53 ++- .../ProcessAggregate.cs | 4 +- .../ProcessAggregateFromSeed.cs | 4 +- .../ResultOperatorProcessors/ProcessAll.cs | 2 +- .../ProcessContains.cs | 2 +- .../ProcessGroupBy.cs | 2 +- .../ProcessNonAggregatingGroupBy.cs | 9 +- .../ResultOperatorProcessors/ProcessOfType.cs | 2 +- .../Linq/Visitors/SelectClauseNominator.cs | 49 ++- .../Linq/Visitors/SelectClauseVisitor.cs | 29 +- .../Visitors/SimplifyConditionalVisitor.cs | 33 +- .../Visitors/SubQueryFromClauseFlattener.cs | 20 +- .../Linq/Visitors/SwapQuerySourceVisitor.cs | 20 +- .../Linq/Visitors/VisitorParameters.cs | 18 +- src/NHibernate/Linq/Visitors/VisitorUtil.cs | 30 +- .../Linq/Visitors/WhereJoinDetector.cs | 314 +++++++++--------- src/NHibernate/NHibernate.csproj | 22 +- src/NHibernate/app.config | 11 + src/NHibernate/packages.config | 14 +- 83 files changed, 1142 insertions(+), 1264 deletions(-) create mode 100644 src/NHibernate.DomainModel/app.config create mode 100644 src/NHibernate/Linq/DBOnlyAttribute.cs create mode 100644 src/NHibernate/Linq/Expressions/NhExpression.cs create mode 100644 src/NHibernate/Linq/Expressions/NhStarExpression.cs rename src/NHibernate/Linq/Visitors/{HqlGeneratorExpressionTreeVisitor.cs => HqlGeneratorExpressionVisitor.cs} (59%) delete mode 100644 src/NHibernate/Linq/Visitors/NhExpressionTreeVisitor.cs create mode 100644 src/NHibernate/Linq/Visitors/NhExpressionVisitor.cs delete mode 100644 src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionTreeVisitor.cs create mode 100644 src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs create mode 100644 src/NHibernate/app.config diff --git a/doc/reference/modules/query_linq.xml b/doc/reference/modules/query_linq.xml index 44f413df1af..8120234d31b 100644 --- a/doc/reference/modules/query_linq.xml +++ b/doc/reference/modules/query_linq.xml @@ -539,16 +539,18 @@ IList cats = - The method call will always be translated to SQL if at least one of the parameters of the - method call has its value originating from an entity. Otherwise, the Linq provider will try to - evaluate the method call with .Net runtime instead. Since NHibernate 5.0, if this runtime - evaluation fails (throws an exception), then the method call will be translated to SQL too. + By default, the Linq provider will try to evaluate the method call with .Net runtime + whenever possible, instead of translating it to SQL. It will not do it if at least one + of the parameters of the method call has its value originating from an entity, or if + the method is marked with the DBOnly attribute (available since + NHibernate 5.0). @@ -656,6 +658,13 @@ cfg.LinqToHqlGeneratorsRegistry(); (Of course, the same result could be obtained with (DateTime?)(c.BirthDate).) + + By default, the Linq provider will try to evaluate the method call with .Net runtime + whenever possible, instead of translating it to SQL. It will not do it if at least one + of the parameters of the method call has its value originating from an entity, or if + the method is marked with the DBOnly attribute (available since + NHibernate 5.0). + \ No newline at end of file diff --git a/src/NHibernate.DomainModel/NHibernate.DomainModel.csproj b/src/NHibernate.DomainModel/NHibernate.DomainModel.csproj index 0f4e5ed7867..2b78c00ad19 100644 --- a/src/NHibernate.DomainModel/NHibernate.DomainModel.csproj +++ b/src/NHibernate.DomainModel/NHibernate.DomainModel.csproj @@ -302,6 +302,7 @@ + diff --git a/src/NHibernate.DomainModel/app.config b/src/NHibernate.DomainModel/app.config new file mode 100644 index 00000000000..246cc8bf759 --- /dev/null +++ b/src/NHibernate.DomainModel/app.config @@ -0,0 +1,11 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/src/NHibernate.Test/App.config b/src/NHibernate.Test/App.config index 83a2785b7d4..3e2c689c08b 100644 --- a/src/NHibernate.Test/App.config +++ b/src/NHibernate.Test/App.config @@ -1,4 +1,4 @@ - + @@ -15,6 +15,10 @@ + + + + @@ -66,7 +70,7 @@ - + @@ -98,12 +102,12 @@ - + - - + + diff --git a/src/NHibernate.Test/Linq/CustomQueryModelRewriterTests.cs b/src/NHibernate.Test/Linq/CustomQueryModelRewriterTests.cs index a22305a9966..4831b099863 100644 --- a/src/NHibernate.Test/Linq/CustomQueryModelRewriterTests.cs +++ b/src/NHibernate.Test/Linq/CustomQueryModelRewriterTests.cs @@ -44,12 +44,12 @@ public class CustomVisitor : QueryModelVisitorBase { public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index) { - whereClause.TransformExpressions(new Visitor().VisitExpression); + whereClause.TransformExpressions(new Visitor().Visit); } - private class Visitor : ExpressionTreeVisitor + private class Visitor : RelinqExpressionVisitor { - protected override Expression VisitBinaryExpression(BinaryExpression expression) + protected override Expression VisitBinary(BinaryExpression expression) { if ( expression.NodeType == ExpressionType.Equal || @@ -68,9 +68,7 @@ protected override Expression VisitBinaryExpression(BinaryExpression expression) reverse = true; } - var constant = left as ConstantExpression; - - if (constant != null && constant.Value == null) + if (left is ConstantExpression constant && constant.Value == null) { left = Expression.Constant("Thomas Hardy"); @@ -82,7 +80,7 @@ protected override Expression VisitBinaryExpression(BinaryExpression expression) } } - return base.VisitBinaryExpression(expression); + return base.VisitBinary(expression); } } } diff --git a/src/NHibernate.Test/NHSpecificTest/NH3386/Fixture.cs b/src/NHibernate.Test/NHSpecificTest/NH3386/Fixture.cs index 0078739a8db..1b797a94993 100644 --- a/src/NHibernate.Test/NHSpecificTest/NH3386/Fixture.cs +++ b/src/NHibernate.Test/NHSpecificTest/NH3386/Fixture.cs @@ -61,6 +61,7 @@ public void ShouldSupportNonRuntimeExtensionWithoutEntityReference() public static class SqlServerFunction { [LinqExtensionMethod] + [DBOnly] public static Guid NewID() { throw new InvalidOperationException("To be translated to SQL only"); diff --git a/src/NHibernate.Test/NHibernate.Test.csproj b/src/NHibernate.Test/NHibernate.Test.csproj index f46c51074ba..02414f8b2c2 100644 --- a/src/NHibernate.Test/NHibernate.Test.csproj +++ b/src/NHibernate.Test/NHibernate.Test.csproj @@ -83,7 +83,14 @@ ..\packages\NUnit.3.6.0\lib\net45\nunit.framework.dll True + + ..\packages\Remotion.Linq.2.1.1\lib\net45\Remotion.Linq.dll + + + ..\packages\Remotion.Linq.EagerFetching.2.0.1\lib\net45\Remotion.Linq.EagerFetching.dll + + 3.5 @@ -100,9 +107,6 @@ 3.5 - - ..\packages\Remotion.Linq.1.15.15.0\lib\portable-net45+wp80+wpa81+win\Remotion.Linq.dll - diff --git a/src/NHibernate.Test/packages.config b/src/NHibernate.Test/packages.config index 7e19d547e61..6448f0abb7a 100644 --- a/src/NHibernate.Test/packages.config +++ b/src/NHibernate.Test/packages.config @@ -4,6 +4,18 @@ - + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/NHibernate.TestDatabaseSetup/App.config b/src/NHibernate.TestDatabaseSetup/App.config index 12027f3e8e8..c90c0b4bcef 100644 --- a/src/NHibernate.TestDatabaseSetup/App.config +++ b/src/NHibernate.TestDatabaseSetup/App.config @@ -1,8 +1,7 @@ - + -
+
@@ -13,4 +12,13 @@ - + + + + + + + + + + diff --git a/src/NHibernate/Linq/DBOnlyAttribute.cs b/src/NHibernate/Linq/DBOnlyAttribute.cs new file mode 100644 index 00000000000..680b4dff735 --- /dev/null +++ b/src/NHibernate/Linq/DBOnlyAttribute.cs @@ -0,0 +1,14 @@ +using System; + +namespace NHibernate.Linq +{ + /// + /// Indicates to the Linq-to-NHibernate provider a method that must not be evaluated. If supported, + /// it will always be converted to the corresponding SQL statement. + /// + public class DBOnlyAttribute: Attribute + { + public DBOnlyAttribute() + { } + } +} \ No newline at end of file diff --git a/src/NHibernate/Linq/DefaultQueryProvider.cs b/src/NHibernate/Linq/DefaultQueryProvider.cs index e2c9d09f441..9b46040dc31 100644 --- a/src/NHibernate/Linq/DefaultQueryProvider.cs +++ b/src/NHibernate/Linq/DefaultQueryProvider.cs @@ -35,9 +35,7 @@ protected virtual ISessionImplementor Session public virtual object Execute(Expression expression) { - IQuery query; - NhLinqExpression nhQuery; - NhLinqExpression nhLinqExpression = PrepareQuery(expression, out query, out nhQuery); + var nhLinqExpression = PrepareQuery(expression, out IQuery query, out NhLinqExpression nhQuery); return ExecuteQuery(nhLinqExpression, query, nhQuery); } @@ -61,9 +59,7 @@ public virtual IQueryable CreateQuery(Expression expression) public virtual object ExecuteFuture(Expression expression) { - IQuery query; - NhLinqExpression nhQuery; - NhLinqExpression nhLinqExpression = PrepareQuery(expression, out query, out nhQuery); + var nhLinqExpression = PrepareQuery(expression, out IQuery query, out NhLinqExpression nhQuery); return ExecuteFutureQuery(nhLinqExpression, query, nhQuery); } diff --git a/src/NHibernate/Linq/ExpressionTransformers/RemoveCharToIntConversion.cs b/src/NHibernate/Linq/ExpressionTransformers/RemoveCharToIntConversion.cs index e411da5d7ab..53d46a9d905 100644 --- a/src/NHibernate/Linq/ExpressionTransformers/RemoveCharToIntConversion.cs +++ b/src/NHibernate/Linq/ExpressionTransformers/RemoveCharToIntConversion.cs @@ -1,6 +1,6 @@ using System; using System.Linq.Expressions; -using Remotion.Linq.Parsing.ExpressionTreeVisitors.Transformation; +using Remotion.Linq.Parsing.ExpressionVisitors.Transformation; namespace NHibernate.Linq.ExpressionTransformers { @@ -57,18 +57,12 @@ public Expression Transform(BinaryExpression expression) } private static bool IsConvertExpression(Expression expression) - { - return (expression.NodeType == ExpressionType.Convert); - } + => expression.NodeType == ExpressionType.Convert; private static bool IsConstantExpression(Expression expression) - { - return (expression.NodeType == ExpressionType.Constant); - } + => expression.NodeType == ExpressionType.Constant; public ExpressionType[] SupportedExpressionTypes - { - get { return _supportedExpressionTypes; } - } + => _supportedExpressionTypes; } } \ No newline at end of file diff --git a/src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs b/src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs index d5a87a97eb9..538e46cb828 100644 --- a/src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs +++ b/src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs @@ -1,5 +1,5 @@ using System.Linq.Expressions; -using Remotion.Linq.Parsing.ExpressionTreeVisitors.Transformation; +using Remotion.Linq.Parsing.ExpressionVisitors.Transformation; namespace NHibernate.Linq.ExpressionTransformers { diff --git a/src/NHibernate/Linq/ExpressionTransformers/SimplifyCompareTransformer.cs b/src/NHibernate/Linq/ExpressionTransformers/SimplifyCompareTransformer.cs index 73ee8d974f6..5dc00b7fe4b 100644 --- a/src/NHibernate/Linq/ExpressionTransformers/SimplifyCompareTransformer.cs +++ b/src/NHibernate/Linq/ExpressionTransformers/SimplifyCompareTransformer.cs @@ -5,7 +5,7 @@ using System.Reflection; using NHibernate.Linq.Functions; using NHibernate.Util; -using Remotion.Linq.Parsing.ExpressionTreeVisitors.Transformation; +using Remotion.Linq.Parsing.ExpressionVisitors.Transformation; namespace NHibernate.Linq.ExpressionTransformers { diff --git a/src/NHibernate/Linq/Expressions/NhAggregatedExpression.cs b/src/NHibernate/Linq/Expressions/NhAggregatedExpression.cs index e564884e771..a3a8c214fd5 100644 --- a/src/NHibernate/Linq/Expressions/NhAggregatedExpression.cs +++ b/src/NHibernate/Linq/Expressions/NhAggregatedExpression.cs @@ -1,34 +1,17 @@ using System.Linq.Expressions; -using Remotion.Linq.Clauses.Expressions; -using Remotion.Linq.Parsing; +using NHibernate.Linq.Visitors; namespace NHibernate.Linq.Expressions { - public abstract class NhAggregatedExpression : ExtensionExpression + public abstract class NhAggregatedExpression : NhSimpleExpression { - public Expression Expression { get; set; } + protected NhAggregatedExpression(Expression expression) + : base(expression) { } - protected NhAggregatedExpression(Expression expression, NhExpressionType type) - : base(expression.Type, (ExpressionType)type) - { - Expression = expression; - } + protected NhAggregatedExpression(Expression expression, System.Type expressionType) + : base(expression, expressionType) { } - protected NhAggregatedExpression(Expression expression, System.Type expressionType, NhExpressionType type) - : base(expressionType, (ExpressionType)type) - { - Expression = expression; - } - - protected override Expression VisitChildren(ExpressionTreeVisitor visitor) - { - var newExpression = visitor.VisitExpression(Expression); - - return newExpression != Expression - ? CreateNew(newExpression) - : this; - } - - public abstract Expression CreateNew(Expression expression); + protected override Expression Accept(NhExpressionVisitor visitor) + => visitor.VisitNhAggregate(this); } } \ No newline at end of file diff --git a/src/NHibernate/Linq/Expressions/NhAverageExpression.cs b/src/NHibernate/Linq/Expressions/NhAverageExpression.cs index 9dffa5f67fe..9496f57619e 100644 --- a/src/NHibernate/Linq/Expressions/NhAverageExpression.cs +++ b/src/NHibernate/Linq/Expressions/NhAverageExpression.cs @@ -6,9 +6,8 @@ namespace NHibernate.Linq.Expressions { public class NhAverageExpression : NhAggregatedExpression { - public NhAverageExpression(Expression expression) : base(expression, CalculateAverageType(expression.Type), NhExpressionType.Average) - { - } + public NhAverageExpression(Expression expression) + : base(expression, CalculateAverageType(expression.Type)) { } private static System.Type CalculateAverageType(System.Type inputType) { @@ -27,7 +26,7 @@ private static System.Type CalculateAverageType(System.Type inputType) case TypeCode.Int64: case TypeCode.Single: case TypeCode.Double: - return isNullable ? typeof(double?) : typeof (double); + return isNullable ? typeof(double?) : typeof(double); case TypeCode.Decimal: return isNullable ? typeof(decimal?) : typeof(decimal); } @@ -35,9 +34,8 @@ private static System.Type CalculateAverageType(System.Type inputType) throw new NotSupportedException(inputType.FullName); } - public override Expression CreateNew(Expression expression) - { - return new NhAverageExpression(expression); - } + public override NhExpressionType NhNodeType => NhExpressionType.Average; + + public override Expression CreateNew(Expression expression) => new NhAverageExpression(expression); } } \ No newline at end of file diff --git a/src/NHibernate/Linq/Expressions/NhCountExpression.cs b/src/NHibernate/Linq/Expressions/NhCountExpression.cs index 8c9024e0280..daee8bd3262 100644 --- a/src/NHibernate/Linq/Expressions/NhCountExpression.cs +++ b/src/NHibernate/Linq/Expressions/NhCountExpression.cs @@ -5,28 +5,26 @@ namespace NHibernate.Linq.Expressions public abstract class NhCountExpression : NhAggregatedExpression { protected NhCountExpression(Expression expression, System.Type type) - : base(expression, type, NhExpressionType.Count) {} + : base(expression, type) { } + + public override NhExpressionType NhNodeType => NhExpressionType.Count; + + public bool IsCountStar => Expression is NhStarExpression; } public class NhShortCountExpression : NhCountExpression { public NhShortCountExpression(Expression expression) - : base(expression, typeof (int)) {} + : base(expression, typeof(int)) { } - public override Expression CreateNew(Expression expression) - { - return new NhShortCountExpression(expression); - } + public override Expression CreateNew(Expression expression) => new NhShortCountExpression(expression); } public class NhLongCountExpression : NhCountExpression { public NhLongCountExpression(Expression expression) - : base(expression, typeof (long)) {} + : base(expression, typeof(long)) { } - public override Expression CreateNew(Expression expression) - { - return new NhLongCountExpression(expression); - } + public override Expression CreateNew(Expression expression) => new NhLongCountExpression(expression); } } \ No newline at end of file diff --git a/src/NHibernate/Linq/Expressions/NhDistinctExpression.cs b/src/NHibernate/Linq/Expressions/NhDistinctExpression.cs index 4ddc970adfd..2f73aa8b0e2 100644 --- a/src/NHibernate/Linq/Expressions/NhDistinctExpression.cs +++ b/src/NHibernate/Linq/Expressions/NhDistinctExpression.cs @@ -5,13 +5,10 @@ namespace NHibernate.Linq.Expressions public class NhDistinctExpression : NhAggregatedExpression { public NhDistinctExpression(Expression expression) - : base(expression, NhExpressionType.Distinct) - { - } + : base(expression) { } - public override Expression CreateNew(Expression expression) - { - return new NhDistinctExpression(expression); - } + public override NhExpressionType NhNodeType => NhExpressionType.Distinct; + + public override Expression CreateNew(Expression expression) => new NhDistinctExpression(expression); } } \ No newline at end of file diff --git a/src/NHibernate/Linq/Expressions/NhExpression.cs b/src/NHibernate/Linq/Expressions/NhExpression.cs new file mode 100644 index 00000000000..3729009315f --- /dev/null +++ b/src/NHibernate/Linq/Expressions/NhExpression.cs @@ -0,0 +1,48 @@ +using System.Linq.Expressions; +using NHibernate.Linq.Visitors; + +namespace NHibernate.Linq.Expressions +{ + public abstract class NhExpression : Expression + { + public override ExpressionType NodeType => ExpressionType.Extension; + + public abstract NhExpressionType NhNodeType { get; } + + protected override Expression Accept(ExpressionVisitor visitor) + { + if (visitor is NhExpressionVisitor nhVisitor) + Accept(nhVisitor); + return base.Accept(visitor); + } + + protected abstract Expression Accept(NhExpressionVisitor visitor); + } + + public abstract class NhSimpleExpression : NhExpression + { + protected NhSimpleExpression(Expression expression) + : this(expression, expression.Type) { } + + protected NhSimpleExpression(Expression expression, System.Type expressionType) + { + Expression = expression; + Type = expressionType; + } + + public Expression Expression { get; } + + public override System.Type Type { get; } + + protected override Expression VisitChildren(ExpressionVisitor visitor) + { + var newExpression = visitor.Visit(Expression); + + return newExpression != Expression + ? CreateNew(newExpression) + : this; + } + + public abstract Expression CreateNew(Expression expression); + } +} \ No newline at end of file diff --git a/src/NHibernate/Linq/Expressions/NhMaxExpression.cs b/src/NHibernate/Linq/Expressions/NhMaxExpression.cs index b4b536fabd0..3ebb9b45f51 100644 --- a/src/NHibernate/Linq/Expressions/NhMaxExpression.cs +++ b/src/NHibernate/Linq/Expressions/NhMaxExpression.cs @@ -5,13 +5,10 @@ namespace NHibernate.Linq.Expressions public class NhMaxExpression : NhAggregatedExpression { public NhMaxExpression(Expression expression) - : base(expression, NhExpressionType.Max) - { - } + : base(expression) { } - public override Expression CreateNew(Expression expression) - { - return new NhMaxExpression(expression); - } + public override NhExpressionType NhNodeType => NhExpressionType.Max; + + public override Expression CreateNew(Expression expression) => new NhMaxExpression(expression); } } \ No newline at end of file diff --git a/src/NHibernate/Linq/Expressions/NhMinExpression.cs b/src/NHibernate/Linq/Expressions/NhMinExpression.cs index e8eb33570dc..cdc560b1577 100644 --- a/src/NHibernate/Linq/Expressions/NhMinExpression.cs +++ b/src/NHibernate/Linq/Expressions/NhMinExpression.cs @@ -5,13 +5,10 @@ namespace NHibernate.Linq.Expressions public class NhMinExpression : NhAggregatedExpression { public NhMinExpression(Expression expression) - : base(expression, NhExpressionType.Min) - { - } + : base(expression) { } - public override Expression CreateNew(Expression expression) - { - return new NhMinExpression(expression); - } + public override NhExpressionType NhNodeType => NhExpressionType.Min; + + public override Expression CreateNew(Expression expression) => new NhMinExpression(expression); } } \ No newline at end of file diff --git a/src/NHibernate/Linq/Expressions/NhNewExpression.cs b/src/NHibernate/Linq/Expressions/NhNewExpression.cs index 5c55e020f17..037a83defcc 100644 --- a/src/NHibernate/Linq/Expressions/NhNewExpression.cs +++ b/src/NHibernate/Linq/Expressions/NhNewExpression.cs @@ -1,64 +1,36 @@ using System.Collections.Generic; using System.Collections.ObjectModel; using System.Linq.Expressions; -using Remotion.Linq.Clauses.Expressions; -using Remotion.Linq.Parsing; +using NHibernate.Linq.Visitors; namespace NHibernate.Linq.Expressions { - public class NhNewExpression : ExtensionExpression + public class NhNewExpression : NhExpression { - private readonly ReadOnlyCollection _members; - private readonly ReadOnlyCollection _arguments; - public NhNewExpression(IList members, IList arguments) - : base(typeof(object), (ExpressionType)NhExpressionType.New) { - _members = new ReadOnlyCollection(members); - _arguments = new ReadOnlyCollection(arguments); + Members = new ReadOnlyCollection(members); + Arguments = new ReadOnlyCollection(arguments); } - public ReadOnlyCollection Arguments - { - get { return _arguments; } - } + public override System.Type Type => typeof(object); - public ReadOnlyCollection Members - { - get { return _members; } - } + public override NhExpressionType NhNodeType => NhExpressionType.New; - protected override Expression VisitChildren(ExpressionTreeVisitor visitor) - { - var arguments = visitor.VisitAndConvert(Arguments, "VisitNhNew"); + public ReadOnlyCollection Arguments { get; } - return arguments != Arguments - ? new NhNewExpression(Members, arguments) - : this; - } - } + public ReadOnlyCollection Members { get; } - public class NhStarExpression : ExtensionExpression - { - public NhStarExpression(Expression expression) - : base(expression.Type, (ExpressionType)NhExpressionType.Star) + protected override Expression VisitChildren(ExpressionVisitor visitor) { - Expression = expression; - } + var arguments = visitor.VisitAndConvert(Arguments, "VisitNhNew"); - public Expression Expression - { - get; - private set; + return arguments != Arguments + ? new NhNewExpression(Members, arguments) + : this; } - protected override Expression VisitChildren(ExpressionTreeVisitor visitor) - { - var newExpression = visitor.VisitExpression(Expression); - - return newExpression != Expression - ? new NhStarExpression(newExpression) - : this; - } + protected override Expression Accept(NhExpressionVisitor visitor) + => visitor.VisitNhNew(this); } } \ No newline at end of file diff --git a/src/NHibernate/Linq/Expressions/NhNominatedExpression.cs b/src/NHibernate/Linq/Expressions/NhNominatedExpression.cs index 15997729997..d4693ac3877 100644 --- a/src/NHibernate/Linq/Expressions/NhNominatedExpression.cs +++ b/src/NHibernate/Linq/Expressions/NhNominatedExpression.cs @@ -1,34 +1,27 @@ using System.Linq.Expressions; -using Remotion.Linq.Clauses.Expressions; -using Remotion.Linq.Parsing; +using NHibernate.Linq.Visitors; namespace NHibernate.Linq.Expressions { /// /// Represents an expression that has been nominated for direct inclusion in the SELECT clause. - /// This bypasses the standard nomination process and assumes that the expression can be converted + /// This bypasses the standard nomination process and assumes that the expression can be converted /// directly to SQL. /// /// /// Used in the nomination of GroupBy key expressions to ensure that matching select clauses /// are generated the same way. /// - internal class NhNominatedExpression : ExtensionExpression + internal class NhNominatedExpression : NhSimpleExpression { - public Expression Expression { get; private set; } + public NhNominatedExpression(Expression expression) + : base(expression) { } - public NhNominatedExpression(Expression expression) : base(expression.Type, (ExpressionType)NhExpressionType.Nominator) - { - Expression = expression; - } + public override NhExpressionType NhNodeType => NhExpressionType.Nominator; - protected override Expression VisitChildren(ExpressionTreeVisitor visitor) - { - var newExpression = visitor.VisitExpression(Expression); + public override Expression CreateNew(Expression expression) => new NhNominatedExpression(expression); - return newExpression != Expression - ? new NhNominatedExpression(newExpression) - : this; - } + protected override Expression Accept(NhExpressionVisitor visitor) + => visitor.VisitNhNominated(this); } } \ No newline at end of file diff --git a/src/NHibernate/Linq/Expressions/NhStarExpression.cs b/src/NHibernate/Linq/Expressions/NhStarExpression.cs new file mode 100644 index 00000000000..88478dbd7e9 --- /dev/null +++ b/src/NHibernate/Linq/Expressions/NhStarExpression.cs @@ -0,0 +1,18 @@ +using System.Linq.Expressions; +using NHibernate.Linq.Visitors; + +namespace NHibernate.Linq.Expressions +{ + public class NhStarExpression : NhSimpleExpression + { + public NhStarExpression(Expression expression) + : base(expression) { } + + public override NhExpressionType NhNodeType => NhExpressionType.Star; + + public override Expression CreateNew(Expression expression) => new NhStarExpression(expression); + + protected override Expression Accept(NhExpressionVisitor visitor) + => visitor.VisitNhStar(this); + } +} \ No newline at end of file diff --git a/src/NHibernate/Linq/Expressions/NhSumExpression.cs b/src/NHibernate/Linq/Expressions/NhSumExpression.cs index d8e7326eda0..4ae80a1d9d2 100644 --- a/src/NHibernate/Linq/Expressions/NhSumExpression.cs +++ b/src/NHibernate/Linq/Expressions/NhSumExpression.cs @@ -5,13 +5,10 @@ namespace NHibernate.Linq.Expressions public class NhSumExpression : NhAggregatedExpression { public NhSumExpression(Expression expression) - : base(expression, NhExpressionType.Sum) - { - } + : base(expression) { } - public override Expression CreateNew(Expression expression) - { - return new NhSumExpression(expression); - } + public override NhExpressionType NhNodeType => NhExpressionType.Sum; + + public override Expression CreateNew(Expression expression) => new NhSumExpression(expression); } } \ No newline at end of file diff --git a/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs b/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs index 5290f3189dd..b5102b70bd6 100644 --- a/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs +++ b/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs @@ -7,17 +7,17 @@ using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Clauses.ResultOperators; using Remotion.Linq.Parsing; -using Remotion.Linq.Parsing.ExpressionTreeVisitors; +using Remotion.Linq.Parsing.ExpressionVisitors; namespace NHibernate.Linq.GroupBy { - //This should be renamed. It handles entire querymodels, not just select clauses - internal class GroupBySelectClauseRewriter : ExpressionTreeVisitor + //This should be renamed. It handles entire query models, not just select clauses + internal class GroupBySelectClauseRewriter : RelinqExpressionVisitor { public static Expression ReWrite(Expression expression, GroupResultOperator groupBy, QueryModel model) { var visitor = new GroupBySelectClauseRewriter(groupBy, model); - return TransparentIdentifierRemovingExpressionTreeVisitor.ReplaceTransparentIdentifiers(visitor.VisitExpression(expression)); + return TransparentIdentifierRemovingExpressionVisitor.ReplaceTransparentIdentifiers(visitor.Visit(expression)); } private readonly GroupResultOperator _groupBy; @@ -31,11 +31,11 @@ private GroupBySelectClauseRewriter(GroupResultOperator groupBy, QueryModel mode _nominatedKeySelector = GroupKeyNominator.Visit(groupBy); } - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { if (!IsMemberOfModel(expression)) { - return base.VisitQuerySourceReferenceExpression(expression); + return base.VisitQuerySourceReference(expression); } if (expression.IsGroupingElementOf(_groupBy)) @@ -43,14 +43,14 @@ protected override Expression VisitQuerySourceReferenceExpression(QuerySourceRef return _groupBy.ElementSelector; } - return base.VisitQuerySourceReferenceExpression(expression); + return base.VisitQuerySourceReference(expression); } - protected override Expression VisitMemberExpression(MemberExpression expression) + protected override Expression VisitMember(MemberExpression expression) { if (!IsMemberOfModel(expression)) { - return base.VisitMemberExpression(expression); + return base.VisitMember(expression); } if (expression.IsGroupingKeyOf(_groupBy)) @@ -64,7 +64,7 @@ protected override Expression VisitMemberExpression(MemberExpression expression) if ((elementSelector is MemberExpression) || (elementSelector is QuerySourceReferenceExpression)) { // If ElementSelector is MemberExpression, just return - return base.VisitMemberExpression(expression); + return base.VisitMember(expression); } if ((elementSelector is NewExpression || elementSelector.NodeType == ExpressionType.Convert) @@ -77,7 +77,7 @@ protected override Expression VisitMemberExpression(MemberExpression expression) throw new NotImplementedException(); } - // TODO - dislike this code intensly. Should probably be a tree-walk in its own right + // TODO - dislike this code intensely. Should probably be a tree-walk in its own right private bool IsMemberOfModel(MemberExpression expression) { var querySourceRef = expression.Expression as QuerySourceReferenceExpression; @@ -99,9 +99,7 @@ private bool IsMemberOfModel(QuerySourceReferenceExpression expression) return false; } - var subQuery = fromClause.FromExpression as SubQueryExpression; - - if (subQuery != null) + if (fromClause.FromExpression is SubQueryExpression subQuery) { return subQuery.QueryModel == _model; } @@ -120,20 +118,18 @@ private bool IsMemberOfModel(QuerySourceReferenceExpression expression) return subQuery2 != null && subQuery2.QueryModel == _model; } - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { //If the subquery is a Count(*) aggregate with a condition if (expression.QueryModel.MainFromClause.FromExpression.Type == _groupBy.ItemType) { var where = expression.QueryModel.BodyClauses.OfType().FirstOrDefault(); - NhCountExpression countExpression; - if (where != null && (countExpression = expression.QueryModel.SelectClause.Selector as NhCountExpression) != - null && countExpression.Expression.NodeType == (ExpressionType)NhExpressionType.Star) + if (where != null && expression.QueryModel.SelectClause.Selector is NhCountExpression countExpression && + countExpression.IsCountStar) { //return it as a CASE [column] WHEN [predicate] THEN 1 ELSE NULL END - return - countExpression.CreateNew(Expression.Condition(where.Predicate, Expression.Constant(1, typeof(int?)), - Expression.Constant(null, typeof(int?)))); + return countExpression.CreateNew(Expression.Condition(where.Predicate, Expression.Constant(1, typeof(int?)), + Expression.Constant(null, typeof(int?)))); } } @@ -144,9 +140,9 @@ protected override Expression VisitSubQueryExpression(SubQueryExpression express { foreach (var bodyClause in expression.QueryModel.BodyClauses) { - bodyClause.TransformExpressions((e) => new KeySelectorVisitor(_groupBy).VisitExpression(e)); + bodyClause.TransformExpressions((e) => new KeySelectorVisitor(_groupBy).Visit(e)); } - return base.VisitSubQueryExpression(expression); + return base.VisitSubQuery(expression); } diff --git a/src/NHibernate/Linq/GroupBy/GroupKeyNominator.cs b/src/NHibernate/Linq/GroupBy/GroupKeyNominator.cs index c550f9aa0e6..b2ecae95243 100644 --- a/src/NHibernate/Linq/GroupBy/GroupKeyNominator.cs +++ b/src/NHibernate/Linq/GroupBy/GroupKeyNominator.cs @@ -12,7 +12,7 @@ namespace NHibernate.Linq.GroupBy /// This class nominates sub-expression trees on the GroupBy Key expression /// for inclusion in the Select clause. /// - internal class GroupKeyNominator : ExpressionTreeVisitor + internal class GroupKeyNominator : RelinqExpressionVisitor { private GroupKeyNominator() { } @@ -27,13 +27,13 @@ public static Expression Visit(GroupResultOperator groupBy) private static Expression VisitInternal(Expression expr) { - return new GroupKeyNominator().VisitExpression(expr); + return new GroupKeyNominator().Visit(expr); } - public override Expression VisitExpression(Expression expression) + public override Expression Visit(Expression expression) { _depth++; - var expr = base.VisitExpression(expression); + var expr = base.Visit(expression); _depth--; // At the root expression, wrap it in the nominator expression if needed @@ -44,45 +44,45 @@ public override Expression VisitExpression(Expression expression) return expr; } - protected override Expression VisitNewArrayExpression(NewArrayExpression expression) + protected override Expression VisitNewArray(NewArrayExpression expression) { _transformed = true; // Transform each initializer recursively (to allow for nested initializers) return Expression.NewArrayInit(expression.Type.GetElementType(), expression.Expressions.Select(VisitInternal)); } - protected override Expression VisitNewExpression(NewExpression expression) + protected override Expression VisitNew(NewExpression expression) { _transformed = true; // Transform each initializer recursively (to allow for nested initializers) return Expression.New(expression.Constructor, expression.Arguments.Select(VisitInternal), expression.Members); } - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { // If the (sub)expression contains a QuerySourceReference, then the entire expression should be nominated _requiresRootNomination = true; - return base.VisitQuerySourceReferenceExpression(expression); + return base.VisitQuerySourceReference(expression); } - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { // If the (sub)expression contains a QuerySourceReference, then the entire expression should be nominated _requiresRootNomination = true; - return base.VisitSubQueryExpression(expression); + return base.VisitSubQuery(expression); } - protected override Expression VisitBinaryExpression(BinaryExpression expression) + protected override Expression VisitBinary(BinaryExpression expression) { if (expression.NodeType != ExpressionType.ArrayIndex) - return base.VisitBinaryExpression(expression); + return base.VisitBinary(expression); // If we encounter an array index then we need to attempt to flatten it before nomination - var flattenedExpression = new ArrayIndexExpressionFlattener().VisitExpression(expression); + var flattenedExpression = new ArrayIndexExpressionFlattener().Visit(expression); if (flattenedExpression != expression) - return base.VisitExpression(flattenedExpression); + return base.Visit(flattenedExpression); - return base.VisitBinaryExpression(expression); + return base.VisitBinary(expression); } } } \ No newline at end of file diff --git a/src/NHibernate/Linq/GroupBy/IsNonAggregatingGroupByDetectionVisitor.cs b/src/NHibernate/Linq/GroupBy/IsNonAggregatingGroupByDetectionVisitor.cs index 97478d29fbf..12a355bdfe3 100644 --- a/src/NHibernate/Linq/GroupBy/IsNonAggregatingGroupByDetectionVisitor.cs +++ b/src/NHibernate/Linq/GroupBy/IsNonAggregatingGroupByDetectionVisitor.cs @@ -8,7 +8,7 @@ namespace NHibernate.Linq.GroupBy /// /// Detects if an expression tree contains naked QuerySourceReferenceExpression /// - internal class IsNonAggregatingGroupByDetectionVisitor : NhExpressionTreeVisitor + internal class IsNonAggregatingGroupByDetectionVisitor : NhExpressionVisitor { private bool _containsNakedQuerySourceReferenceExpression; @@ -16,24 +16,22 @@ public bool IsNonAggregatingGroupBy(Expression expression) { _containsNakedQuerySourceReferenceExpression = false; - VisitExpression(expression); + Visit(expression); return _containsNakedQuerySourceReferenceExpression; } - protected override Expression VisitMemberExpression(MemberExpression expression) + protected override Expression VisitMember(MemberExpression expression) { return expression.IsGroupingKey() - ? expression - : base.VisitMemberExpression(expression); + ? expression + : base.VisitMember(expression); } - protected override Expression VisitNhAggregate(NhAggregatedExpression expression) - { - return expression; - } + public override Expression VisitNhAggregate(NhAggregatedExpression expression) + => expression; - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { _containsNakedQuerySourceReferenceExpression = true; return expression; diff --git a/src/NHibernate/Linq/GroupBy/KeySelectorVisitor.cs b/src/NHibernate/Linq/GroupBy/KeySelectorVisitor.cs index 7de5756ae72..db9c4d098bf 100644 --- a/src/NHibernate/Linq/GroupBy/KeySelectorVisitor.cs +++ b/src/NHibernate/Linq/GroupBy/KeySelectorVisitor.cs @@ -4,7 +4,7 @@ namespace NHibernate.Linq.GroupBy { - internal class KeySelectorVisitor : ExpressionTreeVisitor + internal class KeySelectorVisitor : RelinqExpressionVisitor { private readonly GroupResultOperator _groupBy; @@ -13,13 +13,13 @@ public KeySelectorVisitor(GroupResultOperator groupBy) _groupBy = groupBy; } - protected override Expression VisitMemberExpression(MemberExpression expression) + protected override Expression VisitMember(MemberExpression expression) { if (expression.IsGroupingKeyOf(_groupBy)) { return _groupBy.KeySelector; } - return base.VisitMemberExpression(expression); + return base.VisitMember(expression); } } } \ No newline at end of file diff --git a/src/NHibernate/Linq/GroupBy/NonAggregatingGroupByRewriter.cs b/src/NHibernate/Linq/GroupBy/NonAggregatingGroupByRewriter.cs index 120d332a7f4..10c70e401ce 100644 --- a/src/NHibernate/Linq/GroupBy/NonAggregatingGroupByRewriter.cs +++ b/src/NHibernate/Linq/GroupBy/NonAggregatingGroupByRewriter.cs @@ -3,7 +3,7 @@ using NHibernate.Linq.ResultOperators; using Remotion.Linq; using Remotion.Linq.Clauses; -using Remotion.Linq.Clauses.ExpressionTreeVisitors; +using Remotion.Linq.Clauses.ExpressionVisitors; using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Clauses.ResultOperators; @@ -13,8 +13,8 @@ public static class NonAggregatingGroupByRewriter { public static void ReWrite(QueryModel queryModel) { - if (queryModel.ResultOperators.Count == 1 - && queryModel.ResultOperators[0] is GroupResultOperator + if (queryModel.ResultOperators.Count == 1 + && queryModel.ResultOperators[0] is GroupResultOperator && IsNonAggregatingGroupBy(queryModel)) { var resultOperator = (GroupResultOperator)queryModel.ResultOperators[0]; @@ -23,11 +23,9 @@ public static void ReWrite(QueryModel queryModel) return; } - var subQueryExpression = queryModel.MainFromClause.FromExpression as SubQueryExpression; - - if ((subQueryExpression != null) - && (subQueryExpression.QueryModel.ResultOperators.Count == 1) - && (subQueryExpression.QueryModel.ResultOperators[0] is GroupResultOperator) + if ((queryModel.MainFromClause.FromExpression is SubQueryExpression subQueryExpression) + && (subQueryExpression.QueryModel.ResultOperators.Count == 1) + && (subQueryExpression.QueryModel.ResultOperators[0] is GroupResultOperator) && (IsNonAggregatingGroupBy(queryModel))) { FlattenSubQuery(subQueryExpression, queryModel); @@ -58,7 +56,7 @@ private static void FlattenSubQuery(SubQueryExpression subQueryExpression, Query throw new NotImplementedException(); } - queryModel.ResultOperators.Add(new NonAggregatingGroupBy((GroupResultOperator) subQueryModel.ResultOperators[0])); + queryModel.ResultOperators.Add(new NonAggregatingGroupBy((GroupResultOperator)subQueryModel.ResultOperators[0])); queryModel.ResultOperators.Add(clientSideSelect); } @@ -67,26 +65,24 @@ private static ClientSideSelect CreateClientSideSelect(Expression expression, Qu // TODO - don't like calling GetGenericArguments here... var parameter = Expression.Parameter(expression.Type.GetGenericArguments()[0], "inputParameter"); - + var mapping = new QuerySourceMapping(); mapping.AddMapping(queryModel.MainFromClause, parameter); - - var body = ReferenceReplacingExpressionTreeVisitor.ReplaceClauseReferences(queryModel.SelectClause.Selector, mapping, false); - + + var body = ReferenceReplacingExpressionVisitor.ReplaceClauseReferences(queryModel.SelectClause.Selector, mapping, false); + var lambda = Expression.Lambda(body, parameter); - + return new ClientSideSelect(lambda); } private static bool IsNonAggregatingGroupBy(QueryModel queryModel) - { - return new IsNonAggregatingGroupByDetectionVisitor().IsNonAggregatingGroupBy(queryModel.SelectClause.Selector); - } + => new IsNonAggregatingGroupByDetectionVisitor().IsNonAggregatingGroupBy(queryModel.SelectClause.Selector); } public class ClientSideSelect : ClientSideTransformOperator { - public LambdaExpression SelectClause { get; private set; } + public LambdaExpression SelectClause { get; } public ClientSideSelect(LambdaExpression selectClause) { @@ -96,7 +92,7 @@ public ClientSideSelect(LambdaExpression selectClause) public class ClientSideSelect2 : ClientSideTransformOperator { - public LambdaExpression SelectClause { get; private set; } + public LambdaExpression SelectClause { get; } public ClientSideSelect2(LambdaExpression selectClause) { diff --git a/src/NHibernate/Linq/GroupJoin/AggregatingGroupJoinRewriter.cs b/src/NHibernate/Linq/GroupJoin/AggregatingGroupJoinRewriter.cs index 678711730fb..6ac4628bb26 100644 --- a/src/NHibernate/Linq/GroupJoin/AggregatingGroupJoinRewriter.cs +++ b/src/NHibernate/Linq/GroupJoin/AggregatingGroupJoinRewriter.cs @@ -55,8 +55,6 @@ public static void ReWrite(QueryModel model) } private static IsAggregatingResults IsAggregatingGroupJoin(QueryModel model, IEnumerable clause) - { - return GroupJoinAggregateDetectionVisitor.Visit(clause, model.SelectClause.Selector); - } + => GroupJoinAggregateDetectionVisitor.Visit(clause, model.SelectClause.Selector); } } \ No newline at end of file diff --git a/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs b/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs index 27f0bf8805f..4f005ccc19a 100644 --- a/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs +++ b/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs @@ -8,7 +8,7 @@ namespace NHibernate.Linq.GroupJoin { - internal class GroupJoinAggregateDetectionVisitor : NhExpressionTreeVisitor + internal class GroupJoinAggregateDetectionVisitor : NhExpressionVisitor { private readonly HashSet _groupJoinClauses; private readonly StackFlag _inAggregate = new StackFlag(); @@ -27,18 +27,18 @@ public static IsAggregatingResults Visit(IEnumerable groupJoinC { var visitor = new GroupJoinAggregateDetectionVisitor(groupJoinClause); - visitor.VisitExpression(selectExpression); + visitor.Visit(selectExpression); return new IsAggregatingResults { NonAggregatingClauses = visitor._nonAggregatingGroupJoins, AggregatingClauses = visitor._aggregatingGroupJoins, NonAggregatingExpressions = visitor._nonAggregatingExpressions }; } - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { - VisitExpression(expression.QueryModel.SelectClause.Selector); + Visit(expression.QueryModel.SelectClause.Selector); return expression; } - protected override Expression VisitNhAggregate(NhAggregatedExpression expression) + public override Expression VisitNhAggregate(NhAggregatedExpression expression) { using (_inAggregate.SetFlag()) { @@ -46,7 +46,7 @@ protected override Expression VisitNhAggregate(NhAggregatedExpression expression } } - protected override Expression VisitMemberExpression(MemberExpression expression) + protected override Expression VisitMember(MemberExpression expression) { if (_inAggregate.FlagIsFalse && _parentExpressionProcessed.FlagIsFalse) { @@ -55,19 +55,17 @@ protected override Expression VisitMemberExpression(MemberExpression expression) using (_parentExpressionProcessed.SetFlag()) { - return base.VisitMemberExpression(expression); + return base.VisitMember(expression); } } - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { - var fromClause = (FromClauseBase) expression.ReferencedQuerySource; + var fromClause = (FromClauseBase)expression.ReferencedQuerySource; - var querySourceReference = fromClause.FromExpression as QuerySourceReferenceExpression; - if (querySourceReference != null) + if (fromClause.FromExpression is QuerySourceReferenceExpression querySourceReference) { - var groupJoinClause = querySourceReference.ReferencedQuerySource as GroupJoinClause; - if (groupJoinClause != null && _groupJoinClauses.Contains(groupJoinClause)) + if (querySourceReference.ReferencedQuerySource is GroupJoinClause groupJoinClause && _groupJoinClauses.Contains(groupJoinClause)) { if (_inAggregate.FlagIsFalse) { @@ -80,7 +78,7 @@ protected override Expression VisitQuerySourceReferenceExpression(QuerySourceRef } } - return base.VisitQuerySourceReferenceExpression(expression); + return base.VisitQuerySourceReference(expression); } internal class StackFlag diff --git a/src/NHibernate/Linq/GroupJoin/GroupJoinSelectClauseRewriter.cs b/src/NHibernate/Linq/GroupJoin/GroupJoinSelectClauseRewriter.cs index ad7ac347467..546e7cd514e 100644 --- a/src/NHibernate/Linq/GroupJoin/GroupJoinSelectClauseRewriter.cs +++ b/src/NHibernate/Linq/GroupJoin/GroupJoinSelectClauseRewriter.cs @@ -7,13 +7,13 @@ namespace NHibernate.Linq.GroupJoin { - public class GroupJoinSelectClauseRewriter : ExpressionTreeVisitor + public class GroupJoinSelectClauseRewriter : RelinqExpressionVisitor { private readonly IsAggregatingResults _results; public static Expression ReWrite(Expression expression, IsAggregatingResults results) { - return new GroupJoinSelectClauseRewriter(results).VisitExpression(expression); + return new GroupJoinSelectClauseRewriter(results).Visit(expression); } private GroupJoinSelectClauseRewriter(IsAggregatingResults results) @@ -21,7 +21,7 @@ private GroupJoinSelectClauseRewriter(IsAggregatingResults results) _results = results; } - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { // If the sub query's main (and only) from clause is one of our aggregating group bys, then swap it GroupJoinClause groupJoin = LocateGroupJoinQuerySource(expression.QueryModel); diff --git a/src/NHibernate/Linq/GroupJoin/LocateGroupJoinQuerySource.cs b/src/NHibernate/Linq/GroupJoin/LocateGroupJoinQuerySource.cs index dc55aae5812..a149b90fcdc 100644 --- a/src/NHibernate/Linq/GroupJoin/LocateGroupJoinQuerySource.cs +++ b/src/NHibernate/Linq/GroupJoin/LocateGroupJoinQuerySource.cs @@ -5,7 +5,7 @@ namespace NHibernate.Linq.GroupJoin { - public class LocateGroupJoinQuerySource : ExpressionTreeVisitor + public class LocateGroupJoinQuerySource : RelinqExpressionVisitor { private readonly IsAggregatingResults _results; private GroupJoinClause _groupJoin; @@ -17,19 +17,18 @@ public LocateGroupJoinQuerySource(IsAggregatingResults results) public GroupJoinClause Detect(Expression expression) { - VisitExpression(expression); + Visit(expression); return _groupJoin; } - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { - var groupJoinClause = expression.ReferencedQuerySource as GroupJoinClause; - if (groupJoinClause != null && _results.AggregatingClauses.Contains(groupJoinClause)) + if (expression.ReferencedQuerySource is GroupJoinClause groupJoinClause && _results.AggregatingClauses.Contains(groupJoinClause)) { _groupJoin = groupJoinClause; } - return base.VisitQuerySourceReferenceExpression(expression); + return base.VisitQuerySourceReference(expression); } } } \ No newline at end of file diff --git a/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs b/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs index 73989092dcb..b62b449cefb 100644 --- a/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs +++ b/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs @@ -64,7 +64,7 @@ private void ReWrite() // join s in dc.Status on o.StatusId equals s.Id into os // from y in os.DefaultIfEmpty() // select new { o.OrderNumber, x.VendorName, y.StatusName } - // This is used to repesent an outer join, and again the "from" is removing the hierarchy. So + // This is used to represent an outer join, and again the "from" is removing the hierarchy. So // simply change the group join to an outer join _locator = new QuerySourceUsageLocator(nonAggregatingJoin); @@ -122,39 +122,20 @@ private void SwapClause(IBodyClause oldClause, IBodyClause newClause) } private bool IsOuterJoin(GroupJoinClause nonAggregatingJoin) - { - return false; - } + => false; private bool IsFlattenedJoin(GroupJoinClause nonAggregatingJoin) - { - if (_locator.Clauses.Count == 1) - { - var from = _locator.Clauses[0] as AdditionalFromClause; - - if (from != null) - { - return true; - } - } - - return false; - } + => _locator.Clauses.Count == 1 && _locator.Clauses[0] is AdditionalFromClause from; private bool IsHierarchicalJoin(GroupJoinClause nonAggregatingJoin) - { - return _locator.Clauses.Count == 0; - } + => _locator.Clauses.Count == 0; // TODO - rename this and share with the AggregatingGroupJoinRewriter private IsAggregatingResults GetGroupJoinInformation(IEnumerable clause) - { - return GroupJoinAggregateDetectionVisitor.Visit(clause, _model.SelectClause.Selector); - } - + => GroupJoinAggregateDetectionVisitor.Visit(clause, _model.SelectClause.Selector); } - internal class QuerySourceUsageLocator : ExpressionTreeVisitor + internal class QuerySourceUsageLocator : RelinqExpressionVisitor { private readonly IQuerySource _querySource; private bool _references; @@ -165,10 +146,7 @@ public QuerySourceUsageLocator(IQuerySource querySource) _querySource = querySource; } - public IList Clauses - { - get { return _clauses.AsReadOnly(); } - } + public IList Clauses => _clauses.AsReadOnly(); public void Search(IBodyClause clause) { @@ -184,11 +162,11 @@ public void Search(IBodyClause clause) private Expression ExpressionSearcher(Expression arg) { - VisitExpression(arg); + Visit(arg); return arg; } - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { if (expression.ReferencedQuerySource == _querySource) { diff --git a/src/NHibernate/Linq/LinqExtensionMethods.cs b/src/NHibernate/Linq/LinqExtensionMethods.cs index e783502aceb..94db3fc5f6e 100755 --- a/src/NHibernate/Linq/LinqExtensionMethods.cs +++ b/src/NHibernate/Linq/LinqExtensionMethods.cs @@ -7,7 +7,7 @@ using NHibernate.Type; using NHibernate.Util; using Remotion.Linq; -using Remotion.Linq.Parsing.ExpressionTreeVisitors; +using Remotion.Linq.Parsing.ExpressionVisitors; namespace NHibernate.Linq { @@ -104,6 +104,14 @@ public static IFutureValue ToFutureValue(this IQueryable query) return (IFutureValue) future; } + /// + /// Allows to specify the parameter NHibernate type to use for a literal in a queryable expression. + /// + /// The type of the literal. + /// The literal value. + /// The NHibernate type, usually obtained from NHibernateUtil properties. + /// The literal value. + [DBOnly] public static T MappedAs(this T parameter, IType type) { throw new InvalidOperationException("The method should be used inside Linq to indicate a type of a parameter"); @@ -117,7 +125,7 @@ public static IFutureValue ToFutureValue(this IQueryable var provider = (INhQueryProvider) query.Provider; - var expression = ReplacingExpressionTreeVisitor.Replace(selector.Parameters.Single(), + var expression = ReplacingExpressionVisitor.Replace(selector.Parameters.Single(), query.Expression, selector.Body); diff --git a/src/NHibernate/Linq/LinqLogging.cs b/src/NHibernate/Linq/LinqLogging.cs index 942fc007340..b0d041ad9ba 100644 --- a/src/NHibernate/Linq/LinqLogging.cs +++ b/src/NHibernate/Linq/LinqLogging.cs @@ -21,30 +21,29 @@ internal static void LogExpression(string msg, Expression expression) // generated by a class internal to System.Linq.Expression, so we cannot // actually override that logic. Circumvent it by replacing such ConstantExpressions // with ParameterExpression, having their name set to the string we wish to display. - var visitor = new ProxyReplacingExpressionTreeVisitor(); - var preparedExpression = visitor.VisitExpression(expression); + var visitor = new ProxyReplacingExpressionVisitor(); + var preparedExpression = visitor.Visit(expression); Log.DebugFormat("{0}: {1}", msg, preparedExpression.ToString()); } } - /// /// Replace all occurrences of ConstantExpression where the value is an NHibernate /// proxy with a ParameterExpression. The name of the parameter will be a string /// representing the proxied entity, without initializing it. /// - private class ProxyReplacingExpressionTreeVisitor : NhExpressionTreeVisitor + private class ProxyReplacingExpressionVisitor : NhExpressionVisitor { - // See also e.g. Remotion.Linq.Clauses.ExpressionTreeVisitors.FormattingExpressionTreeVisitor + // See also e.g. Remotion.Linq.Clauses.ExpressionVisitors.FormattingExpressionTreeVisitor // for another example of this technique. - protected override Expression VisitConstantExpression(ConstantExpression expression) + protected override Expression VisitConstant(ConstantExpression expression) { if (expression.Value.IsProxy()) return Expression.Parameter(expression.Type, ObjectHelpers.IdentityToString(expression.Value)); - return base.VisitConstantExpression(expression); + return base.VisitConstant(expression); } } } diff --git a/src/NHibernate/Linq/NestedSelects/NestedSelectDetector.cs b/src/NHibernate/Linq/NestedSelects/NestedSelectDetector.cs index 0d6639a6ef5..c7ed5a53e1d 100644 --- a/src/NHibernate/Linq/NestedSelects/NestedSelectDetector.cs +++ b/src/NHibernate/Linq/NestedSelects/NestedSelectDetector.cs @@ -7,34 +7,28 @@ namespace NHibernate.Linq.NestedSelects { - internal class NestedSelectDetector : ExpressionTreeVisitor + internal class NestedSelectDetector : RelinqExpressionVisitor { - private readonly ISessionFactory sessionFactory; + private readonly ISessionFactory _sessionFactory; private readonly ICollection _expressions = new List(); public NestedSelectDetector(ISessionFactory sessionFactory) { - this.sessionFactory = sessionFactory; + _sessionFactory = sessionFactory; } - public ICollection Expressions - { - get { return _expressions; } - } + public ICollection Expressions => _expressions; - public bool HasSubqueries - { - get { return Expressions.Count > 0; } - } + public bool HasSubqueries => Expressions.Count > 0; - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { if (expression.QueryModel.ResultOperators.Count == 0) Expressions.Add(expression); - return base.VisitSubQueryExpression(expression); + return base.VisitSubQuery(expression); } - protected override Expression VisitMemberExpression(MemberExpression expression) + protected override Expression VisitMember(MemberExpression expression) { var memberType = ReflectHelper.GetPropertyOrFieldType(expression.Member); @@ -45,14 +39,14 @@ protected override Expression VisitMemberExpression(MemberExpression expression) Expressions.Add(expression); } - return base.VisitMemberExpression(expression); + return base.VisitMember(expression); } private bool IsMappedCollection(MemberInfo memberInfo) { var collectionRole = memberInfo.DeclaringType.FullName + "." + memberInfo.Name; - return sessionFactory.GetCollectionMetadata(collectionRole) != null; + return _sessionFactory.GetCollectionMetadata(collectionRole) != null; } private bool IsChainedFromQuerySourceReference(MemberExpression expression) diff --git a/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs b/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs index 61a057d97e6..7d6e099a63d 100644 --- a/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs +++ b/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs @@ -29,7 +29,7 @@ static class NestedSelectRewriter public static void ReWrite(QueryModel queryModel, VisitorParameters parameters) { var nsqmv = new NestedSelectDetector(parameters.SessionFactory); - nsqmv.VisitExpression(queryModel.SelectClause.Selector); + nsqmv.Visit(queryModel.SelectClause.Selector); if (!nsqmv.HasSubqueries) return; @@ -52,7 +52,7 @@ public static void ReWrite(QueryModel queryModel, VisitorParameters parameters) var rewriter = new SelectClauseRewriter(key, expressions, identifier, replacements); - var resultSelector = rewriter.VisitExpression(queryModel.SelectClause.Selector); + var resultSelector = rewriter.Visit(queryModel.SelectClause.Selector); elementExpression.AddRange(expressions); @@ -156,7 +156,7 @@ private static LambdaExpression MakeSelector(ICollection eleme var rewriter = new SelectClauseRewriter(parameter, elementExpression, identifier, 1, new Dictionary()); - var selectorBody = rewriter.VisitExpression(select); + var selectorBody = rewriter.Visit(select); return Expression.Lambda(selectorBody, parameter); } diff --git a/src/NHibernate/Linq/NestedSelects/SelectClauseRewriter.cs b/src/NHibernate/Linq/NestedSelects/SelectClauseRewriter.cs index 78337496bf0..6c03c7c8227 100644 --- a/src/NHibernate/Linq/NestedSelects/SelectClauseRewriter.cs +++ b/src/NHibernate/Linq/NestedSelects/SelectClauseRewriter.cs @@ -5,57 +5,51 @@ namespace NHibernate.Linq.NestedSelects { - class SelectClauseRewriter : ExpressionTreeVisitor + class SelectClauseRewriter : RelinqExpressionVisitor { private readonly Dictionary _dictionary; + private readonly ICollection _expressions; + private readonly Expression _parameter; + private readonly int _tuple; - readonly ICollection expressions; - readonly Expression parameter; - readonly int tuple; + public SelectClauseRewriter(Expression parameter, ICollection expressions, Expression expression, + Dictionary dictionary) + : this(parameter, expressions, expression, 0, dictionary) { } - public SelectClauseRewriter(Expression parameter, ICollection expressions, Expression expression, Dictionary dictionary) - : this(parameter, expressions, expression, 0, dictionary) + public SelectClauseRewriter(Expression parameter, ICollection expressions, Expression expression, + int tuple, Dictionary dictionary) { - } - - public SelectClauseRewriter(Expression parameter, ICollection expressions, Expression expression, int tuple, Dictionary dictionary) - { - this.expressions = expressions; - this.parameter = parameter; - this.tuple = tuple; - this.expressions.Add(new ExpressionHolder { Expression = expression, Tuple = tuple }); //ID placeholder + _expressions = expressions; + _parameter = parameter; + _tuple = tuple; + _expressions.Add(new ExpressionHolder { Expression = expression, Tuple = tuple }); //ID placeholder _dictionary = dictionary; } - public override Expression VisitExpression(Expression expression) + public override Expression Visit(Expression expression) { if (expression == null) return null; - Expression replacement; - if (_dictionary.TryGetValue(expression, out replacement)) + if (_dictionary.TryGetValue(expression, out Expression replacement)) return replacement; - return base.VisitExpression(expression); + return base.Visit(expression); } - protected override Expression VisitMemberExpression(MemberExpression expression) - { - return AddAndConvertExpression(expression); - } + protected override Expression VisitMember(MemberExpression expression) + => AddAndConvert(expression); - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) - { - return AddAndConvertExpression(expression); - } + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) + => AddAndConvert(expression); - private Expression AddAndConvertExpression(Expression expression) + private Expression AddAndConvert(Expression expression) { - expressions.Add(new ExpressionHolder { Expression = expression, Tuple = tuple }); + _expressions.Add(new ExpressionHolder { Expression = expression, Tuple = _tuple }); return Expression.Convert( Expression.ArrayIndex( - Expression.Property(parameter, Tuple.ItemsProperty), - Expression.Constant(expressions.Count - 1)), + Expression.Property(_parameter, Tuple.ItemsProperty), + Expression.Constant(_expressions.Count - 1)), expression.Type); } } diff --git a/src/NHibernate/Linq/NestedSelects/Tuple.cs b/src/NHibernate/Linq/NestedSelects/Tuple.cs index 4b4068a8ab4..8f239ba0973 100644 --- a/src/NHibernate/Linq/NestedSelects/Tuple.cs +++ b/src/NHibernate/Linq/NestedSelects/Tuple.cs @@ -6,14 +6,14 @@ namespace NHibernate.Linq.NestedSelects { internal class Tuple : IEquatable { - public static readonly ConstructorInfo Constructor = typeof (Tuple).GetConstructor(new[] { typeof (object[]) }); - public static readonly PropertyInfo ItemsProperty = typeof (Tuple).GetProperty("Items"); + public static readonly ConstructorInfo Constructor = typeof(Tuple).GetConstructor(new[] { typeof(object[]) }); + public static readonly PropertyInfo ItemsProperty = typeof(Tuple).GetProperty(nameof(Items)); + private readonly object[] _items; public Tuple(object[] items) { - if (items == null) throw new ArgumentNullException("items"); - _items = items; + _items = items ?? throw new ArgumentNullException(nameof(items)); } public object[] Items diff --git a/src/NHibernate/Linq/NhLinqExpression.cs b/src/NHibernate/Linq/NhLinqExpression.cs index c4c06f8197a..9733841ec6f 100644 --- a/src/NHibernate/Linq/NhLinqExpression.cs +++ b/src/NHibernate/Linq/NhLinqExpression.cs @@ -41,7 +41,7 @@ public NhLinqExpression(Expression expression, ISessionFactoryImplementor sessio _constantToParameterMap = ExpressionParameterVisitor.Visit(ref _expression, sessionFactory); ParameterValuesByName = _constantToParameterMap.Values.ToDictionary(p => p.Name, - p => System.Tuple.Create(p.Value, p.Type)); + p => System.Tuple.Create(p.Value, p.Type)); Key = ExpressionKeyVisitor.Visit(_expression, _constantToParameterMap); @@ -62,15 +62,15 @@ public IASTNode Translate(ISessionFactoryImplementor sessionFactory, bool filter var requiredHqlParameters = new List(); var querySourceNamer = new QuerySourceNamer(); var queryModel = NhRelinqQueryParser.Parse(_expression); - var visitorParameters = new VisitorParameters(sessionFactory, _constantToParameterMap, requiredHqlParameters, querySourceNamer); + var visitorParameters = new VisitorParameters(sessionFactory, _constantToParameterMap, requiredHqlParameters, querySourceNamer, ReturnType); - ExpressionToHqlTranslationResults = QueryModelVisitor.GenerateHqlQuery(queryModel, visitorParameters, true, ReturnType); + ExpressionToHqlTranslationResults = QueryModelVisitor.GenerateHqlQuery(queryModel, visitorParameters, true); if (ExpressionToHqlTranslationResults.ExecuteResultTypeOverride != null) Type = ExpressionToHqlTranslationResults.ExecuteResultTypeOverride; ParameterDescriptors = requiredHqlParameters.AsReadOnly(); - + return ExpressionToHqlTranslationResults.Statement.AstNode; } diff --git a/src/NHibernate/Linq/NhRelinqQueryParser.cs b/src/NHibernate/Linq/NhRelinqQueryParser.cs index 4de1dd3d9d1..388d3fd9110 100644 --- a/src/NHibernate/Linq/NhRelinqQueryParser.cs +++ b/src/NHibernate/Linq/NhRelinqQueryParser.cs @@ -10,7 +10,7 @@ using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.StreamedData; using Remotion.Linq.EagerFetching.Parsing; -using Remotion.Linq.Parsing.ExpressionTreeVisitors.Transformation; +using Remotion.Linq.Parsing.ExpressionVisitors.Transformation; using Remotion.Linq.Parsing.Structure; using Remotion.Linq.Parsing.Structure.ExpressionTreeProcessors; using Remotion.Linq.Parsing.Structure.IntermediateModel; @@ -55,7 +55,7 @@ static NhRelinqQueryParser() /// The transformed expression. public static Expression PreTransform(Expression expression) { - var partiallyEvaluatedExpression = NhPartialEvaluatingExpressionTreeVisitor.EvaluateIndependentSubtrees(expression); + var partiallyEvaluatedExpression = NhPartialEvaluatingExpressionVisitor.EvaluateIndependentSubtrees(expression); return PreProcessor.Process(partiallyEvaluatedExpression); } @@ -140,9 +140,8 @@ public override Expression Resolve(ParameterExpression inputParameter, Expressio return Source.Resolve(inputParameter, expressionToBeResolved, clauseGenerationContext); } - protected override QueryModel ApplyNodeSpecificSemantics(QueryModel queryModel, ClauseGenerationContext clauseGenerationContext) + protected override void ApplyNodeSpecificSemantics(QueryModel queryModel, ClauseGenerationContext clauseGenerationContext) { - return queryModel; } } diff --git a/src/NHibernate/Linq/ReWriters/ArrayIndexExpressionFlattener.cs b/src/NHibernate/Linq/ReWriters/ArrayIndexExpressionFlattener.cs index 0b739b078e3..e74305be0dd 100644 --- a/src/NHibernate/Linq/ReWriters/ArrayIndexExpressionFlattener.cs +++ b/src/NHibernate/Linq/ReWriters/ArrayIndexExpressionFlattener.cs @@ -5,17 +5,17 @@ namespace NHibernate.Linq.ReWriters { - public class ArrayIndexExpressionFlattener : ExpressionTreeVisitor + public class ArrayIndexExpressionFlattener : RelinqExpressionVisitor { public static void ReWrite(QueryModel model) { var visitor = new ArrayIndexExpressionFlattener(); - model.TransformExpressions(visitor.VisitExpression); + model.TransformExpressions(visitor.Visit); } - protected override Expression VisitBinaryExpression(BinaryExpression expression) + protected override Expression VisitBinary(BinaryExpression expression) { - var visitedExpression = base.VisitBinaryExpression(expression); + var visitedExpression = base.VisitBinary(expression); if (visitedExpression.NodeType != ExpressionType.ArrayIndex) return visitedExpression; @@ -25,16 +25,16 @@ protected override Expression VisitBinaryExpression(BinaryExpression expression) return visitedExpression; var expressionList = expression.Left as NewArrayExpression; - if (expressionList == null || expressionList.NodeType != ExpressionType.NewArrayInit) + if (expressionList == null || expressionList.NodeType != ExpressionType.NewArrayInit) return visitedExpression; - return VisitExpression(expressionList.Expressions[(int)index.Value]); + return Visit(expressionList.Expressions[(int)index.Value]); } - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { ReWrite(expression.QueryModel); - return expression; // Note that we modifiy the (mutable) QueryModel, we return an unchanged expression + return expression; // Note that we modify the (mutable) QueryModel, we return an unchanged expression } } } \ No newline at end of file diff --git a/src/NHibernate/Linq/ReWriters/MergeAggregatingResultsRewriter.cs b/src/NHibernate/Linq/ReWriters/MergeAggregatingResultsRewriter.cs index 2e2b3c52405..c6731f6c513 100644 --- a/src/NHibernate/Linq/ReWriters/MergeAggregatingResultsRewriter.cs +++ b/src/NHibernate/Linq/ReWriters/MergeAggregatingResultsRewriter.cs @@ -8,15 +8,14 @@ using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Clauses.ResultOperators; using Remotion.Linq.Parsing; -using Remotion.Linq.Parsing.ExpressionTreeVisitors; +using Remotion.Linq.Parsing.ExpressionVisitors; namespace NHibernate.Linq.ReWriters { public class MergeAggregatingResultsRewriter : QueryModelVisitorBase { private MergeAggregatingResultsRewriter() - { - } + { } public static void ReWrite(QueryModel model) { @@ -27,42 +26,37 @@ public static void ReWrite(QueryModel model) public override void VisitResultOperator(ResultOperatorBase resultOperator, QueryModel queryModel, int index) { - if (resultOperator is SumResultOperator) - { - queryModel.SelectClause.Selector = new NhSumExpression(queryModel.SelectClause.Selector); - queryModel.ResultOperators.Remove(resultOperator); - } - else if (resultOperator is AverageResultOperator) - { - queryModel.SelectClause.Selector = new NhAverageExpression(queryModel.SelectClause.Selector); - queryModel.ResultOperators.Remove(resultOperator); - } - else if (resultOperator is MinResultOperator) - { - queryModel.SelectClause.Selector = new NhMinExpression(queryModel.SelectClause.Selector); - queryModel.ResultOperators.Remove(resultOperator); - } - else if (resultOperator is MaxResultOperator) + switch (resultOperator) { - queryModel.SelectClause.Selector = new NhMaxExpression(queryModel.SelectClause.Selector); - queryModel.ResultOperators.Remove(resultOperator); + case SumResultOperator sum: + queryModel.SelectClause.Selector = new NhSumExpression(queryModel.SelectClause.Selector); + queryModel.ResultOperators.Remove(resultOperator); + break; + case AverageResultOperator avg: + queryModel.SelectClause.Selector = new NhAverageExpression(queryModel.SelectClause.Selector); + queryModel.ResultOperators.Remove(resultOperator); + break; + case MinResultOperator min: + queryModel.SelectClause.Selector = new NhMinExpression(queryModel.SelectClause.Selector); + queryModel.ResultOperators.Remove(resultOperator); + break; + case MaxResultOperator max: + queryModel.SelectClause.Selector = new NhMaxExpression(queryModel.SelectClause.Selector); + queryModel.ResultOperators.Remove(resultOperator); + break; + case DistinctResultOperator distinct: + queryModel.SelectClause.Selector = new NhDistinctExpression(queryModel.SelectClause.Selector); + queryModel.ResultOperators.Remove(resultOperator); + break; + case CountResultOperator count: + queryModel.SelectClause.Selector = new NhShortCountExpression(TransformCountExpression(queryModel.SelectClause.Selector)); + queryModel.ResultOperators.Remove(resultOperator); + break; + case LongCountResultOperator longCount: + queryModel.SelectClause.Selector = new NhLongCountExpression(TransformCountExpression(queryModel.SelectClause.Selector)); + queryModel.ResultOperators.Remove(resultOperator); + break; } - else if (resultOperator is DistinctResultOperator) - { - queryModel.SelectClause.Selector = new NhDistinctExpression(queryModel.SelectClause.Selector); - queryModel.ResultOperators.Remove(resultOperator); - } - else if (resultOperator is CountResultOperator) - { - queryModel.SelectClause.Selector = new NhShortCountExpression(TransformCountExpression(queryModel.SelectClause.Selector)); - queryModel.ResultOperators.Remove(resultOperator); - } - else if (resultOperator is LongCountResultOperator) - { - queryModel.SelectClause.Selector = new NhLongCountExpression(TransformCountExpression(queryModel.SelectClause.Selector)); - queryModel.ResultOperators.Remove(resultOperator); - } - base.VisitResultOperator(resultOperator, queryModel, index); } @@ -83,9 +77,9 @@ public override void VisitOrdering(Ordering ordering, QueryModel queryModel, Ord private static Expression TransformCountExpression(Expression expression) { - if (expression.NodeType == ExpressionType.MemberInit || + if (expression.NodeType == ExpressionType.MemberInit || expression.NodeType == ExpressionType.New || - expression.NodeType == QuerySourceReferenceExpression.ExpressionType) + expression is QuerySourceReferenceExpression) { //Probably it should be done by CountResultOperatorProcessor return new NhStarExpression(expression); @@ -95,7 +89,7 @@ private static Expression TransformCountExpression(Expression expression) } } - internal class MergeAggregatingResultsInExpressionRewriter : ExpressionTreeVisitor + internal class MergeAggregatingResultsInExpressionRewriter : RelinqExpressionVisitor { private readonly NameGenerator _nameGenerator; @@ -108,62 +102,60 @@ public static Expression Rewrite(Expression expression, NameGenerator nameGenera { var visitor = new MergeAggregatingResultsInExpressionRewriter(nameGenerator); - return visitor.VisitExpression(expression); + return visitor.Visit(expression); } - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { MergeAggregatingResultsRewriter.ReWrite(expression.QueryModel); return expression; } - protected override Expression VisitMethodCallExpression(MethodCallExpression m) + protected override Expression VisitMethodCall(MethodCallExpression m) { if (m.Method.DeclaringType == typeof(Queryable) || m.Method.DeclaringType == typeof(Enumerable)) { - // TODO - dynamic name generation needed here switch (m.Method.Name) { - case "Count": + case nameof(Queryable.Count): return CreateAggregate(m.Arguments[0], (LambdaExpression)m.Arguments[1], - e => new NhShortCountExpression(e), - () => new CountResultOperator()); - case "LongCount": + e => new NhShortCountExpression(e), + () => new CountResultOperator()); + case nameof(Queryable.LongCount): return CreateAggregate(m.Arguments[0], (LambdaExpression)m.Arguments[1], - e => new NhLongCountExpression(e), - () => new LongCountResultOperator()); - case "Min": + e => new NhLongCountExpression(e), + () => new LongCountResultOperator()); + case nameof(Queryable.Min): return CreateAggregate(m.Arguments[0], (LambdaExpression)m.Arguments[1], - e => new NhMinExpression(e), - () => new MinResultOperator()); - case "Max": + e => new NhMinExpression(e), + () => new MinResultOperator()); + case nameof(Queryable.Max): return CreateAggregate(m.Arguments[0], (LambdaExpression)m.Arguments[1], - e => new NhMaxExpression(e), - () => new MaxResultOperator()); - case "Sum": + e => new NhMaxExpression(e), + () => new MaxResultOperator()); + case nameof(Queryable.Sum): return CreateAggregate(m.Arguments[0], (LambdaExpression)m.Arguments[1], - e => new NhSumExpression(e), - () => new SumResultOperator()); - case "Average": + e => new NhSumExpression(e), + () => new SumResultOperator()); + case nameof(Queryable.Average): return CreateAggregate(m.Arguments[0], (LambdaExpression)m.Arguments[1], - e => new NhAverageExpression(e), - () => new AverageResultOperator()); + e => new NhAverageExpression(e), + () => new AverageResultOperator()); } } - return base.VisitMethodCallExpression(m); + return base.VisitMethodCall(m); } - private Expression CreateAggregate(Expression fromClauseExpression, LambdaExpression body, Func aggregateFactory, Func resultOperatorFactory) + private Expression CreateAggregate(Expression fromClauseExpression, LambdaExpression body, + Func aggregateFactory, Func resultOperatorFactory) { var fromClause = new MainFromClause(_nameGenerator.GetNewName(), body.Parameters[0].Type, fromClauseExpression); var selectClause = body.Body; - selectClause = ReplacingExpressionTreeVisitor.Replace(body.Parameters[0], - new QuerySourceReferenceExpression( - fromClause), selectClause); - var queryModel = new QueryModel(fromClause, - new SelectClause(aggregateFactory(selectClause))); + selectClause = ReplacingExpressionVisitor.Replace(body.Parameters[0], + new QuerySourceReferenceExpression(fromClause), selectClause); + var queryModel = new QueryModel(fromClause, new SelectClause(aggregateFactory(selectClause))); // TODO - this sucks, but we use it to get the Type of the SubQueryExpression correct queryModel.ResultOperators.Add(resultOperatorFactory()); diff --git a/src/NHibernate/Linq/ReWriters/MoveOrderByToEndRewriter.cs b/src/NHibernate/Linq/ReWriters/MoveOrderByToEndRewriter.cs index d5434b6e347..52a91ef412a 100644 --- a/src/NHibernate/Linq/ReWriters/MoveOrderByToEndRewriter.cs +++ b/src/NHibernate/Linq/ReWriters/MoveOrderByToEndRewriter.cs @@ -1,8 +1,4 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using Remotion.Linq; +using Remotion.Linq; using Remotion.Linq.Clauses; namespace NHibernate.Linq.ReWriters diff --git a/src/NHibernate/Linq/ReWriters/QueryReferenceExpressionFlattener.cs b/src/NHibernate/Linq/ReWriters/QueryReferenceExpressionFlattener.cs index 412f1c8fabe..98d02e4481e 100644 --- a/src/NHibernate/Linq/ReWriters/QueryReferenceExpressionFlattener.cs +++ b/src/NHibernate/Linq/ReWriters/QueryReferenceExpressionFlattener.cs @@ -9,7 +9,7 @@ namespace NHibernate.Linq.ReWriters { - public class QueryReferenceExpressionFlattener : ExpressionTreeVisitor + public class QueryReferenceExpressionFlattener : RelinqExpressionVisitor { private readonly QueryModel _model; @@ -29,24 +29,22 @@ private QueryReferenceExpressionFlattener(QueryModel model) public static void ReWrite(QueryModel model) { var visitor = new QueryReferenceExpressionFlattener(model); - model.TransformExpressions(visitor.VisitExpression); + model.TransformExpressions(visitor.Visit); } - protected override Expression VisitSubQueryExpression(SubQueryExpression subQuery) + protected override Expression VisitSubQuery(SubQueryExpression subQuery) { var subQueryModel = subQuery.QueryModel; var hasBodyClauses = subQueryModel.BodyClauses.Count > 0; if (hasBodyClauses) { - return base.VisitSubQueryExpression(subQuery); + return base.VisitSubQuery(subQuery); } var resultOperators = subQueryModel.ResultOperators; if (resultOperators.Count == 0 || HasJustAllFlattenableOperator(resultOperators)) { - var selectQuerySource = subQueryModel.SelectClause.Selector as QuerySourceReferenceExpression; - - if (selectQuerySource != null && selectQuerySource.ReferencedQuerySource == subQueryModel.MainFromClause) + if (subQueryModel.SelectClause.Selector is QuerySourceReferenceExpression selectQuerySource && selectQuerySource.ReferencedQuerySource == subQueryModel.MainFromClause) { foreach (var resultOperator in resultOperators) { @@ -57,7 +55,7 @@ protected override Expression VisitSubQueryExpression(SubQueryExpression subQuer } } - return base.VisitSubQueryExpression(subQuery); + return base.VisitSubQuery(subQuery); } private static bool HasJustAllFlattenableOperator(IEnumerable resultOperators) @@ -65,18 +63,16 @@ private static bool HasJustAllFlattenableOperator(IEnumerable FlattenableResultOperators.Contains(x.GetType())); } - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { - var fromClauseBase = expression.ReferencedQuerySource as FromClauseBase; - - if (fromClauseBase != null && + if (expression.ReferencedQuerySource is FromClauseBase fromClauseBase && fromClauseBase.FromExpression is QuerySourceReferenceExpression && expression.Type == fromClauseBase.FromExpression.Type) { return fromClauseBase.FromExpression; } - return base.VisitQuerySourceReferenceExpression(expression); + return base.VisitQuerySourceReference(expression); } } } \ No newline at end of file diff --git a/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs b/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs index fe1e50f7c40..8d8c8570b1f 100644 --- a/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs +++ b/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs @@ -1,32 +1,27 @@ +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using Remotion.Linq; +using Remotion.Linq.Clauses; +using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Clauses.ResultOperators; +using Remotion.Linq.Clauses.StreamedData; +using Remotion.Linq.EagerFetching; using Remotion.Linq.Parsing; namespace NHibernate.Linq.ReWriters { - using System.Collections.Generic; - using System.Linq; - using System.Linq.Expressions; - - using NHibernate.Linq.Visitors; - - using Remotion.Linq; - using Remotion.Linq.Clauses; - using Remotion.Linq.Clauses.Expressions; - using Remotion.Linq.Clauses.ResultOperators; - using Remotion.Linq.Clauses.StreamedData; - using Remotion.Linq.EagerFetching; - /// /// Removes various result operators from a query so that they can be processed at the same /// tree level as the query itself. /// public class ResultOperatorRewriter : QueryModelVisitorBase { - private readonly List resultOperators = new List(); - private IStreamedDataInfo evaluationType; + private readonly List _resultOperators = new List(); + private IStreamedDataInfo _evaluationType; private ResultOperatorRewriter() - { - } + { } public static ResultOperatorRewriterResult Rewrite(QueryModel queryModel) { @@ -34,7 +29,7 @@ public static ResultOperatorRewriterResult Rewrite(QueryModel queryModel) rewriter.VisitQueryModel(queryModel); - return new ResultOperatorRewriterResult(rewriter.resultOperators, rewriter.evaluationType); + return new ResultOperatorRewriterResult(rewriter._resultOperators, rewriter._evaluationType); } public override void VisitMainFromClause(MainFromClause fromClause, QueryModel queryModel) @@ -52,14 +47,14 @@ public override void VisitMainFromClause(MainFromClause fromClause, QueryModel q } } - this.resultOperators.AddRange(rewriter.ResultOperators); - this.evaluationType = rewriter.EvaluationType; + _resultOperators.AddRange(rewriter.ResultOperators); + _evaluationType = rewriter.EvaluationType; } /// /// Rewrites expressions so that they sit in the outermost portion of the query. /// - private class ResultOperatorExpressionRewriter : ExpressionTreeVisitor + private class ResultOperatorExpressionRewriter : RelinqExpressionVisitor { private static readonly System.Type[] rewrittenTypes = new[] { @@ -70,45 +65,37 @@ private class ResultOperatorExpressionRewriter : ExpressionTreeVisitor typeof(CastResultOperator), // see ProcessCast class }; - private readonly List resultOperators = new List(); - private IStreamedDataInfo evaluationType; + private readonly List _resultOperators = new List(); + private IStreamedDataInfo _evaluationType; /// /// Gets an of that were rewritten. /// - public IEnumerable ResultOperators - { - get { return resultOperators; } - } + public IEnumerable ResultOperators => _resultOperators; /// /// Gets the representing the type of data that the operator works upon. /// - public IStreamedDataInfo EvaluationType - { - get { return evaluationType; } - } + public IStreamedDataInfo EvaluationType => _evaluationType; public Expression Rewrite(Expression expression) - { - return VisitExpression(expression); - } + => Visit(expression); - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { - resultOperators.AddRange( + _resultOperators.AddRange( expression.QueryModel.ResultOperators .Where(r => rewrittenTypes.Any(t => t.IsInstanceOfType(r)))); - resultOperators.ForEach(f => expression.QueryModel.ResultOperators.Remove(f)); - evaluationType = expression.QueryModel.SelectClause.GetOutputDataInfo(); + _resultOperators.ForEach(f => expression.QueryModel.ResultOperators.Remove(f)); + _evaluationType = expression.QueryModel.SelectClause.GetOutputDataInfo(); if (expression.QueryModel.ResultOperators.Count == 0 && expression.QueryModel.BodyClauses.Count == 0) { return expression.QueryModel.MainFromClause.FromExpression; } - return base.VisitSubQueryExpression(expression); + return base.VisitSubQuery(expression); } } } diff --git a/src/NHibernate/Linq/ReWriters/ResultOperatorRewriterResult.cs b/src/NHibernate/Linq/ReWriters/ResultOperatorRewriterResult.cs index 59d945e664a..356eefb26d0 100644 --- a/src/NHibernate/Linq/ReWriters/ResultOperatorRewriterResult.cs +++ b/src/NHibernate/Linq/ReWriters/ResultOperatorRewriterResult.cs @@ -1,30 +1,29 @@ +using System.Collections.Generic; +using Remotion.Linq.Clauses; +using Remotion.Linq.Clauses.StreamedData; + namespace NHibernate.Linq.ReWriters { - using System.Collections.Generic; - - using Remotion.Linq.Clauses; - using Remotion.Linq.Clauses.StreamedData; - - /// - /// Result of . - /// - public class ResultOperatorRewriterResult - { - public ResultOperatorRewriterResult(IEnumerable rewrittenOperators, IStreamedDataInfo evaluationType) - { - this.RewrittenOperators = rewrittenOperators; - this.EvaluationType = evaluationType; - } + /// + /// Result of . + /// + public class ResultOperatorRewriterResult + { + public ResultOperatorRewriterResult(IEnumerable rewrittenOperators, IStreamedDataInfo evaluationType) + { + RewrittenOperators = rewrittenOperators; + EvaluationType = evaluationType; + } - /// - /// Gets an of implementations that were - /// rewritten. - /// - public IEnumerable RewrittenOperators { get; private set; } + /// + /// Gets an of implementations that were + /// rewritten. + /// + public IEnumerable RewrittenOperators { get; } - /// - /// Gets the representing the type of data that the operator works upon. - /// - public IStreamedDataInfo EvaluationType { get; private set; } - } + /// + /// Gets the representing the type of data that the operator works upon. + /// + public IStreamedDataInfo EvaluationType { get; } + } } diff --git a/src/NHibernate/Linq/Visitors/EqualityHqlGenerator.cs b/src/NHibernate/Linq/Visitors/EqualityHqlGenerator.cs index 13b36cb2ef6..1741e380609 100644 --- a/src/NHibernate/Linq/Visitors/EqualityHqlGenerator.cs +++ b/src/NHibernate/Linq/Visitors/EqualityHqlGenerator.cs @@ -24,7 +24,7 @@ public HqlBooleanExpression Visit(Expression innerKeySelector, Expression outerK var outerNewExpression = outerKeySelector as NewExpression; return innerNewExpression != null && outerNewExpression != null ? VisitNew(innerNewExpression, outerNewExpression) - : GenerateEqualityNode(innerKeySelector, outerKeySelector, new HqlGeneratorExpressionTreeVisitor(_parameters)); + : GenerateEqualityNode(innerKeySelector, outerKeySelector, new HqlGeneratorExpressionVisitor(_parameters)); } private HqlBooleanExpression VisitNew(NewExpression innerKeySelector, NewExpression outerKeySelector) @@ -45,15 +45,11 @@ private HqlBooleanExpression VisitNew(NewExpression innerKeySelector, NewExpress } private HqlEquality GenerateEqualityNode(NewExpression innerKeySelector, NewExpression outerKeySelector, int index) - { - return GenerateEqualityNode(innerKeySelector.Arguments[index], outerKeySelector.Arguments[index], new HqlGeneratorExpressionTreeVisitor(_parameters)); - } + => GenerateEqualityNode(innerKeySelector.Arguments[index], outerKeySelector.Arguments[index], new HqlGeneratorExpressionVisitor(_parameters)); private HqlEquality GenerateEqualityNode(Expression leftExpr, Expression rightExpr, IHqlExpressionVisitor visitor) - { - return _hqlTreeBuilder.Equality( + => _hqlTreeBuilder.Equality( visitor.Visit(leftExpr).ToArithmeticExpression(), visitor.Visit(rightExpr).ToArithmeticExpression()); - } } } \ No newline at end of file diff --git a/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs index ab4618feb71..a2ab8086076 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs @@ -6,6 +6,7 @@ using System.Reflection; using System.Text; using NHibernate.Param; +using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Parsing; namespace NHibernate.Linq.Visitors @@ -17,7 +18,7 @@ namespace NHibernate.Linq.Visitors /// generate the same key as /// from c in Customers where c.City = "Madrid" /// - public class ExpressionKeyVisitor : ExpressionTreeVisitor + public class ExpressionKeyVisitor : RelinqExpressionVisitor { private readonly IDictionary _constantToParameterMap; readonly StringBuilder _string = new StringBuilder(); @@ -31,17 +32,14 @@ public static string Visit(Expression expression, IDictionary _string.ToString(); - protected override Expression VisitBinaryExpression(BinaryExpression expression) + protected override Expression VisitBinary(BinaryExpression expression) { if (expression.Method != null) { @@ -56,31 +54,29 @@ protected override Expression VisitBinaryExpression(BinaryExpression expression) _string.Append("("); - VisitExpression(expression.Left); + Visit(expression.Left); _string.Append(", "); - VisitExpression(expression.Right); + Visit(expression.Right); _string.Append(")"); return expression; } - protected override Expression VisitConditionalExpression(ConditionalExpression expression) + protected override Expression VisitConditional(ConditionalExpression expression) { - VisitExpression(expression.Test); + Visit(expression.Test); _string.Append(" ? "); - VisitExpression(expression.IfTrue); + Visit(expression.IfTrue); _string.Append(" : "); - VisitExpression(expression.IfFalse); + Visit(expression.IfFalse); return expression; } - protected override Expression VisitConstantExpression(ConstantExpression expression) + protected override Expression VisitConstant(ConstantExpression expression) { - NamedParameter param; - - if (_constantToParameterMap.TryGetValue(expression, out param) && insideSelectClause == false) + if (_constantToParameterMap.TryGetValue(expression, out NamedParameter param) && insideSelectClause == false) { // Nulls generate different query plans. X = variable generates a different query depending on if variable is null or not. if (param.Value == null) @@ -89,8 +85,7 @@ protected override Expression VisitConstantExpression(ConstantExpression express } else { - var value = param.Value as IEnumerable; - if (value != null && !(value is string) && !value.Cast().Any()) + if (param.Value is IEnumerable value && !(value is string) && !value.Cast().Any()) { _string.Append("EmptyList"); } @@ -108,8 +103,7 @@ protected override Expression VisitConstantExpression(ConstantExpression express } else { - var value = expression.Value as IEnumerable; - if (value != null && !(value is string) && !(value is IQueryable)) + if (expression.Value is IEnumerable value && !(value is string) && !(value is IQueryable)) { _string.Append("{"); _string.Append(String.Join(",", value.Cast())); @@ -122,32 +116,32 @@ protected override Expression VisitConstantExpression(ConstantExpression express } } - return base.VisitConstantExpression(expression); + return base.VisitConstant(expression); } private T AppendCommas(T expression) where T : Expression { - VisitExpression(expression); + Visit(expression); _string.Append(", "); return expression; } - protected override Expression VisitLambdaExpression(LambdaExpression expression) + protected override Expression VisitLambda(Expression expression) { _string.Append('('); - VisitList(expression.Parameters, AppendCommas); + Visit(expression.Parameters, AppendCommas); _string.Append(") => ("); - VisitExpression(expression.Body); + Visit(expression.Body); _string.Append(')'); return expression; } - protected override Expression VisitMemberExpression(MemberExpression expression) + protected override Expression VisitMember(MemberExpression expression) { - base.VisitMemberExpression(expression); + base.VisitMember(expression); _string.Append('.'); _string.Append(expression.Member.Name); @@ -156,7 +150,7 @@ protected override Expression VisitMemberExpression(MemberExpression expression) } private bool insideSelectClause; - protected override Expression VisitMethodCallExpression(MethodCallExpression expression) + protected override Expression VisitMethodCall(MethodCallExpression expression) { var old = insideSelectClause; @@ -175,39 +169,39 @@ protected override Expression VisitMethodCallExpression(MethodCallExpression exp break; } - VisitExpression(expression.Object); + Visit(expression.Object); _string.Append('.'); VisitMethod(expression.Method); _string.Append('('); - VisitList(expression.Arguments, AppendCommas); + Visit(expression.Arguments, AppendCommas); _string.Append(')'); insideSelectClause = old; return expression; } - protected override Expression VisitNewExpression(NewExpression expression) + protected override Expression VisitNew(NewExpression expression) { _string.Append("new "); _string.Append(expression.Constructor.DeclaringType.Name); _string.Append('('); - VisitList(expression.Arguments, AppendCommas); + Visit(expression.Arguments, AppendCommas); _string.Append(')'); return expression; } - protected override Expression VisitParameterExpression(ParameterExpression expression) + protected override Expression VisitParameter(ParameterExpression expression) { _string.Append(expression.Name); return expression; } - protected override Expression VisitTypeBinaryExpression(TypeBinaryExpression expression) + protected override Expression VisitTypeBinary(TypeBinaryExpression expression) { _string.Append("IsType("); - VisitExpression(expression.Expression); + Visit(expression.Expression); _string.Append(", "); _string.Append(expression.TypeOperand.FullName); _string.Append(")"); @@ -215,17 +209,17 @@ protected override Expression VisitTypeBinaryExpression(TypeBinaryExpression exp return expression; } - protected override Expression VisitUnaryExpression(UnaryExpression expression) + protected override Expression VisitUnary(UnaryExpression expression) { _string.Append(expression.NodeType); _string.Append('('); - VisitExpression(expression.Operand); + Visit(expression.Operand); _string.Append(')'); return expression; } - protected override Expression VisitQuerySourceReferenceExpression(Remotion.Linq.Clauses.Expressions.QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { _string.Append(expression.ReferencedQuerySource.ItemName); return expression; diff --git a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs index cc3662fb149..f937ce83013 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs @@ -14,7 +14,7 @@ namespace NHibernate.Linq.Visitors /// /// Locates constants in the expression tree and generates parameters for each one /// - public class ExpressionParameterVisitor : ExpressionTreeVisitor + public class ExpressionParameterVisitor : RelinqExpressionVisitor { private readonly Dictionary _parameters = new Dictionary(); private readonly ISessionFactoryImplementor _sessionFactory; @@ -28,7 +28,8 @@ public class ExpressionParameterVisitor : ExpressionTreeVisitor private static readonly MethodInfo EnumerableTakeDefinition = ReflectHelper.GetMethodDefinition(() => Enumerable.Take(null, 0)); - private readonly ICollection _pagingMethods = new HashSet + private readonly ICollection _pagingMethods = + new HashSet { QueryableSkipDefinition, QueryableTakeDefinition, EnumerableSkipDefinition, EnumerableTakeDefinition @@ -48,16 +49,16 @@ internal static IDictionary Visit(ref Expres { var visitor = new ExpressionParameterVisitor(sessionFactory); - expression = visitor.VisitExpression(expression); + expression = visitor.Visit(expression); return visitor._parameters; } - protected override Expression VisitMethodCallExpression(MethodCallExpression expression) + protected override Expression VisitMethodCall(MethodCallExpression expression) { if (expression.Method.Name == nameof(LinqExtensionMethods.MappedAs) && expression.Method.DeclaringType == typeof(LinqExtensionMethods)) { - var rawParameter = VisitExpression(expression.Arguments[0]); + var rawParameter = Visit(expression.Arguments[0]); var parameter = rawParameter as ConstantExpression; var type = expression.Arguments[1] as ConstantExpression; if (parameter == null) @@ -81,7 +82,7 @@ protected override Expression VisitMethodCallExpression(MethodCallExpression exp if (_pagingMethods.Contains(method) && !_sessionFactory.Dialect.SupportsVariableLimit) { //TODO: find a way to make this code cleaner - var query = VisitExpression(expression.Arguments[0]); + var query = Visit(expression.Arguments[0]); var arg = expression.Arguments[1]; if (query == expression.Arguments[0]) @@ -95,10 +96,10 @@ protected override Expression VisitMethodCallExpression(MethodCallExpression exp return expression; } - return base.VisitMethodCallExpression(expression); + return base.VisitMethodCall(expression); } - protected override Expression VisitConstantExpression(ConstantExpression expression) + protected override Expression VisitConstant(ConstantExpression expression) { if (!_parameters.ContainsKey(expression) && !typeof(IQueryable).IsAssignableFrom(expression.Type) && !IsNullObject(expression)) { @@ -125,12 +126,10 @@ protected override Expression VisitConstantExpression(ConstantExpression express _parameters.Add(expression, new NamedParameter("p" + (_parameters.Count + 1), value, type)); } - return base.VisitConstantExpression(expression); + return base.VisitConstant(expression); } private static bool IsNullObject(ConstantExpression expression) - { - return expression.Type == typeof(Object) && expression.Value == null; - } + => expression.Type == typeof(Object) && expression.Value == null; } } diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs similarity index 59% rename from src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs rename to src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index b009496ec56..71105e900bd 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -8,37 +8,27 @@ using NHibernate.Param; using NHibernate.Util; using Remotion.Linq.Clauses.Expressions; -using Remotion.Linq.Clauses.ResultOperators; namespace NHibernate.Linq.Visitors { - public class HqlGeneratorExpressionTreeVisitor : IHqlExpressionVisitor + public class HqlGeneratorExpressionVisitor : IHqlExpressionVisitor { private readonly HqlTreeBuilder _hqlTreeBuilder = new HqlTreeBuilder(); private readonly VisitorParameters _parameters; private readonly ILinqToHqlGeneratorsRegistry _functionRegistry; public static HqlTreeNode Visit(Expression expression, VisitorParameters parameters) - { - return new HqlGeneratorExpressionTreeVisitor(parameters).VisitExpression(expression); - } + => new HqlGeneratorExpressionVisitor(parameters).Visit(expression); - public HqlGeneratorExpressionTreeVisitor(VisitorParameters parameters) + public HqlGeneratorExpressionVisitor(VisitorParameters parameters) { _functionRegistry = parameters.SessionFactory.Settings.LinqToHqlGeneratorsRegistry; _parameters = parameters; } - - public ISessionFactory SessionFactory { get { return _parameters.SessionFactory; } } - - + public ISessionFactory SessionFactory => _parameters.SessionFactory; + public HqlTreeNode Visit(Expression expression) - { - return VisitExpression(expression); - } - - protected HqlTreeNode VisitExpression(Expression expression) { if (expression == null) return null; @@ -54,7 +44,7 @@ protected HqlTreeNode VisitExpression(Expression expression) case ExpressionType.Quote: case ExpressionType.TypeAs: case ExpressionType.UnaryPlus: - return VisitUnaryExpression((UnaryExpression) expression); + return VisitUnaryExpression((UnaryExpression)expression); case ExpressionType.Add: case ExpressionType.AddChecked: case ExpressionType.Divide: @@ -79,66 +69,63 @@ protected HqlTreeNode VisitExpression(Expression expression) case ExpressionType.LessThanOrEqual: case ExpressionType.Coalesce: case ExpressionType.ArrayIndex: - return VisitBinaryExpression((BinaryExpression) expression); + return VisitBinaryExpression((BinaryExpression)expression); case ExpressionType.Conditional: - return VisitConditionalExpression((ConditionalExpression) expression); + return VisitConditionalExpression((ConditionalExpression)expression); case ExpressionType.Constant: - return VisitConstantExpression((ConstantExpression) expression); + return VisitConstantExpression((ConstantExpression)expression); case ExpressionType.Invoke: - return VisitInvocationExpression((InvocationExpression) expression); + return VisitInvocationExpression((InvocationExpression)expression); case ExpressionType.Lambda: - return VisitLambdaExpression((LambdaExpression) expression); + return VisitLambdaExpression((LambdaExpression)expression); case ExpressionType.MemberAccess: - return VisitMemberExpression((MemberExpression) expression); + return VisitMemberExpression((MemberExpression)expression); case ExpressionType.Call: - return VisitMethodCallExpression((MethodCallExpression) expression); - //case ExpressionType.New: - // return VisitNewExpression((NewExpression)expression); - //case ExpressionType.NewArrayBounds: + return VisitMethodCallExpression((MethodCallExpression)expression); + //case ExpressionType.New: + // return VisitNewExpression((NewExpression)expression); + //case ExpressionType.NewArrayBounds: case ExpressionType.NewArrayInit: - return VisitNewArrayExpression((NewArrayExpression) expression); - //case ExpressionType.MemberInit: - // return VisitMemberInitExpression((MemberInitExpression)expression); - //case ExpressionType.ListInit: - // return VisitListInitExpression((ListInitExpression)expression); + return VisitNewArrayExpression((NewArrayExpression)expression); + //case ExpressionType.MemberInit: + // return VisitMemberInitExpression((MemberInitExpression)expression); + //case ExpressionType.ListInit: + // return VisitListInitExpression((ListInitExpression)expression); case ExpressionType.Parameter: - return VisitParameterExpression((ParameterExpression) expression); + return VisitParameterExpression((ParameterExpression)expression); case ExpressionType.TypeIs: - return VisitTypeBinaryExpression((TypeBinaryExpression) expression); + return VisitTypeBinaryExpression((TypeBinaryExpression)expression); default: - var subQueryExpression = expression as SubQueryExpression; - if (subQueryExpression != null) - return VisitSubQueryExpression(subQueryExpression); - - var querySourceReferenceExpression = expression as QuerySourceReferenceExpression; - if (querySourceReferenceExpression != null) - return VisitQuerySourceReferenceExpression(querySourceReferenceExpression); - - var vbStringComparisonExpression = expression as VBStringComparisonExpression; - if (vbStringComparisonExpression != null) - return VisitVBStringComparisonExpression(vbStringComparisonExpression); - - switch ((NhExpressionType) expression.NodeType) + switch (expression) { - case NhExpressionType.Average: - return VisitNhAverage((NhAverageExpression) expression); - case NhExpressionType.Min: - return VisitNhMin((NhMinExpression) expression); - case NhExpressionType.Max: - return VisitNhMax((NhMaxExpression) expression); - case NhExpressionType.Sum: - return VisitNhSum((NhSumExpression) expression); - case NhExpressionType.Count: - return VisitNhCount((NhCountExpression) expression); - case NhExpressionType.Distinct: - return VisitNhDistinct((NhDistinctExpression) expression); - case NhExpressionType.Star: - return VisitNhStar((NhStarExpression) expression); - case NhExpressionType.Nominator: - return VisitExpression(((NhNominatedExpression) expression).Expression); - //case NhExpressionType.New: - // return VisitNhNew((NhNewExpression)expression); + case SubQueryExpression subQueryExpression: + return VisitSubQueryExpression(subQueryExpression); + case QuerySourceReferenceExpression querySourceReferenceExpression: + return VisitQuerySourceReferenceExpression(querySourceReferenceExpression); + case VBStringComparisonExpression vbStringComparisonExpression: + return VisitVBStringComparisonExpression(vbStringComparisonExpression); + case NhSimpleExpression nhExpression: + switch (nhExpression.NhNodeType) + { + case NhExpressionType.Average: + return VisitNhAverage(nhExpression); + case NhExpressionType.Min: + return VisitNhMin(nhExpression); + case NhExpressionType.Max: + return VisitNhMax(nhExpression); + case NhExpressionType.Sum: + return VisitNhSum(nhExpression); + case NhExpressionType.Count: + return VisitNhCount(nhExpression); + case NhExpressionType.Distinct: + return VisitNhDistinct(nhExpression); + case NhExpressionType.Star: + return VisitNhStar(nhExpression); + case NhExpressionType.Nominator: + return Visit(nhExpression.Expression); + } + break; } throw new NotSupportedException(expression.ToString()); @@ -146,20 +133,17 @@ protected HqlTreeNode VisitExpression(Expression expression) } private HqlTreeNode VisitTypeBinaryExpression(TypeBinaryExpression expression) - { - return BuildOfType(expression.Expression, expression.TypeOperand); - } + => BuildOfType(expression.Expression, expression.TypeOperand); internal HqlBooleanExpression BuildOfType(Expression expression, System.Type type) { var sessionFactory = _parameters.SessionFactory; - var meta = sessionFactory.GetClassMetadata(type) as Persister.Entity.AbstractEntityPersister; - if (meta != null && !meta.IsExplicitPolymorphism) + if (sessionFactory.GetClassMetadata(type) is Persister.Entity.AbstractEntityPersister meta && !meta.IsExplicitPolymorphism) { //Adapted the logic found in SingleTableEntityPersister.DiscriminatorFilterFragment var nodes = meta .SubclassClosure - .Select(typeName => (NHibernate.Persister.Entity.IQueryable) sessionFactory.GetEntityPersister(typeName)) + .Select(typeName => (NHibernate.Persister.Entity.IQueryable)sessionFactory.GetEntityPersister(typeName)) .Where(persister => !persister.IsAbstract) .Select(persister => _hqlTreeBuilder.Ident(persister.EntityName)) .ToList(); @@ -183,7 +167,7 @@ internal HqlBooleanExpression BuildOfType(Expression expression, System.Type typ if (nodes.Count == 0) { const string abstractClassWithNoSubclassExceptionMessageTemplate = -@"The class {0} can't be instatiated and does not have mapped subclasses; +@"The class {0} can't be instantiated and does not have mapped subclasses; possible solutions: - don't map the abstract class - map its subclasses."; @@ -197,61 +181,45 @@ internal HqlBooleanExpression BuildOfType(Expression expression, System.Type typ _hqlTreeBuilder.Ident(type.FullName)); } - protected HqlTreeNode VisitNhStar(NhStarExpression expression) - { - return _hqlTreeBuilder.Star(); - } + protected HqlTreeNode VisitNhStar(NhSimpleExpression expression) + => _hqlTreeBuilder.Star(); private HqlTreeNode VisitInvocationExpression(InvocationExpression expression) - { - return VisitExpression(expression.Expression); - } + => Visit(expression.Expression); - protected HqlTreeNode VisitNhAverage(NhAverageExpression expression) + protected HqlTreeNode VisitNhAverage(NhSimpleExpression expression) { - var hqlExpression = VisitExpression(expression.Expression).AsExpression(); + var hqlExpression = Visit(expression.Expression).AsExpression(); if (expression.Type != expression.Expression.Type) hqlExpression = _hqlTreeBuilder.Cast(hqlExpression, expression.Type); return _hqlTreeBuilder.Cast(_hqlTreeBuilder.Average(hqlExpression), expression.Type); } - protected HqlTreeNode VisitNhCount(NhCountExpression expression) - { - return _hqlTreeBuilder.Cast(_hqlTreeBuilder.Count(VisitExpression(expression.Expression).AsExpression()), expression.Type); - } + protected HqlTreeNode VisitNhCount(NhSimpleExpression expression) + => _hqlTreeBuilder.Cast(_hqlTreeBuilder.Count(Visit(expression.Expression).AsExpression()), expression.Type); - protected HqlTreeNode VisitNhMin(NhMinExpression expression) - { - return _hqlTreeBuilder.Min(VisitExpression(expression.Expression).AsExpression()); - } + protected HqlTreeNode VisitNhMin(NhSimpleExpression expression) + =>_hqlTreeBuilder.Min(Visit(expression.Expression).AsExpression()); - protected HqlTreeNode VisitNhMax(NhMaxExpression expression) - { - return _hqlTreeBuilder.Max(VisitExpression(expression.Expression).AsExpression()); - } + protected HqlTreeNode VisitNhMax(NhSimpleExpression expression) + => _hqlTreeBuilder.Max(Visit(expression.Expression).AsExpression()); - protected HqlTreeNode VisitNhSum(NhSumExpression expression) - { - return _hqlTreeBuilder.Cast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type); - } + protected HqlTreeNode VisitNhSum(NhSimpleExpression expression) + => _hqlTreeBuilder.Cast(_hqlTreeBuilder.Sum(Visit(expression.Expression).AsExpression()), expression.Type); - protected HqlTreeNode VisitNhDistinct(NhDistinctExpression expression) + protected HqlTreeNode VisitNhDistinct(NhSimpleExpression expression) { - var visitor = new HqlGeneratorExpressionTreeVisitor(_parameters); - return _hqlTreeBuilder.ExpressionSubTreeHolder(_hqlTreeBuilder.Distinct(), visitor.VisitExpression(expression.Expression)); + var visitor = new HqlGeneratorExpressionVisitor(_parameters); + return _hqlTreeBuilder.ExpressionSubTreeHolder(_hqlTreeBuilder.Distinct(), visitor.Visit(expression.Expression)); } protected HqlTreeNode VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) - { - return _hqlTreeBuilder.Ident(_parameters.QuerySourceNamer.GetName(expression.ReferencedQuerySource)); - } + => _hqlTreeBuilder.Ident(_parameters.QuerySourceNamer.GetName(expression.ReferencedQuerySource)); private HqlTreeNode VisitVBStringComparisonExpression(VBStringComparisonExpression expression) - { // We ignore the case sensitivity flag in the same way that == does. - return VisitExpression(expression.Comparison); - } + => Visit(expression.Comparison); protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression) { @@ -264,8 +232,8 @@ protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression) return TranslateInequalityComparison(expression); } - var lhs = VisitExpression(expression.Left).AsExpression(); - var rhs = VisitExpression(expression.Right).AsExpression(); + var lhs = Visit(expression.Left).AsExpression(); + var rhs = Visit(expression.Right).AsExpression(); switch (expression.NodeType) { @@ -282,7 +250,7 @@ protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression) return _hqlTreeBuilder.BooleanOr(lhs.ToBooleanExpression(), rhs.ToBooleanExpression()); case ExpressionType.Add: - if (expression.Left.Type == typeof (string) && expression.Right.Type == typeof(string)) + if (expression.Left.Type == typeof(string) && expression.Right.Type == typeof(string)) { return _hqlTreeBuilder.MethodCall("concat", lhs, rhs); } @@ -321,8 +289,8 @@ protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression) private HqlTreeNode TranslateInequalityComparison(BinaryExpression expression) { - var lhs = VisitExpression(expression.Left).ToArithmeticExpression(); - var rhs = VisitExpression(expression.Right).ToArithmeticExpression(); + var lhs = Visit(expression.Left).ToArithmeticExpression(); + var rhs = Visit(expression.Right).ToArithmeticExpression(); // Check for nulls on left or right. if (VisitorUtil.IsNullConstant(expression.Right)) @@ -355,8 +323,8 @@ private HqlTreeNode TranslateInequalityComparison(BinaryExpression expression) return inequality; } - var lhs2 = VisitExpression(expression.Left).ToArithmeticExpression(); - var rhs2 = VisitExpression(expression.Right).ToArithmeticExpression(); + var lhs2 = Visit(expression.Left).ToArithmeticExpression(); + var rhs2 = Visit(expression.Right).ToArithmeticExpression(); HqlBooleanExpression booleanExpression; if (lhsNullable && rhsNullable) @@ -379,8 +347,8 @@ private HqlTreeNode TranslateInequalityComparison(BinaryExpression expression) private HqlTreeNode TranslateEqualityComparison(BinaryExpression expression) { - var lhs = VisitExpression(expression.Left).ToArithmeticExpression(); - var rhs = VisitExpression(expression.Right).ToArithmeticExpression(); + var lhs = Visit(expression.Left).ToArithmeticExpression(); + var rhs = Visit(expression.Right).ToArithmeticExpression(); // Check for nulls on left or right. if (VisitorUtil.IsNullConstant(expression.Right)) @@ -418,8 +386,8 @@ private HqlTreeNode TranslateEqualityComparison(BinaryExpression expression) return equality; } - var lhs2 = VisitExpression(expression.Left).ToArithmeticExpression(); - var rhs2 = VisitExpression(expression.Right).ToArithmeticExpression(); + var lhs2 = Visit(expression.Left).ToArithmeticExpression(); + var rhs2 = Visit(expression.Right).ToArithmeticExpression(); return _hqlTreeBuilder.BooleanOr( equality, @@ -429,31 +397,28 @@ private HqlTreeNode TranslateEqualityComparison(BinaryExpression expression) } static bool IsNullable(HqlExpression original) - { - var hqlDot = original as HqlDot; - return hqlDot != null && hqlDot.Children.Last() is HqlIdent; - } + => original is HqlDot hqlDot && hqlDot.Children.Last() is HqlIdent; protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression) { switch (expression.NodeType) { case ExpressionType.Negate: - return _hqlTreeBuilder.Negate(VisitExpression(expression.Operand).AsExpression()); + return _hqlTreeBuilder.Negate(Visit(expression.Operand).AsExpression()); case ExpressionType.UnaryPlus: - return VisitExpression(expression.Operand).AsExpression(); + return Visit(expression.Operand).AsExpression(); case ExpressionType.Not: - return _hqlTreeBuilder.BooleanNot(VisitExpression(expression.Operand).ToBooleanExpression()); + return _hqlTreeBuilder.BooleanNot(Visit(expression.Operand).ToBooleanExpression()); case ExpressionType.Convert: case ExpressionType.ConvertChecked: case ExpressionType.TypeAs: if ((expression.Operand.Type.IsPrimitive || expression.Operand.Type == typeof(Decimal)) && (expression.Type.IsPrimitive || expression.Type == typeof(Decimal))) { - return _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), expression.Type); + return _hqlTreeBuilder.Cast(Visit(expression.Operand).AsExpression(), expression.Type); } - return VisitExpression(expression.Operand); + return Visit(expression.Operand); } throw new NotSupportedException(expression.ToString()); @@ -462,37 +427,32 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression) protected HqlTreeNode VisitMemberExpression(MemberExpression expression) { // Strip out the .Value property of a nullable type, HQL doesn't need that - if (expression.Member.Name == "Value" && expression.Expression.Type.IsNullable()) + if (expression.Member.Name == nameof(Nullable.Value) && expression.Expression.Type.IsNullable()) { - return VisitExpression(expression.Expression); + return Visit(expression.Expression); } // Look for "special" properties (DateTime.Month etc) - IHqlGeneratorForProperty generator; - - if (_functionRegistry.TryGetGenerator(expression.Member, out generator)) + if (_functionRegistry.TryGetGenerator(expression.Member, out IHqlGeneratorForProperty generator)) { return generator.BuildHql(expression.Member, expression.Expression, _hqlTreeBuilder, this); } // Else just emit standard HQL for a property reference - return _hqlTreeBuilder.Dot(VisitExpression(expression.Expression).AsExpression(), _hqlTreeBuilder.Ident(expression.Member.Name)); + return _hqlTreeBuilder.Dot(Visit(expression.Expression).AsExpression(), _hqlTreeBuilder.Ident(expression.Member.Name)); } protected HqlTreeNode VisitConstantExpression(ConstantExpression expression) { if (expression.Value != null) { - IEntityNameProvider entityName = expression.Value as IEntityNameProvider; - if (entityName != null) + if (expression.Value is IEntityNameProvider entityName) { return _hqlTreeBuilder.Ident(entityName.EntityName); } } - NamedParameter namedParameter; - - if (_parameters.ConstantToParameterMap.TryGetValue(expression, out namedParameter)) + if (_parameters.ConstantToParameterMap.TryGetValue(expression, out NamedParameter namedParameter)) { _parameters.RequiredHqlParameters.Add(new NamedParameterDescriptor(namedParameter.Name, null, false)); @@ -504,10 +464,8 @@ protected HqlTreeNode VisitConstantExpression(ConstantExpression expression) protected HqlTreeNode VisitMethodCallExpression(MethodCallExpression expression) { - IHqlGeneratorForMethod generator; - var method = expression.Method; - if (!_functionRegistry.TryGetGenerator(method, out generator)) + if (!_functionRegistry.TryGetGenerator(method, out IHqlGeneratorForMethod generator)) { throw new NotSupportedException(method.ToString()); } @@ -516,39 +474,35 @@ protected HqlTreeNode VisitMethodCallExpression(MethodCallExpression expression) } protected HqlTreeNode VisitLambdaExpression(LambdaExpression expression) - { - return VisitExpression(expression.Body); - } + => Visit(expression.Body); protected HqlTreeNode VisitParameterExpression(ParameterExpression expression) - { - return _hqlTreeBuilder.Ident(expression.Name); - } + => _hqlTreeBuilder.Ident(expression.Name); protected HqlTreeNode VisitConditionalExpression(ConditionalExpression expression) { - var test = VisitExpression(expression.Test).ToBooleanExpression(); - var ifTrue = VisitExpression(expression.IfTrue).ToArithmeticExpression(); + var test = Visit(expression.Test).ToBooleanExpression(); + var ifTrue = Visit(expression.IfTrue).ToArithmeticExpression(); var ifFalse = (expression.IfFalse != null - ? VisitExpression(expression.IfFalse).ToArithmeticExpression() - : null); + ? Visit(expression.IfFalse).ToArithmeticExpression() + : null); - HqlExpression @case = _hqlTreeBuilder.Case(new[] {_hqlTreeBuilder.When(test, ifTrue)}, ifFalse); + HqlExpression @case = _hqlTreeBuilder.Case(new[] { _hqlTreeBuilder.When(test, ifTrue) }, ifFalse); - return (expression.Type == typeof (bool) || expression.Type == (typeof (bool?))) - ? @case - : _hqlTreeBuilder.Cast(@case, expression.Type); + return (expression.Type == typeof(bool) || expression.Type == (typeof(bool?))) + ? @case + : _hqlTreeBuilder.Cast(@case, expression.Type); } protected HqlTreeNode VisitSubQueryExpression(SubQueryExpression expression) { - ExpressionToHqlTranslationResults query = QueryModelVisitor.GenerateHqlQuery(expression.QueryModel, _parameters, false, null); + var query = QueryModelVisitor.GenerateHqlQuery(expression.QueryModel, _parameters, false); return query.Statement; } protected HqlTreeNode VisitNewArrayExpression(NewArrayExpression expression) { - var expressionSubTree = expression.Expressions.Select(exp => VisitExpression(exp)).ToArray(); + var expressionSubTree = expression.Expressions.Select(exp => Visit(exp)).ToArray(); return _hqlTreeBuilder.ExpressionSubTreeHolder(expressionSubTree); } } diff --git a/src/NHibernate/Linq/Visitors/JoinBuilder.cs b/src/NHibernate/Linq/Visitors/JoinBuilder.cs index b785d6e18d5..c020c6644fa 100644 --- a/src/NHibernate/Linq/Visitors/JoinBuilder.cs +++ b/src/NHibernate/Linq/Visitors/JoinBuilder.cs @@ -65,21 +65,21 @@ public bool CanAddJoin(Expression expression) return resultOperatorBase != null && _queryModel.ResultOperators.Contains(resultOperatorBase); } - private class QuerySourceExtractor : ExpressionTreeVisitor + private class QuerySourceExtractor : RelinqExpressionVisitor { private IQuerySource _querySource; public static IQuerySource GetQuerySource(Expression expression) { var sourceExtractor = new QuerySourceExtractor(); - sourceExtractor.VisitExpression(expression); + sourceExtractor.Visit(expression); return sourceExtractor._querySource; } - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { _querySource = expression.ReferencedQuerySource; - return base.VisitQuerySourceReferenceExpression(expression); + return base.VisitQuerySourceReference(expression); } } } diff --git a/src/NHibernate/Linq/Visitors/LeftJoinRewriter.cs b/src/NHibernate/Linq/Visitors/LeftJoinRewriter.cs index e8cf6299208..b39ad521759 100644 --- a/src/NHibernate/Linq/Visitors/LeftJoinRewriter.cs +++ b/src/NHibernate/Linq/Visitors/LeftJoinRewriter.cs @@ -2,7 +2,7 @@ using System.Linq; using Remotion.Linq; using Remotion.Linq.Clauses; -using Remotion.Linq.Clauses.ExpressionTreeVisitors; +using Remotion.Linq.Clauses.ExpressionVisitors; using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Clauses.ResultOperators; @@ -41,7 +41,7 @@ public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, var innerSelectorMapping = new QuerySourceMapping(); innerSelectorMapping.AddMapping(fromClause, subQueryModel.SelectClause.Selector); - queryModel.TransformExpressions(ex => ReferenceReplacingExpressionTreeVisitor.ReplaceClauseReferences(ex, innerSelectorMapping, false)); + queryModel.TransformExpressions(ex => ReferenceReplacingExpressionVisitor.ReplaceClauseReferences(ex, innerSelectorMapping, false)); queryModel.BodyClauses.RemoveAt(index); queryModel.BodyClauses.Insert(index, join); @@ -50,7 +50,7 @@ public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, var innerBodyClauseMapping = new QuerySourceMapping(); innerBodyClauseMapping.AddMapping(mainFromClause, new QuerySourceReferenceExpression(join)); - queryModel.TransformExpressions(ex => ReferenceReplacingExpressionTreeVisitor.ReplaceClauseReferences(ex, innerBodyClauseMapping, false)); + queryModel.TransformExpressions(ex => ReferenceReplacingExpressionVisitor.ReplaceClauseReferences(ex, innerBodyClauseMapping, false)); } private static void InsertBodyClauses(IEnumerable bodyClauses, QueryModel destinationQueryModel, int destinationIndex) diff --git a/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs b/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs index 537b02c674d..7f6b8dd1a1f 100644 --- a/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs +++ b/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs @@ -1,5 +1,3 @@ -using System.Collections; -using System.Collections.Generic; using System.Linq.Expressions; using NHibernate.Linq.Expressions; using NHibernate.Linq.ReWriters; @@ -14,7 +12,7 @@ namespace NHibernate.Linq.Visitors /// Replaces them with appropriate joins, maintaining reference equality between different clauses. /// This allows extracted GroupBy key expression to also be replaced so that they can continue to match replaced Select expressions /// - internal class MemberExpressionJoinDetector : ExpressionTreeVisitor + internal class MemberExpressionJoinDetector : RelinqExpressionVisitor { private readonly IIsEntityDecider _isEntityDecider; private readonly IJoiner _joiner; @@ -30,7 +28,7 @@ public MemberExpressionJoinDetector(IIsEntityDecider isEntityDecider, IJoiner jo _joiner = joiner; } - protected override Expression VisitMemberExpression(MemberExpression expression) + protected override Expression VisitMember(MemberExpression expression) { var isIdentifier = _isEntityDecider.IsIdentifier(expression.Expression.Type, expression.Member.Name); if (isIdentifier) @@ -38,8 +36,8 @@ protected override Expression VisitMemberExpression(MemberExpression expression) if (!isIdentifier) _memberExpressionDepth++; - var result = base.VisitMemberExpression(expression); - + var result = base.VisitMember(expression); + if (!isIdentifier) _memberExpressionDepth--; @@ -55,33 +53,33 @@ protected override Expression VisitMemberExpression(MemberExpression expression) return result; } - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { - expression.QueryModel.TransformExpressions(VisitExpression); + expression.QueryModel.TransformExpressions(Visit); return expression; } - protected override Expression VisitConditionalExpression(ConditionalExpression expression) + protected override Expression VisitConditional(ConditionalExpression expression) { var oldRequiresJoinForNonIdentifier = _requiresJoinForNonIdentifier; _requiresJoinForNonIdentifier = !_preventJoinsInConditionalTest && _requiresJoinForNonIdentifier; - var newTest = VisitExpression(expression.Test); + var newTest = Visit(expression.Test); _requiresJoinForNonIdentifier = oldRequiresJoinForNonIdentifier; - var newFalse = VisitExpression(expression.IfFalse); - var newTrue = VisitExpression(expression.IfTrue); + var newFalse = Visit(expression.IfFalse); + var newTrue = Visit(expression.IfTrue); if ((newTest != expression.Test) || (newFalse != expression.IfFalse) || (newTrue != expression.IfTrue)) return Expression.Condition(newTest, newTrue, newFalse); return expression; } - protected override Expression VisitExtensionExpression(ExtensionExpression expression) + protected override Expression VisitExtension(Expression expression) { // Nominated expressions need to prevent joins on non-Identifier member expressions // (for the test expression of conditional expressions only) // Otherwise an extra join is created and the GroupBy and Select clauses will not match var old = _preventJoinsInConditionalTest; - _preventJoinsInConditionalTest = (NhExpressionType)expression.NodeType == NhExpressionType.Nominator; - var expr = base.VisitExtensionExpression(expression); + _preventJoinsInConditionalTest = (expression as NhSimpleExpression)?.NhNodeType == NhExpressionType.Nominator; + var expr = base.VisitExtension(expression); _preventJoinsInConditionalTest = old; return expr; } @@ -90,18 +88,14 @@ public void Transform(SelectClause selectClause) { // The select clause typically requires joins for non-Identifier member access _requiresJoinForNonIdentifier = true; - selectClause.TransformExpressions(VisitExpression); + selectClause.TransformExpressions(Visit); _requiresJoinForNonIdentifier = false; } public void Transform(ResultOperatorBase resultOperator) - { - resultOperator.TransformExpressions(VisitExpression); - } + => resultOperator.TransformExpressions(Visit); public void Transform(Ordering ordering) - { - ordering.TransformExpressions(VisitExpression); - } + => ordering.TransformExpressions(Visit); } } \ No newline at end of file diff --git a/src/NHibernate/Linq/Visitors/NhExpressionTreeVisitor.cs b/src/NHibernate/Linq/Visitors/NhExpressionTreeVisitor.cs deleted file mode 100644 index 3fb10519975..00000000000 --- a/src/NHibernate/Linq/Visitors/NhExpressionTreeVisitor.cs +++ /dev/null @@ -1,98 +0,0 @@ -using System; -using System.Linq.Expressions; -using NHibernate.Linq.Expressions; -using Remotion.Linq.Parsing; - -namespace NHibernate.Linq.Visitors -{ - public class NhExpressionTreeVisitor : ExpressionTreeVisitor - { - public override Expression VisitExpression(Expression expression) - { - if (expression == null) - { - return null; - } - - switch ((NhExpressionType)expression.NodeType) - { - case NhExpressionType.Average: - case NhExpressionType.Min: - case NhExpressionType.Max: - case NhExpressionType.Sum: - case NhExpressionType.Count: - case NhExpressionType.Distinct: - return VisitNhAggregate((NhAggregatedExpression)expression); - case NhExpressionType.New: - return VisitNhNew((NhNewExpression)expression); - case NhExpressionType.Star: - return VisitNhStar((NhStarExpression)expression); - } - - // Keep this variable for easy examination during debug. - var expr = base.VisitExpression(expression); - return expr; - } - - protected virtual Expression VisitNhStar(NhStarExpression expression) - { - return expression.Accept(this); - } - - protected virtual Expression VisitNhNew(NhNewExpression expression) - { - return expression.Accept(this); - } - - protected virtual Expression VisitNhAggregate(NhAggregatedExpression expression) - { - switch ((NhExpressionType)expression.NodeType) - { - case NhExpressionType.Average: - return VisitNhAverage((NhAverageExpression)expression); - case NhExpressionType.Min: - return VisitNhMin((NhMinExpression)expression); - case NhExpressionType.Max: - return VisitNhMax((NhMaxExpression)expression); - case NhExpressionType.Sum: - return VisitNhSum((NhSumExpression)expression); - case NhExpressionType.Count: - return VisitNhCount((NhCountExpression)expression); - case NhExpressionType.Distinct: - return VisitNhDistinct((NhDistinctExpression)expression); - default: - throw new ArgumentException(); - } - } - - protected virtual Expression VisitNhDistinct(NhDistinctExpression expression) - { - return expression.Accept(this); - } - - protected virtual Expression VisitNhCount(NhCountExpression expression) - { - return expression.Accept(this); - } - - protected virtual Expression VisitNhSum(NhSumExpression expression) - { - return expression.Accept(this); - } - - protected virtual Expression VisitNhMax(NhMaxExpression expression) - { - return expression.Accept(this); - } - - protected virtual Expression VisitNhMin(NhMinExpression expression) - { - return expression.Accept(this); - } - - protected virtual Expression VisitNhAverage(NhAverageExpression expression) - { - return expression.Accept(this); - } - } -} \ No newline at end of file diff --git a/src/NHibernate/Linq/Visitors/NhExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/NhExpressionVisitor.cs new file mode 100644 index 00000000000..8c20a7c0db8 --- /dev/null +++ b/src/NHibernate/Linq/Visitors/NhExpressionVisitor.cs @@ -0,0 +1,70 @@ +using System; +using System.Linq.Expressions; +using NHibernate.Linq.Expressions; +using Remotion.Linq.Parsing; + +namespace NHibernate.Linq.Visitors +{ + public class NhExpressionVisitor : RelinqExpressionVisitor + { + public override Expression Visit(Expression expression) + { + if (expression == null) + { + return null; + } + + // Keep this variable for easy examination during debug. + var expr = base.Visit(expression); + return expr; + } + + public virtual Expression VisitNhStar(NhStarExpression expression) + => VisitExtension(expression); + + public virtual Expression VisitNhNew(NhNewExpression expression) + => VisitExtension(expression); + + internal virtual Expression VisitNhNominated(NhNominatedExpression expression) + => VisitExtension(expression); + + public virtual Expression VisitNhAggregate(NhAggregatedExpression expression) + { + switch (expression.NhNodeType) + { + case NhExpressionType.Average: + return VisitNhAverage(expression); + case NhExpressionType.Min: + return VisitNhMin(expression); + case NhExpressionType.Max: + return VisitNhMax(expression); + case NhExpressionType.Sum: + return VisitNhSum(expression); + case NhExpressionType.Count: + return VisitNhCount(expression); + case NhExpressionType.Distinct: + return VisitNhDistinct(expression); + default: + throw new ArgumentException($"Unsupported NH node type {expression.NhNodeType}.", nameof(expression)); + } + } + + protected virtual Expression VisitNhDistinct(NhSimpleExpression expression) + => VisitExtension(expression); + + protected virtual Expression VisitNhCount(NhSimpleExpression expression) + => VisitExtension(expression); + + protected virtual Expression VisitNhSum(NhSimpleExpression expression) + => VisitExtension(expression); + + protected virtual Expression VisitNhMax(NhSimpleExpression expression) + => VisitExtension(expression); + + protected virtual Expression VisitNhMin(NhSimpleExpression expression) + => VisitExtension(expression); + + protected virtual Expression VisitNhAverage(NhSimpleExpression expression) + => VisitExtension(expression); + } +} \ No newline at end of file diff --git a/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionTreeVisitor.cs b/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionTreeVisitor.cs deleted file mode 100644 index 462051d75b9..00000000000 --- a/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionTreeVisitor.cs +++ /dev/null @@ -1,32 +0,0 @@ -using System.Linq.Expressions; -using Remotion.Linq.Clauses.Expressions; -using Remotion.Linq.Parsing; -using Remotion.Linq.Parsing.ExpressionTreeVisitors; - -namespace NHibernate.Linq.Visitors -{ - internal class NhPartialEvaluatingExpressionTreeVisitor : ExpressionTreeVisitor, IPartialEvaluationExceptionExpressionVisitor - { - protected override Expression VisitConstantExpression(ConstantExpression expression) - { - var value = expression.Value as Expression; - if (value == null) - { - return base.VisitConstantExpression(expression); - } - - return EvaluateIndependentSubtrees(value); - } - - public static Expression EvaluateIndependentSubtrees(Expression expression) - { - var evaluatedExpression = PartialEvaluatingExpressionTreeVisitor.EvaluateIndependentSubtrees(expression); - return new NhPartialEvaluatingExpressionTreeVisitor().VisitExpression(evaluatedExpression); - } - - public Expression VisitPartialEvaluationExceptionExpression(PartialEvaluationExceptionExpression expression) - { - return VisitExpression(expression.Reduce()); - } - } -} \ No newline at end of file diff --git a/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs new file mode 100644 index 00000000000..35865aa73f5 --- /dev/null +++ b/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs @@ -0,0 +1,42 @@ +using System; +using System.Linq.Expressions; +using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Parsing; +using Remotion.Linq.Parsing.ExpressionVisitors; +using Remotion.Linq.Parsing.ExpressionVisitors.TreeEvaluation; + +namespace NHibernate.Linq.Visitors +{ + internal class NhPartialEvaluatingExpressionVisitor : RelinqExpressionVisitor + { + protected override Expression VisitConstant(ConstantExpression expression) + { + if (expression.Value as Expression == null) + { + return base.VisitConstant(expression); + } + + return EvaluateIndependentSubtrees(expression.Value as Expression); + } + + public static Expression EvaluateIndependentSubtrees(Expression expression) + { + var evaluatedExpression = PartialEvaluatingExpressionVisitor.EvaluateIndependentSubtrees(expression, new NhEvaluatableExpressionFilter()); + return new NhPartialEvaluatingExpressionVisitor().Visit(evaluatedExpression); + } + + internal class NhEvaluatableExpressionFilter : EvaluatableExpressionFilterBase + { + public NhEvaluatableExpressionFilter() + { } + + public override bool IsEvaluatableMethodCall(MethodCallExpression node) + { + if (node == null) + throw new ArgumentNullException(nameof(node)); + + return node.Method.GetCustomAttributes(typeof(DBOnlyAttribute), false).Length == 0; + } + } + } +} \ No newline at end of file diff --git a/src/NHibernate/Linq/Visitors/PagingRewriterSelectClauseVisitor.cs b/src/NHibernate/Linq/Visitors/PagingRewriterSelectClauseVisitor.cs index 87c0f90359f..fb07051f441 100644 --- a/src/NHibernate/Linq/Visitors/PagingRewriterSelectClauseVisitor.cs +++ b/src/NHibernate/Linq/Visitors/PagingRewriterSelectClauseVisitor.cs @@ -2,33 +2,33 @@ using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Parsing; -using Remotion.Linq.Parsing.ExpressionTreeVisitors; +using Remotion.Linq.Parsing.ExpressionVisitors; namespace NHibernate.Linq.Visitors { - internal class PagingRewriterSelectClauseVisitor : ExpressionTreeVisitor + internal class PagingRewriterSelectClauseVisitor : RelinqExpressionVisitor { - private readonly FromClauseBase querySource; + private readonly FromClauseBase _querySource; public PagingRewriterSelectClauseVisitor(FromClauseBase querySource) { - this.querySource = querySource; + _querySource = querySource; } public Expression Swap(Expression expression) { - return TransparentIdentifierRemovingExpressionTreeVisitor.ReplaceTransparentIdentifiers(VisitExpression(expression)); + return TransparentIdentifierRemovingExpressionVisitor.ReplaceTransparentIdentifiers(Visit(expression)); } - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { var innerSelector = GetSubQuerySelectorOrNull(expression); if (innerSelector != null) { - return VisitExpression(innerSelector); + return Visit(innerSelector); } - return base.VisitQuerySourceReferenceExpression(expression); + return base.VisitQuerySourceReference(expression); } /// @@ -37,7 +37,7 @@ protected override Expression VisitQuerySourceReferenceExpression(QuerySourceRef /// private Expression GetSubQuerySelectorOrNull(QuerySourceReferenceExpression expression) { - if (expression.ReferencedQuerySource != querySource) + if (expression.ReferencedQuerySource != _querySource) return null; var fromClause = expression.ReferencedQuerySource as FromClauseBase; diff --git a/src/NHibernate/Linq/Visitors/QueryExpressionSourceIdentifer.cs b/src/NHibernate/Linq/Visitors/QueryExpressionSourceIdentifer.cs index d0dcd2a0668..87003308674 100644 --- a/src/NHibernate/Linq/Visitors/QueryExpressionSourceIdentifer.cs +++ b/src/NHibernate/Linq/Visitors/QueryExpressionSourceIdentifer.cs @@ -4,7 +4,7 @@ namespace NHibernate.Linq.Visitors { - public class QueryExpressionSourceIdentifer : ExpressionTreeVisitor + public class QueryExpressionSourceIdentifer : RelinqExpressionVisitor { private readonly QuerySourceIdentifier _identifier; @@ -13,10 +13,10 @@ public QueryExpressionSourceIdentifer(QuerySourceIdentifier identifier) _identifier = identifier; } - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { _identifier.VisitQueryModel(expression.QueryModel); - return base.VisitSubQueryExpression(expression); + return base.VisitSubQuery(expression); } } } \ No newline at end of file diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index ee5bcceb8d0..7410f9d0960 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -21,8 +21,7 @@ namespace NHibernate.Linq.Visitors { public class QueryModelVisitor : QueryModelVisitorBase { - public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel queryModel, VisitorParameters parameters, bool root, - NhLinqExpressionReturnType? rootReturnType) + public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel queryModel, VisitorParameters parameters, bool root) { NestedSelectRewriter.ReWrite(queryModel, parameters); @@ -81,17 +80,14 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer // Identify and name query sources QuerySourceIdentifier.Visit(parameters.QuerySourceNamer, queryModel); - var visitor = new QueryModelVisitor(parameters, root, queryModel, rootReturnType) - { - RewrittenOperatorResult = result, - }; + var visitor = new QueryModelVisitor(parameters, root, queryModel, result); visitor.Visit(); return visitor._hqlTree.GetTranslation(); } private readonly IntermediateHqlTree _hqlTree; - private readonly NhLinqExpressionReturnType? _rootReturnType; + private readonly NhLinqExpressionReturnType _rootReturnType; private static readonly ResultOperatorMap ResultOperatorMap; private bool _serverSide = true; @@ -103,7 +99,7 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer public QueryModel Model { get; } - public ResultOperatorRewriterResult RewrittenOperatorResult { get; private set; } + public ResultOperatorRewriterResult RewrittenOperatorResult { get; } static QueryModelVisitor() { @@ -132,11 +128,12 @@ static QueryModelVisitor() } private QueryModelVisitor(VisitorParameters visitorParameters, bool root, QueryModel queryModel, - NhLinqExpressionReturnType? rootReturnType) + ResultOperatorRewriterResult rewrittenOperatorResult) { VisitorParameters = visitorParameters; Model = queryModel; - _rootReturnType = root ? rootReturnType : null; + _rootReturnType = visitorParameters.RootReturnType; + RewrittenOperatorResult = rewrittenOperatorResult; _hqlTree = new IntermediateHqlTree(root); } @@ -148,10 +145,12 @@ private void Visit() private void AddAdditionalPostExecuteTransformer() { - if (_rootReturnType == NhLinqExpressionReturnType.Scalar && Model.ResultTypeOverride != null) + if (_hqlTree.IsRoot && _rootReturnType == NhLinqExpressionReturnType.Scalar && Model.ResultTypeOverride != null && + Model.SelectClause.Selector.NodeType == ExpressionType.Extension && + Model.SelectClause.Selector is NhExpression nhExpression) { // NH-3850: handle polymorphic scalar results aggregation - switch ((NhExpressionType)Model.SelectClause.Selector.NodeType) + switch (nhExpression.NhNodeType) { case NhExpressionType.Average: // Polymorphic case complex to handle and not implemented. (HQL query must be reshaped for adding @@ -299,7 +298,7 @@ public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, if (fromClause.FromExpression is MemberExpression) { // It's a join - var expression = HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, VisitorParameters).AsExpression(); + var expression = HqlGeneratorExpressionVisitor.Visit(fromClause.FromExpression, VisitorParameters).AsExpression(); var querySourceName = VisitorParameters.QuerySourceNamer.GetName(fromClause); var alias = _hqlTree.TreeBuilder.Alias(querySourceName); var hqlJoin = VisitorParameters.IsLeftJoin(fromClause) ? @@ -308,7 +307,7 @@ public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, foreach (var withClause in VisitorParameters.GetRestrictions(fromClause)) { - var booleanExpression = HqlGeneratorExpressionTreeVisitor.Visit(withClause.Predicate, VisitorParameters).ToBooleanExpression(); + var booleanExpression = HqlGeneratorExpressionVisitor.Visit(withClause.Predicate, VisitorParameters).ToBooleanExpression(); hqlJoin.AddChild(_hqlTree.TreeBuilder.With(booleanExpression)); } @@ -327,7 +326,7 @@ private void AddFromClause(FromClauseBase fromClause) var querySourceName = VisitorParameters.QuerySourceNamer.GetName(fromClause); _hqlTree.AddFromClause( _hqlTree.TreeBuilder.Range( - HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, VisitorParameters), + HqlGeneratorExpressionVisitor.Visit(fromClause.FromExpression, VisitorParameters), _hqlTree.TreeBuilder.Alias(querySourceName))); } @@ -357,7 +356,7 @@ public override void VisitSelectClause(SelectClause selectClause, QueryModel que var visitor = new SelectClauseVisitor(typeof(object[]), VisitorParameters); - visitor.Visit(selectClause.Selector); + visitor.VisitRoot(selectClause.Selector); if (visitor.ProjectionExpression != null) { @@ -372,10 +371,10 @@ public override void VisitSelectClause(SelectClause selectClause, QueryModel que public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index) { var visitor = new SimplifyConditionalVisitor(); - whereClause.Predicate = visitor.VisitExpression(whereClause.Predicate); + whereClause.Predicate = visitor.Visit(whereClause.Predicate); // Visit the predicate to build the query - var expression = HqlGeneratorExpressionTreeVisitor.Visit(whereClause.Predicate, VisitorParameters).ToBooleanExpression(); + var expression = HqlGeneratorExpressionVisitor.Visit(whereClause.Predicate, VisitorParameters).ToBooleanExpression(); if (VisitorParameters.IsHavingClause(whereClause)) { _hqlTree.AddHavingClause(expression); @@ -391,7 +390,7 @@ public override void VisitOrderByClause(OrderByClause orderByClause, QueryModel { foreach (var clause in orderByClause.Orderings) { - _hqlTree.AddOrderByClause(HqlGeneratorExpressionTreeVisitor.Visit(clause.Expression, VisitorParameters).AsExpression(), + _hqlTree.AddOrderByClause(HqlGeneratorExpressionVisitor.Visit(clause.Expression, VisitorParameters).AsExpression(), clause.OrderingDirection == OrderingDirection.Asc ? _hqlTree.TreeBuilder.Ascending() : (HqlDirectionStatement)_hqlTree.TreeBuilder.Descending()); @@ -407,7 +406,7 @@ public override void VisitJoinClause(JoinClause joinClause, QueryModel queryMode _hqlTree.AddFromClause( _hqlTree.TreeBuilder.Range( - HqlGeneratorExpressionTreeVisitor.Visit(joinClause.InnerSequence, VisitorParameters), + HqlGeneratorExpressionVisitor.Visit(joinClause.InnerSequence, VisitorParameters), _hqlTree.TreeBuilder.Alias(joinClause.ItemName))); } diff --git a/src/NHibernate/Linq/Visitors/QuerySourceIdentifier.cs b/src/NHibernate/Linq/Visitors/QuerySourceIdentifier.cs index 72a06822512..3ae7d02c16f 100644 --- a/src/NHibernate/Linq/Visitors/QuerySourceIdentifier.cs +++ b/src/NHibernate/Linq/Visitors/QuerySourceIdentifier.cs @@ -54,17 +54,16 @@ public override void VisitJoinClause(JoinClause joinClause, QueryModel queryMode public override void VisitResultOperator(ResultOperatorBase resultOperator, QueryModel queryModel, int index) { - var groupBy = resultOperator as GroupResultOperator; - if (groupBy != null) + if (resultOperator is GroupResultOperator groupBy) _namer.Add(groupBy); } public override void VisitSelectClause(SelectClause selectClause, QueryModel queryModel) { //Find nested query sources - new QueryExpressionSourceIdentifer(this).VisitExpression(selectClause.Selector); + new QueryExpressionSourceIdentifer(this).Visit(selectClause.Selector); } - public QuerySourceNamer Namer { get { return _namer; } } + public QuerySourceNamer Namer => _namer; } } \ No newline at end of file diff --git a/src/NHibernate/Linq/Visitors/QuerySourceLocator.cs b/src/NHibernate/Linq/Visitors/QuerySourceLocator.cs index c6a80ddd380..d5b2aa5bb98 100644 --- a/src/NHibernate/Linq/Visitors/QuerySourceLocator.cs +++ b/src/NHibernate/Linq/Visitors/QuerySourceLocator.cs @@ -1,27 +1,26 @@ using Remotion.Linq; using Remotion.Linq.Clauses; -using Remotion.Linq.Collections; namespace NHibernate.Linq.Visitors { - public class QuerySourceLocator : QueryModelVisitorBase - { - private readonly System.Type _type; - private IQuerySource _querySource; + public class QuerySourceLocator : QueryModelVisitorBase + { + private readonly System.Type _type; + private IQuerySource _querySource; - private QuerySourceLocator(System.Type type) - { - _type = type; - } + private QuerySourceLocator(System.Type type) + { + _type = type; + } - public static IQuerySource FindQuerySource(QueryModel queryModel, System.Type type) - { - var finder = new QuerySourceLocator(type); + public static IQuerySource FindQuerySource(QueryModel queryModel, System.Type type) + { + var finder = new QuerySourceLocator(type); - finder.VisitQueryModel(queryModel); + finder.VisitQueryModel(queryModel); - return finder._querySource; - } + return finder._querySource; + } public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, QueryModel queryModel, int index) { @@ -37,16 +36,16 @@ public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, base.VisitAdditionalFromClause(fromClause, queryModel, index); } - public override void VisitMainFromClause(MainFromClause fromClause, QueryModel queryModel) - { - if (_type.IsAssignableFrom(fromClause.ItemType)) - { - _querySource = fromClause; - } - else - { - base.VisitMainFromClause(fromClause, queryModel); - } - } - } + public override void VisitMainFromClause(MainFromClause fromClause, QueryModel queryModel) + { + if (_type.IsAssignableFrom(fromClause.ItemType)) + { + _querySource = fromClause; + } + else + { + base.VisitMainFromClause(fromClause, queryModel); + } + } + } } \ No newline at end of file diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregate.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregate.cs index 7a2b5a819e3..574533820b7 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregate.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregate.cs @@ -3,7 +3,7 @@ using NHibernate.Util; using Remotion.Linq.Clauses.ResultOperators; using Remotion.Linq.Clauses.StreamedData; -using Remotion.Linq.Parsing.ExpressionTreeVisitors; +using Remotion.Linq.Parsing.ExpressionVisitors; namespace NHibernate.Linq.Visitors.ResultOperatorProcessors { @@ -15,7 +15,7 @@ public void Process(AggregateResultOperator resultOperator, QueryModelVisitor qu var inputType = inputExpr.Type; var paramExpr = Expression.Parameter(inputType, "item"); var accumulatorFunc = Expression.Lambda( - ReplacingExpressionTreeVisitor.Replace(inputExpr, paramExpr, resultOperator.Func.Body), + ReplacingExpressionVisitor.Replace(inputExpr, paramExpr, resultOperator.Func.Body), resultOperator.Func.Parameters[0], paramExpr); diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregateFromSeed.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregateFromSeed.cs index 4e55fe6382c..d369066a575 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregateFromSeed.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregateFromSeed.cs @@ -3,7 +3,7 @@ using NHibernate.Util; using Remotion.Linq.Clauses.ResultOperators; using Remotion.Linq.Clauses.StreamedData; -using Remotion.Linq.Parsing.ExpressionTreeVisitors; +using Remotion.Linq.Parsing.ExpressionVisitors; namespace NHibernate.Linq.Visitors.ResultOperatorProcessors { @@ -15,7 +15,7 @@ public void Process(AggregateFromSeedResultOperator resultOperator, QueryModelVi var inputType = inputExpr.Type; var paramExpr = Expression.Parameter(inputType, "item"); var accumulatorFunc = Expression.Lambda( - ReplacingExpressionTreeVisitor.Replace(inputExpr, paramExpr, resultOperator.Func.Body), + ReplacingExpressionVisitor.Replace(inputExpr, paramExpr, resultOperator.Func.Body), resultOperator.Func.Parameters[0], paramExpr); diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAll.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAll.cs index 692365a5d89..d66baf2fb8a 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAll.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAll.cs @@ -12,7 +12,7 @@ public class ProcessAll : IResultOperatorProcessor public void Process(AllResultOperator resultOperator, QueryModelVisitor queryModelVisitor, IntermediateHqlTree tree) { tree.AddWhereClause(tree.TreeBuilder.BooleanNot( - HqlGeneratorExpressionTreeVisitor.Visit(resultOperator.Predicate, queryModelVisitor.VisitorParameters). + HqlGeneratorExpressionVisitor.Visit(resultOperator.Predicate, queryModelVisitor.VisitorParameters). ToBooleanExpression())); if (tree.IsRoot) diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessContains.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessContains.cs index 01928d78999..17fc7850425 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessContains.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessContains.cs @@ -10,7 +10,7 @@ public class ProcessContains : IResultOperatorProcessor public void Process(ContainsResultOperator resultOperator, QueryModelVisitor queryModelVisitor, IntermediateHqlTree tree) { var itemExpression = - HqlGeneratorExpressionTreeVisitor.Visit(resultOperator.Item, queryModelVisitor.VisitorParameters) + HqlGeneratorExpressionVisitor.Visit(resultOperator.Item, queryModelVisitor.VisitorParameters) .AsExpression(); var from = GetFromRangeClause(tree.Root); diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessGroupBy.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessGroupBy.cs index 0c4e19db231..2ee37b5387b 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessGroupBy.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessGroupBy.cs @@ -16,7 +16,7 @@ public void Process(GroupResultOperator resultOperator, QueryModelVisitor queryM else groupByKeys = new[] {resultOperator.KeySelector}; - IEnumerable hqlGroupByKeys = groupByKeys.Select(k => HqlGeneratorExpressionTreeVisitor.Visit(k, queryModelVisitor.VisitorParameters).AsExpression()); + IEnumerable hqlGroupByKeys = groupByKeys.Select(k => HqlGeneratorExpressionVisitor.Visit(k, queryModelVisitor.VisitorParameters).AsExpression()); tree.AddGroupByClause(tree.TreeBuilder.GroupBy(hqlGroupByKeys.ToArray())); } diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessNonAggregatingGroupBy.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessNonAggregatingGroupBy.cs index abaf20e1471..1ce5acaf0a3 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessNonAggregatingGroupBy.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessNonAggregatingGroupBy.cs @@ -1,9 +1,8 @@ -using System; -using System.Collections.Generic; +using System.Collections.Generic; using System.Linq.Expressions; using NHibernate.Linq.ResultOperators; using NHibernate.Util; -using Remotion.Linq.Clauses.ExpressionTreeVisitors; +using Remotion.Linq.Clauses.ExpressionVisitors; namespace NHibernate.Linq.Visitors.ResultOperatorProcessors { @@ -22,9 +21,9 @@ public void Process(NonAggregatingGroupBy resultOperator, QueryModelVisitor quer // Stuff in the group by that doesn't map to HQL. Run it client-side var listParameter = Expression.Parameter(typeof(IEnumerable), "list"); - var keySelectorExpr = ReverseResolvingExpressionTreeVisitor.ReverseResolve(selector, keySelector); + var keySelectorExpr = ReverseResolvingExpressionVisitor.ReverseResolve(selector, keySelector); - var elementSelectorExpr = ReverseResolvingExpressionTreeVisitor.ReverseResolve(selector, elementSelector); + var elementSelectorExpr = ReverseResolvingExpressionVisitor.ReverseResolve(selector, elementSelector); var groupByMethod = ReflectionCache.EnumerableMethods.GroupByWithElementSelectorDefinition .MakeGenericMethod(new[] { sourceType, keyType, elementType }); diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessOfType.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessOfType.cs index 4da0271fee2..f6137239277 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessOfType.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessOfType.cs @@ -8,7 +8,7 @@ public void Process(OfTypeResultOperator resultOperator, QueryModelVisitor query { var source = queryModelVisitor.Model.SelectClause.GetOutputDataInfo().ItemExpression; - var expression = new HqlGeneratorExpressionTreeVisitor(queryModelVisitor.VisitorParameters) + var expression = new HqlGeneratorExpressionVisitor(queryModelVisitor.VisitorParameters) .BuildOfType(source, resultOperator.SearchedItemType); tree.AddWhereClause(expression); diff --git a/src/NHibernate/Linq/Visitors/SelectClauseNominator.cs b/src/NHibernate/Linq/Visitors/SelectClauseNominator.cs index 4dae4f2aa9e..43607e447a2 100644 --- a/src/NHibernate/Linq/Visitors/SelectClauseNominator.cs +++ b/src/NHibernate/Linq/Visitors/SelectClauseNominator.cs @@ -1,9 +1,7 @@ -using System; -using System.Collections.Generic; +using System.Collections.Generic; using System.Linq.Expressions; using NHibernate.Linq.Functions; using NHibernate.Linq.Expressions; -using NHibernate.Util; using Remotion.Linq.Parsing; namespace NHibernate.Linq.Visitors @@ -12,7 +10,7 @@ namespace NHibernate.Linq.Visitors /// Analyze the select clause to determine what parts can be translated /// fully to HQL, and some other properties of the clause. /// - class SelectClauseHqlNominator : ExpressionTreeVisitor + class SelectClauseHqlNominator : RelinqExpressionVisitor { private readonly ILinqToHqlGeneratorsRegistry _functionRegistry; @@ -35,28 +33,22 @@ class SelectClauseHqlNominator : ExpressionTreeVisitor public SelectClauseHqlNominator(VisitorParameters parameters) { _functionRegistry = parameters.SessionFactory.Settings.LinqToHqlGeneratorsRegistry; - } - - internal Expression Visit(Expression expression) - { HqlCandidates = new HashSet(); ContainsUntranslatedMethodCalls = false; _canBeCandidate = true; _stateStack = new Stack(); _stateStack.Push(false); - - return VisitExpression(expression); } - public override Expression VisitExpression(Expression expression) + public override Expression Visit(Expression expression) { if (expression == null) return null; - if (expression.NodeType == (ExpressionType)NhExpressionType.Nominator) + if (expression is NhNominatedExpression nominatedExpression) { // Add the nominated clause and strip the nominator wrapper from the select expression - var innerExpression = ((NhNominatedExpression)expression).Expression; + var innerExpression = nominatedExpression.Expression; HqlCandidates.Add(innerExpression); return innerExpression; } @@ -86,7 +78,7 @@ public override Expression VisitExpression(Expression expression) return expression; } - expression = base.VisitExpression(expression); + expression = base.Visit(expression); if (_canBeCandidate) { @@ -113,21 +105,24 @@ private bool IsRegisteredFunction(Expression expression) { if (expression.NodeType == ExpressionType.Call) { - var methodCallExpression = (MethodCallExpression) expression; - IHqlGeneratorForMethod methodGenerator; - if (_functionRegistry.TryGetGenerator(methodCallExpression.Method, out methodGenerator)) + var methodCallExpression = (MethodCallExpression)expression; + if (_functionRegistry.TryGetGenerator(methodCallExpression.Method, out IHqlGeneratorForMethod methodGenerator)) { return methodCallExpression.Object == null || // is static or extension method methodCallExpression.Object.NodeType != ExpressionType.Constant; // does not belong to parameter } } - else if (expression.NodeType == (ExpressionType)NhExpressionType.Sum || - expression.NodeType == (ExpressionType)NhExpressionType.Count || - expression.NodeType == (ExpressionType)NhExpressionType.Average || - expression.NodeType == (ExpressionType)NhExpressionType.Max || - expression.NodeType == (ExpressionType)NhExpressionType.Min) + else if (expression is NhExpression nhExpression) { - return true; + switch (nhExpression.NhNodeType) + { + case NhExpressionType.Sum: + case NhExpressionType.Count: + case NhExpressionType.Average: + case NhExpressionType.Max: + case NhExpressionType.Min: + return true; + } } return false; @@ -136,8 +131,8 @@ private bool IsRegisteredFunction(Expression expression) private bool CanBeEvaluatedInHqlSelectStatement(Expression expression, bool projectConstantsInHql) { // HQL can't do New or Member Init - if (expression.NodeType == ExpressionType.MemberInit || - expression.NodeType == ExpressionType.New || + if (expression.NodeType == ExpressionType.MemberInit || + expression.NodeType == ExpressionType.New || expression.NodeType == ExpressionType.NewArrayInit || expression.NodeType == ExpressionType.NewArrayBounds) { @@ -171,8 +166,6 @@ private bool CanBeEvaluatedInHqlSelectStatement(Expression expression, bool proj } private static bool CanBeEvaluatedInHqlStatementShortcut(Expression expression) - { - return ((NhExpressionType)expression.NodeType) == NhExpressionType.Count; - } + => expression is NhExpression nhExpression && nhExpression.NhNodeType == NhExpressionType.Count; } } \ No newline at end of file diff --git a/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs b/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs index 4b2254e53e3..bd3e6817863 100644 --- a/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs +++ b/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs @@ -8,7 +8,7 @@ namespace NHibernate.Linq.Visitors { - public class SelectClauseVisitor : ExpressionTreeVisitor + public class SelectClauseVisitor : RelinqExpressionVisitor { private readonly HqlTreeBuilder _hqlTreeBuilder = new HqlTreeBuilder(); private HashSet _hqlNodes; @@ -16,23 +16,21 @@ public class SelectClauseVisitor : ExpressionTreeVisitor private readonly VisitorParameters _parameters; private int _iColumn; private List _hqlTreeNodes = new List(); - private readonly HqlGeneratorExpressionTreeVisitor _hqlVisitor; + private readonly HqlGeneratorExpressionVisitor _hqlVisitor; public SelectClauseVisitor(System.Type inputType, VisitorParameters parameters) { _inputParameter = Expression.Parameter(inputType, "input"); _parameters = parameters; - _hqlVisitor = new HqlGeneratorExpressionTreeVisitor(_parameters); + _hqlVisitor = new HqlGeneratorExpressionVisitor(_parameters); } public LambdaExpression ProjectionExpression { get; private set; } public IEnumerable GetHqlNodes() - { - return _hqlTreeNodes; - } + => _hqlTreeNodes; - public void Visit(Expression expression) + public void VisitRoot(Expression expression) { var distinct = expression as NhDistinctExpression; if (distinct != null) @@ -53,7 +51,7 @@ public void Visit(Expression expression) throw new NotSupportedException("Cannot use distinct on result that depends on methods for which no SQL equivalent exist."); // Now visit the tree - var projection = VisitExpression(expression); + var projection = Visit(expression); if ((projection != expression) && !_hqlNodes.Contains(expression)) { @@ -65,13 +63,13 @@ public void Visit(Expression expression) if (distinct != null) { - var treeNodes = new List(_hqlTreeNodes.Count + 1) {_hqlTreeBuilder.Distinct()}; + var treeNodes = new List(_hqlTreeNodes.Count + 1) { _hqlTreeBuilder.Distinct() }; treeNodes.AddRange(_hqlTreeNodes); - _hqlTreeNodes = new List(1) {_hqlTreeBuilder.ExpressionSubTreeHolder(treeNodes)}; + _hqlTreeNodes = new List(1) { _hqlTreeBuilder.ExpressionSubTreeHolder(treeNodes) }; } } - public override Expression VisitExpression(Expression expression) + public override Expression Visit(Expression expression) { if (expression == null) { @@ -85,12 +83,15 @@ public override Expression VisitExpression(Expression expression) return Expression.Convert(Expression.ArrayIndex(_inputParameter, Expression.Constant(_iColumn++)), expression.Type); } - - // Can't handle this node with HQL. Just recurse down, and emit the expression - return base.VisitExpression(expression); + else + { + // Can't handle this node with HQL. Just recurse down, and emit the expression + return base.Visit(expression); + } } } + // To be removed in v6.0 (or 5.0?) [Obsolete] public static class BooleanToCaseConvertor { diff --git a/src/NHibernate/Linq/Visitors/SimplifyConditionalVisitor.cs b/src/NHibernate/Linq/Visitors/SimplifyConditionalVisitor.cs index 3263fb4d235..ffc497b0bc5 100644 --- a/src/NHibernate/Linq/Visitors/SimplifyConditionalVisitor.cs +++ b/src/NHibernate/Linq/Visitors/SimplifyConditionalVisitor.cs @@ -1,5 +1,4 @@ using System.Linq.Expressions; -using NHibernate.Util; using Remotion.Linq.Parsing; namespace NHibernate.Linq.Visitors @@ -7,26 +6,24 @@ namespace NHibernate.Linq.Visitors /// /// Some conditional expressions can be reduced to just their IfTrue or IfFalse part. /// - internal class SimplifyConditionalVisitor :ExpressionTreeVisitor + internal class SimplifyConditionalVisitor : RelinqExpressionVisitor { - protected override Expression VisitConditionalExpression(ConditionalExpression expression) + protected override Expression VisitConditional(ConditionalExpression expression) { - var testExpression = VisitExpression(expression.Test); + var testExpression = Visit(expression.Test); - bool testExprResult; - if (VisitorUtil.IsBooleanConstant(testExpression, out testExprResult)) + if (VisitorUtil.IsBooleanConstant(testExpression, out bool testExprResult)) { if (testExprResult) - return VisitExpression(expression.IfTrue); + return Visit(expression.IfTrue); - return VisitExpression(expression.IfFalse); + return Visit(expression.IfFalse); } - return base.VisitConditionalExpression(expression); + return base.VisitConditional(expression); } - - protected override Expression VisitBinaryExpression(BinaryExpression expression) + protected override Expression VisitBinary(BinaryExpression expression) { // See NH-3423. Conditional expression where the test expression is a comparison // of a construction expression and null will happen in WCF DS. @@ -42,24 +39,18 @@ protected override Expression VisitBinaryExpression(BinaryExpression expression) return Expression.Constant(true); } - return base.VisitBinaryExpression(expression); + return base.VisitBinary(expression); } - private static bool IsConstruction(Expression expression) - { - return expression is NewExpression || expression is MemberInitExpression; - } - + => expression is NewExpression || expression is MemberInitExpression; private static bool IsConstructionToNullComparison(Expression expression) { - var testExpression = expression as BinaryExpression; - - if (testExpression != null) + if (expression is BinaryExpression testExpression) { if ((IsConstruction(testExpression.Left) && VisitorUtil.IsNullConstant(testExpression.Right)) - || (IsConstruction(testExpression.Right) && VisitorUtil.IsNullConstant(testExpression.Left))) + || (IsConstruction(testExpression.Right) && VisitorUtil.IsNullConstant(testExpression.Left))) { return true; } diff --git a/src/NHibernate/Linq/Visitors/SubQueryFromClauseFlattener.cs b/src/NHibernate/Linq/Visitors/SubQueryFromClauseFlattener.cs index 56eace76c69..06d2ad460f7 100644 --- a/src/NHibernate/Linq/Visitors/SubQueryFromClauseFlattener.cs +++ b/src/NHibernate/Linq/Visitors/SubQueryFromClauseFlattener.cs @@ -3,7 +3,7 @@ using Remotion.Linq; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; -using Remotion.Linq.Clauses.ExpressionTreeVisitors; +using Remotion.Linq.Clauses.ExpressionVisitors; using Remotion.Linq.EagerFetching; namespace NHibernate.Linq.Visitors @@ -23,28 +23,26 @@ public static void ReWrite(QueryModel queryModel) public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, QueryModel queryModel, int index) { - var subQueryExpression = fromClause.FromExpression as SubQueryExpression; - if (subQueryExpression != null) + if (fromClause.FromExpression is SubQueryExpression subQueryExpression) FlattenSubQuery(subQueryExpression, fromClause, queryModel, index + 1); base.VisitAdditionalFromClause(fromClause, queryModel, index); } public override void VisitMainFromClause(MainFromClause fromClause, QueryModel queryModel) { - var subQueryExpression = fromClause.FromExpression as SubQueryExpression; - if (subQueryExpression != null) + if (fromClause.FromExpression is SubQueryExpression subQueryExpression) FlattenSubQuery(subQueryExpression, fromClause, queryModel, 0); base.VisitMainFromClause(fromClause, queryModel); } private static bool CheckFlattenable(QueryModel subQueryModel) { - if (subQueryModel.BodyClauses.OfType().Any()) + if (subQueryModel.BodyClauses.OfType().Any()) return false; - if (subQueryModel.ResultOperators.Count == 0) + if (subQueryModel.ResultOperators.Count == 0) return true; - + return HasJustAllFlattenableOperator(subQueryModel.ResultOperators); } @@ -70,14 +68,14 @@ private static void FlattenSubQuery(SubQueryExpression subQueryExpression, FromC var innerSelectorMapping = new QuerySourceMapping(); innerSelectorMapping.AddMapping(fromClause, subQueryExpression.QueryModel.SelectClause.Selector); - queryModel.TransformExpressions(ex => ReferenceReplacingExpressionTreeVisitor.ReplaceClauseReferences(ex, innerSelectorMapping, false)); + queryModel.TransformExpressions(ex => ReferenceReplacingExpressionVisitor.ReplaceClauseReferences(ex, innerSelectorMapping, false)); InsertBodyClauses(subQueryExpression.QueryModel.BodyClauses, queryModel, destinationIndex); InsertResultOperators(subQueryExpression.QueryModel.ResultOperators, queryModel); var innerBodyClauseMapping = new QuerySourceMapping(); innerBodyClauseMapping.AddMapping(mainFromClause, new QuerySourceReferenceExpression(fromClause)); - queryModel.TransformExpressions(ex => ReferenceReplacingExpressionTreeVisitor.ReplaceClauseReferences(ex, innerBodyClauseMapping, false)); + queryModel.TransformExpressions(ex => ReferenceReplacingExpressionVisitor.ReplaceClauseReferences(ex, innerBodyClauseMapping, false)); } internal static void InsertResultOperators(IEnumerable resultOperators, QueryModel queryModel) @@ -92,7 +90,7 @@ internal static void InsertResultOperators(IEnumerable resul private static void InsertBodyClauses(IEnumerable bodyClauses, QueryModel queryModel, int destinationIndex) { - foreach (var bodyClause in bodyClauses) + foreach (var bodyClause in bodyClauses) { queryModel.BodyClauses.Insert(destinationIndex, bodyClause); ++destinationIndex; diff --git a/src/NHibernate/Linq/Visitors/SwapQuerySourceVisitor.cs b/src/NHibernate/Linq/Visitors/SwapQuerySourceVisitor.cs index 70181397bc0..5dcaa069ec0 100644 --- a/src/NHibernate/Linq/Visitors/SwapQuerySourceVisitor.cs +++ b/src/NHibernate/Linq/Visitors/SwapQuerySourceVisitor.cs @@ -5,7 +5,7 @@ namespace NHibernate.Linq.Visitors { - public class SwapQuerySourceVisitor : ExpressionTreeVisitor + public class SwapQuerySourceVisitor : RelinqExpressionVisitor { private readonly IQuerySource _oldClause; private readonly IQuerySource _newClause; @@ -17,11 +17,9 @@ public SwapQuerySourceVisitor(IQuerySource oldClause, IQuerySource newClause) } public Expression Swap(Expression expression) - { - return VisitExpression(expression); - } + => Visit(expression); - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { if (expression.ReferencedQuerySource == _oldClause) { @@ -29,20 +27,18 @@ protected override Expression VisitQuerySourceReferenceExpression(QuerySourceRef } // TODO - really don't like this drill down approach. Feels fragile - var mainFromClause = expression.ReferencedQuerySource as MainFromClause; - - if (mainFromClause != null) + if (expression.ReferencedQuerySource is MainFromClause mainFromClause) { - mainFromClause.FromExpression = VisitExpression(mainFromClause.FromExpression); + mainFromClause.FromExpression = Visit(mainFromClause.FromExpression); } return expression; } - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { - expression.QueryModel.TransformExpressions(VisitExpression); - return base.VisitSubQueryExpression(expression); + expression.QueryModel.TransformExpressions(Visit); + return base.VisitSubQuery(expression); } } } \ No newline at end of file diff --git a/src/NHibernate/Linq/Visitors/VisitorParameters.cs b/src/NHibernate/Linq/Visitors/VisitorParameters.cs index 5e676a21c22..a79ffc28369 100644 --- a/src/NHibernate/Linq/Visitors/VisitorParameters.cs +++ b/src/NHibernate/Linq/Visitors/VisitorParameters.cs @@ -17,6 +17,8 @@ public class VisitorParameters public QuerySourceNamer QuerySourceNamer { get; } + public NhLinqExpressionReturnType RootReturnType { get; } + private readonly HashSet _havingClauses = new HashSet(); private readonly HashSet _leftJoins = new HashSet(); private readonly HashSet _withClauses = new HashSet(); @@ -26,12 +28,14 @@ public VisitorParameters( ISessionFactoryImplementor sessionFactory, IDictionary constantToParameterMap, List requiredHqlParameters, - QuerySourceNamer querySourceNamer) + QuerySourceNamer querySourceNamer, + NhLinqExpressionReturnType rootReturnType) { SessionFactory = sessionFactory; ConstantToParameterMap = constantToParameterMap; RequiredHqlParameters = requiredHqlParameters; QuerySourceNamer = querySourceNamer; + RootReturnType = rootReturnType; } /// @@ -40,9 +44,7 @@ public VisitorParameters( /// The clause to test. /// true if the clause needs to be converted to a HQL having clause, false otherwise. public bool IsHavingClause(WhereClause clause) - { - return _havingClauses.Contains(clause); - } + => _havingClauses.Contains(clause); /// /// Indicates if a Linq where clause needs to be converted to a HQL with clause. @@ -50,9 +52,7 @@ public bool IsHavingClause(WhereClause clause) /// The clause to test. /// true if the clause needs to be converted to a HQL with clause, false otherwise. public bool IsWithClause(WhereClause clause) - { - return _withClauses.Contains(clause); - } + => _withClauses.Contains(clause); /// /// Indicates if a Linq join clause needs to be converted to a HQL left join. @@ -60,9 +60,7 @@ public bool IsWithClause(WhereClause clause) /// The join to test. /// true if the clause needs to be converted to a HQL left join, false otherwise. public bool IsLeftJoin(AdditionalFromClause join) - { - return _leftJoins.Contains(join); - } + => _leftJoins.Contains(join); /// /// Get the clauses to apply to the join as HQL with clauses. diff --git a/src/NHibernate/Linq/Visitors/VisitorUtil.cs b/src/NHibernate/Linq/Visitors/VisitorUtil.cs index 5b559b2eb32..1a8c8e7baa7 100644 --- a/src/NHibernate/Linq/Visitors/VisitorUtil.cs +++ b/src/NHibernate/Linq/Visitors/VisitorUtil.cs @@ -10,7 +10,8 @@ namespace NHibernate.Linq.Visitors { public static class VisitorUtil { - public static bool IsDynamicComponentDictionaryGetter(MethodInfo method, Expression targetObject, IEnumerable arguments, ISessionFactory sessionFactory, out string memberName) + public static bool IsDynamicComponentDictionaryGetter(MethodInfo method, Expression targetObject, + IEnumerable arguments, ISessionFactory sessionFactory, out string memberName) { memberName = null; @@ -37,23 +38,22 @@ public static bool IsDynamicComponentDictionaryGetter(MethodInfo method, Express //Walk backwards if the owning member is not a mapped class (i.e a possible Component) targetObject = member.Expression; while (metaData == null && targetObject != null && - (targetObject.NodeType == ExpressionType.MemberAccess || targetObject.NodeType == ExpressionType.Parameter || - targetObject.NodeType == QuerySourceReferenceExpression.ExpressionType)) + (targetObject.NodeType == ExpressionType.MemberAccess || targetObject.NodeType == ExpressionType.Parameter || + targetObject is QuerySourceReferenceExpression)) { System.Type memberType; - if (targetObject.NodeType == QuerySourceReferenceExpression.ExpressionType) + if (targetObject is QuerySourceReferenceExpression querySourceExpression) { - var querySourceExpression = (QuerySourceReferenceExpression) targetObject; memberType = querySourceExpression.Type; } else if (targetObject.NodeType == ExpressionType.Parameter) { - var parameterExpression = (ParameterExpression) targetObject; + var parameterExpression = (ParameterExpression)targetObject; memberType = parameterExpression.Type; } else //targetObject.NodeType == ExpressionType.MemberAccess { - var memberExpression = ((MemberExpression) targetObject); + var memberExpression = ((MemberExpression)targetObject); memberPath = memberExpression.Member.Name + "." + memberPath; memberType = memberExpression.Type; targetObject = memberExpression.Expression; @@ -70,16 +70,10 @@ public static bool IsDynamicComponentDictionaryGetter(MethodInfo method, Express } public static bool IsDynamicComponentDictionaryGetter(MethodCallExpression expression, ISessionFactory sessionFactory, out string memberName) - { - return IsDynamicComponentDictionaryGetter(expression.Method, expression.Object, expression.Arguments, sessionFactory, out memberName); - } + => IsDynamicComponentDictionaryGetter(expression.Method, expression.Object, expression.Arguments, sessionFactory, out memberName); public static bool IsDynamicComponentDictionaryGetter(MethodCallExpression expression, ISessionFactory sessionFactory) - { - string memberName; - return IsDynamicComponentDictionaryGetter(expression, sessionFactory, out memberName); - } - + => IsDynamicComponentDictionaryGetter(expression, sessionFactory, out string memberName); public static bool IsNullConstant(Expression expression) { @@ -90,13 +84,11 @@ public static bool IsNullConstant(Expression expression) constantExpression.Value == null; } - public static bool IsBooleanConstant(Expression expression, out bool value) { - var constantExpr = expression as ConstantExpression; - if (constantExpr != null && constantExpr.Type == typeof (bool)) + if (expression is ConstantExpression constantExpr && constantExpr.Type == typeof(bool)) { - value = (bool) constantExpr.Value; + value = (bool)constantExpr.Value; return true; } diff --git a/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs b/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs index dcb5e57c617..a703eed8651 100644 --- a/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs +++ b/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs @@ -34,7 +34,7 @@ namespace NHibernate.Linq.Visitors /// a.B.C == 1 && a.D.E == 1 can be inner joined. /// a.B.C == 1 || a.D.E == 1 must be outer joined. /// - /// By default we outer join via the code in VisitExpression. The use of inner joins is only + /// By default we outer join via the code in Visit. The use of inner joins is only /// an optimization hint to the database. /// /// More examples: @@ -56,14 +56,14 @@ namespace NHibernate.Linq.Visitors /// /// The code here is based on the excellent work started by Harald Mueller. /// - internal class WhereJoinDetector : ExpressionTreeVisitor + internal class WhereJoinDetector : RelinqExpressionVisitor { // TODO: There are a number of types of expressions that we didn't handle here due to time constraints. For example, the ?: operator could be checked easily. private readonly IIsEntityDecider _isEntityDecider; private readonly IJoiner _joiner; private readonly Stack _handled = new Stack(); - + // Stack of result values of each expression. After an expression has processed itself, it adds itself to the stack. private readonly Stack _values = new Stack(); @@ -78,7 +78,7 @@ internal WhereJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner) public void Transform(WhereClause whereClause) { - whereClause.TransformExpressions(VisitExpression); + whereClause.TransformExpressions(Visit); var values = _values.Pop(); @@ -92,7 +92,7 @@ public void Transform(WhereClause whereClause) } } - public override Expression VisitExpression(Expression expression) + public override Expression Visit(Expression expression) { if (expression == null) return null; @@ -104,13 +104,13 @@ public override Expression VisitExpression(Expression expression) _handled.Push(false); int originalCount = _values.Count; - Expression result = base.VisitExpression(expression); + Expression result = base.Visit(expression); if (!_handled.Pop()) { // If this expression was not handled in a known way, we throw away any values that might // have been returned and we return "all values" for this expression, since we don't know - // what the expresson might result in. + // what the expression might result in. while (_values.Count > originalCount) _values.Pop(); _values.Push(new ExpressionValues(PossibleValueSet.CreateAllValues(expression.Type))); @@ -119,174 +119,164 @@ public override Expression VisitExpression(Expression expression) return result; } - protected override Expression VisitBinaryExpression(BinaryExpression expression) + protected override Expression VisitBinary(BinaryExpression expression) { - var result = base.VisitBinaryExpression(expression); - - if (expression.NodeType == ExpressionType.AndAlso) - { - HandleBinaryOperation((a, b) => a.AndAlso(b)); - } - else if (expression.NodeType == ExpressionType.OrElse) - { - HandleBinaryOperation((a, b) => a.OrElse(b)); - } - else if (expression.NodeType == ExpressionType.NotEqual && VisitorUtil.IsNullConstant(expression.Right)) - { - // Discard result from right null. Left is visited first, so it's below right on the stack. - _values.Pop(); - - HandleUnaryOperation(pvs => pvs.IsNotNull()); - } - else if (expression.NodeType == ExpressionType.NotEqual && VisitorUtil.IsNullConstant(expression.Left)) - { - // Discard result from left null. - var right = _values.Pop(); - _values.Pop(); // Discard left. - _values.Push(right); - - HandleUnaryOperation(pvs => pvs.IsNotNull()); - } - else if (expression.NodeType == ExpressionType.Equal && VisitorUtil.IsNullConstant(expression.Right)) - { - // Discard result from right null. Left is visited first, so it's below right on the stack. - _values.Pop(); - - HandleUnaryOperation(pvs => pvs.IsNull()); - } - else if (expression.NodeType == ExpressionType.Equal && VisitorUtil.IsNullConstant(expression.Left)) - { - // Discard result from left null. - var right = _values.Pop(); - _values.Pop(); // Discard left. - _values.Push(right); - - HandleUnaryOperation(pvs => pvs.IsNull()); - } - else if (expression.NodeType == ExpressionType.Coalesce) - { - HandleBinaryOperation((a, b) => a.Coalesce(b)); - } - else if (expression.NodeType == ExpressionType.Add || expression.NodeType == ExpressionType.AddChecked) - { - HandleBinaryOperation((a, b) => a.Add(b, expression.Type)); - } - else if (expression.NodeType == ExpressionType.Divide) - { - HandleBinaryOperation((a, b) => a.Divide(b, expression.Type)); - } - else if (expression.NodeType == ExpressionType.Modulo) - { - HandleBinaryOperation((a, b) => a.Modulo(b, expression.Type)); - } - else if (expression.NodeType == ExpressionType.Multiply || expression.NodeType == ExpressionType.MultiplyChecked) - { - HandleBinaryOperation((a, b) => a.Multiply(b, expression.Type)); - } - else if (expression.NodeType == ExpressionType.Power) - { - HandleBinaryOperation((a, b) => a.Power(b, expression.Type)); - } - else if (expression.NodeType == ExpressionType.Subtract || expression.NodeType == ExpressionType.SubtractChecked) - { - HandleBinaryOperation((a, b) => a.Subtract(b, expression.Type)); - } - else if (expression.NodeType == ExpressionType.And) - { - HandleBinaryOperation((a, b) => a.And(b, expression.Type)); - } - else if (expression.NodeType == ExpressionType.Or) - { - HandleBinaryOperation((a, b) => a.Or(b, expression.Type)); - } - else if (expression.NodeType == ExpressionType.ExclusiveOr) - { - HandleBinaryOperation((a, b) => a.ExclusiveOr(b, expression.Type)); - } - else if (expression.NodeType == ExpressionType.LeftShift) - { - HandleBinaryOperation((a, b) => a.LeftShift(b, expression.Type)); - } - else if (expression.NodeType == ExpressionType.RightShift) - { - HandleBinaryOperation((a, b) => a.RightShift(b, expression.Type)); - } - else if (expression.NodeType == ExpressionType.Equal) - { - HandleBinaryOperation((a, b) => a.Equal(b)); - } - else if (expression.NodeType == ExpressionType.NotEqual) - { - HandleBinaryOperation((a, b) => a.NotEqual(b)); - } - else if (expression.NodeType == ExpressionType.GreaterThanOrEqual) - { - HandleBinaryOperation((a, b) => a.GreaterThanOrEqual(b)); - } - else if (expression.NodeType == ExpressionType.GreaterThan) - { - HandleBinaryOperation((a, b) => a.GreaterThan(b)); - } - else if (expression.NodeType == ExpressionType.LessThan) - { - HandleBinaryOperation((a, b) => a.LessThan(b)); - } - else if (expression.NodeType == ExpressionType.LessThanOrEqual) - { - HandleBinaryOperation((a, b) => a.LessThanOrEqual(b)); + var result = base.VisitBinary(expression); + + switch (expression.NodeType) + { + case ExpressionType.AndAlso: + HandleBinaryOperation((a, b) => a.AndAlso(b)); + break; + case ExpressionType.OrElse: + HandleBinaryOperation((a, b) => a.OrElse(b)); + break; + + case ExpressionType.NotEqual: + if (VisitorUtil.IsNullConstant(expression.Right)) + { + // Discard result from right null. Left is visited first, so it's below right on the stack. + _values.Pop(); + HandleUnaryOperation(pvs => pvs.IsNotNull()); + } + else if (VisitorUtil.IsNullConstant(expression.Left)) + { + // Discard result from left null. + var right = _values.Pop(); + _values.Pop(); // Discard left. + _values.Push(right); + HandleUnaryOperation(pvs => pvs.IsNotNull()); + } + else + { + HandleBinaryOperation((a, b) => a.NotEqual(b)); + } + break; + + case ExpressionType.Equal: + if (VisitorUtil.IsNullConstant(expression.Right)) + { + // Discard result from right null. Left is visited first, so it's below right on the stack. + _values.Pop(); + HandleUnaryOperation(pvs => pvs.IsNull()); + } + else if (VisitorUtil.IsNullConstant(expression.Left)) + { + // Discard result from left null. + var right = _values.Pop(); + _values.Pop(); // Discard left. + _values.Push(right); + HandleUnaryOperation(pvs => pvs.IsNull()); + } + else + { + HandleBinaryOperation((a, b) => a.Equal(b)); + } + break; + + case ExpressionType.Coalesce: + HandleBinaryOperation((a, b) => a.Coalesce(b)); + break; + case ExpressionType.Add: + case ExpressionType.AddChecked: + HandleBinaryOperation((a, b) => a.Add(b, expression.Type)); + break; + case ExpressionType.Divide: + HandleBinaryOperation((a, b) => a.Divide(b, expression.Type)); + break; + case ExpressionType.Modulo: + HandleBinaryOperation((a, b) => a.Modulo(b, expression.Type)); + break; + case ExpressionType.Multiply: + case ExpressionType.MultiplyChecked: + HandleBinaryOperation((a, b) => a.Multiply(b, expression.Type)); + break; + case ExpressionType.Power: + HandleBinaryOperation((a, b) => a.Power(b, expression.Type)); + break; + case ExpressionType.Subtract: + case ExpressionType.SubtractChecked: + HandleBinaryOperation((a, b) => a.Subtract(b, expression.Type)); + break; + case ExpressionType.And: + HandleBinaryOperation((a, b) => a.And(b, expression.Type)); + break; + case ExpressionType.Or: + HandleBinaryOperation((a, b) => a.Or(b, expression.Type)); + break; + case ExpressionType.ExclusiveOr: + HandleBinaryOperation((a, b) => a.ExclusiveOr(b, expression.Type)); + break; + case ExpressionType.LeftShift: + HandleBinaryOperation((a, b) => a.LeftShift(b, expression.Type)); + break; + case ExpressionType.RightShift: + HandleBinaryOperation((a, b) => a.RightShift(b, expression.Type)); + break; + case ExpressionType.GreaterThanOrEqual: + HandleBinaryOperation((a, b) => a.GreaterThanOrEqual(b)); + break; + case ExpressionType.GreaterThan: + HandleBinaryOperation((a, b) => a.GreaterThan(b)); + break; + case ExpressionType.LessThan: + HandleBinaryOperation((a, b) => a.LessThan(b)); + break; + case ExpressionType.LessThanOrEqual: + HandleBinaryOperation((a, b) => a.LessThanOrEqual(b)); + break; } return result; } - protected override Expression VisitUnaryExpression(UnaryExpression expression) + protected override Expression VisitUnary(UnaryExpression expression) { - Expression result = base.VisitUnaryExpression(expression); - - if (expression.NodeType == ExpressionType.Not && expression.Type == typeof(bool)) - { - HandleUnaryOperation(pvs => pvs.Not()); - } - else if (expression.NodeType == ExpressionType.Not) - { - HandleUnaryOperation(pvs => pvs.BitwiseNot(expression.Type)); - } - else if (expression.NodeType == ExpressionType.ArrayLength) - { - HandleUnaryOperation(pvs => pvs.ArrayLength(expression.Type)); - } - else if (expression.NodeType == ExpressionType.Convert || expression.NodeType == ExpressionType.ConvertChecked) - { - HandleUnaryOperation(pvs => pvs.Convert(expression.Type)); + var result = base.VisitUnary(expression); + + switch (expression.NodeType) + { + case ExpressionType.Not: + if (expression.Type == typeof(bool)) + HandleUnaryOperation(pvs => pvs.Not()); + else + HandleUnaryOperation(pvs => pvs.BitwiseNot(expression.Type)); + break; + case ExpressionType.ArrayLength: + HandleUnaryOperation(pvs => pvs.ArrayLength(expression.Type)); + break; + case ExpressionType.Convert: + case ExpressionType.ConvertChecked: + HandleUnaryOperation(pvs => pvs.Convert(expression.Type)); + break; + case ExpressionType.Negate: + case ExpressionType.NegateChecked: + HandleUnaryOperation(pvs => pvs.Negate(expression.Type)); + break; + case ExpressionType.UnaryPlus: + HandleUnaryOperation(pvs => pvs.UnaryPlus(expression.Type)); + break; } - else if (expression.NodeType == ExpressionType.Negate || expression.NodeType == ExpressionType.NegateChecked) - { - HandleUnaryOperation(pvs => pvs.Negate(expression.Type)); - } - else if (expression.NodeType == ExpressionType.UnaryPlus) - { - HandleUnaryOperation(pvs => pvs.UnaryPlus(expression.Type)); - } - + return result; } - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { - expression.QueryModel.TransformExpressions(VisitExpression); + expression.QueryModel.TransformExpressions(Visit); return expression; } - // We would usually get NULL if one of our inner member expresions was null. + // We would usually get NULL if one of our inner member expressions was null. // However, it's possible a method call will convert the null value from the failed join into a non-null value. // This could be optimized by actually checking what the method does. For example StartsWith("s") would leave null as null and would still allow us to inner join. - //protected override Expression VisitMethodCallExpression(MethodCallExpression expression) + //protected override Expression VisitMethodCall(MethodCallExpression expression) //{ - // Expression result = base.VisitMethodCallExpression(expression); + // Expression result = base.VisitMethodCall(expression); // return result; //} - protected override Expression VisitMemberExpression(MemberExpression expression) + protected override Expression VisitMember(MemberExpression expression) { // The member expression we're visiting might be on the end of a variety of things, such as: // a.B @@ -300,7 +290,7 @@ protected override Expression VisitMemberExpression(MemberExpression expression) if (!isIdentifier) _memberExpressionDepth++; - var result = base.VisitMemberExpression(expression); + var result = base.VisitMember(expression); if (!isIdentifier) _memberExpressionDepth--; @@ -320,7 +310,7 @@ protected override Expression VisitMemberExpression(MemberExpression expression) values.MemberExpressionValuesIfEmptyOuterJoined[key] = PossibleValueSet.CreateNull(expression.Type); } SetResultValues(values); - + return result; } @@ -356,7 +346,7 @@ public ExpressionValues(PossibleValueSet valuesIfUnknownMemberExpression) /// For example, if we have an expression "3" and we request the state for "a.B.C", we'll /// use "3" from Values since it won't exist in MemberExpressionValuesIfEmptyOuterJoined. /// - private PossibleValueSet Values { get; set; } + private PossibleValueSet Values { get; } /// /// Stores the possible values of an expression that would result if the given member expression @@ -365,20 +355,16 @@ public ExpressionValues(PossibleValueSet valuesIfUnknownMemberExpression) /// member expression, it may not appear in this list. In that case, the emptily outer joined /// value set for that member expression will be whatever's in Values instead. /// - public Dictionary MemberExpressionValuesIfEmptyOuterJoined { get; private set; } + public Dictionary MemberExpressionValuesIfEmptyOuterJoined { get; } public PossibleValueSet GetValues(string memberExpression) { - PossibleValueSet value; - if (MemberExpressionValuesIfEmptyOuterJoined.TryGetValue(memberExpression, out value)) + if (MemberExpressionValuesIfEmptyOuterJoined.TryGetValue(memberExpression, out PossibleValueSet value)) return value; return Values; } - public IEnumerable MemberExpressions - { - get { return MemberExpressionValuesIfEmptyOuterJoined.Keys; } - } + public IEnumerable MemberExpressions => MemberExpressionValuesIfEmptyOuterJoined.Keys; public ExpressionValues Operation(ExpressionValues mergeWith, Func operation) { diff --git a/src/NHibernate/NHibernate.csproj b/src/NHibernate/NHibernate.csproj index 3d80f7fcab3..8047b027c15 100644 --- a/src/NHibernate/NHibernate.csproj +++ b/src/NHibernate/NHibernate.csproj @@ -79,10 +79,14 @@ ..\packages\Iesi.Collections.4.0.1.4000\lib\net40\Iesi.Collections.dll True - - ..\packages\Remotion.Linq.1.15.15.0\lib\portable-net45+wp80+wpa81+win\Remotion.Linq.dll + + ..\packages\Remotion.Linq.2.1.1\lib\net45\Remotion.Linq.dll + + + ..\packages\Remotion.Linq.EagerFetching.2.0.1\lib\net45\Remotion.Linq.EagerFetching.dll + @@ -144,6 +148,14 @@ + + + + + + + + @@ -321,7 +333,6 @@ - @@ -976,11 +987,9 @@ - - @@ -1040,11 +1049,9 @@ - - @@ -1796,6 +1803,7 @@ + diff --git a/src/NHibernate/app.config b/src/NHibernate/app.config new file mode 100644 index 00000000000..246cc8bf759 --- /dev/null +++ b/src/NHibernate/app.config @@ -0,0 +1,11 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/src/NHibernate/packages.config b/src/NHibernate/packages.config index b3056b09770..880c0387d89 100644 --- a/src/NHibernate/packages.config +++ b/src/NHibernate/packages.config @@ -3,5 +3,17 @@ - + + + + + + + + + + + + + \ No newline at end of file From f7b7601041aae6f35821b513ee1f49786e82eb62 Mon Sep 17 00:00:00 2001 From: Alexander Zaytsev Date: Mon, 20 Mar 2017 17:07:13 +1300 Subject: [PATCH 3/3] NH-3944 - Fix nuspec template for Relinq v2 --- src/NHibernate/NHibernate.nuspec.template | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/NHibernate/NHibernate.nuspec.template b/src/NHibernate/NHibernate.nuspec.template index 7ed3e54f5d2..1cf635ce4f3 100644 --- a/src/NHibernate/NHibernate.nuspec.template +++ b/src/NHibernate/NHibernate.nuspec.template @@ -15,7 +15,8 @@ - + + http://nhibernate.info @@ -34,4 +35,4 @@ - \ No newline at end of file +