diff --git a/src/NHibernate.Test/Linq/ByMethod/GroupByHavingTests.cs b/src/NHibernate.Test/Linq/ByMethod/GroupByHavingTests.cs index 818d7593722..7b4e4cf3c50 100644 --- a/src/NHibernate.Test/Linq/ByMethod/GroupByHavingTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/GroupByHavingTests.cs @@ -135,5 +135,23 @@ public void SingleKeyGroupAndCountWithHavingClause() var hornRow = orderCounts.Single(row => row.CompanyName == "Around the Horn"); Assert.That(hornRow.OrderCount, Is.EqualTo(13)); } + + [Test, Explicit("Demonstrate an unsupported case for PagingRewriter")] + public void SingleKeyGroupAndCountWithHavingClausePagingAndOuterWhere() + { + var orderCounts = db.Orders + .GroupBy(o => o.Customer.CompanyName) + .Where(g => g.Count() > 10) + .Select(g => new { CompanyName = g.Key, OrderCount = g.Count() }) + .OrderBy(oc => oc.CompanyName) + .Skip(5) + .Take(10) + .Where(oc => oc.CompanyName.Contains("F")) + .ToList(); + + Assert.That(orderCounts, Has.Count.EqualTo(3)); + var frankRow = orderCounts.Single(row => row.CompanyName == "Frankenversand"); + Assert.That(frankRow.OrderCount, Is.EqualTo(15)); + } } } diff --git a/src/NHibernate/Linq/Clauses/NhHavingClause.cs b/src/NHibernate/Linq/Clauses/NhHavingClause.cs deleted file mode 100644 index 131043ecc92..00000000000 --- a/src/NHibernate/Linq/Clauses/NhHavingClause.cs +++ /dev/null @@ -1,19 +0,0 @@ -using Remotion.Linq.Clauses; -using Remotion.Linq.Clauses.ExpressionTreeVisitors; -using System.Linq.Expressions; - -namespace NHibernate.Linq.Clauses -{ - public class NhHavingClause : WhereClause - { - public NhHavingClause(Expression predicate) - : base(predicate) - { - } - - public override string ToString() - { - return "having " + FormattingExpressionTreeVisitor.Format(Predicate); - } - } -} \ No newline at end of file diff --git a/src/NHibernate/Linq/Clauses/NhJoinClause.cs b/src/NHibernate/Linq/Clauses/NhJoinClause.cs deleted file mode 100644 index 944a8248d4d..00000000000 --- a/src/NHibernate/Linq/Clauses/NhJoinClause.cs +++ /dev/null @@ -1,62 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Collections.ObjectModel; -using System.Linq.Expressions; -using NHibernate.Linq.Visitors; -using Remotion.Linq.Clauses; -using Remotion.Linq.Clauses.Expressions; - -namespace NHibernate.Linq.Clauses -{ - /// - /// All joins are created as outer joins. An optimization in finds - /// joins that may be inner joined and calls on them. - /// 's will - /// then emit the correct HQL join. - /// - public class NhJoinClause : AdditionalFromClause - { - public NhJoinClause(string itemName, System.Type itemType, Expression fromExpression) - : this(itemName, itemType, fromExpression, new NhWithClause[0]) - { - } - - public NhJoinClause(string itemName, System.Type itemType, Expression fromExpression, IEnumerable restrictions) - : base(itemName, itemType, fromExpression) - { - Restrictions = new ObservableCollection(); - foreach (var withClause in restrictions) - Restrictions.Add(withClause); - IsInner = false; - } - - public ObservableCollection Restrictions { get; private set; } - - public bool IsInner { get; private set; } - - public override AdditionalFromClause Clone(CloneContext cloneContext) - { - var joinClause = new NhJoinClause(ItemName, ItemType, FromExpression); - foreach (var withClause in Restrictions) - { - var withClause2 = new NhWithClause(withClause.Predicate); - joinClause.Restrictions.Add(withClause2); - } - - cloneContext.QuerySourceMapping.AddMapping(this, new QuerySourceReferenceExpression(joinClause)); - return base.Clone(cloneContext); - } - - public void MakeInner() - { - IsInner = true; - } - - public override void TransformExpressions(Func transformation) - { - foreach (var withClause in Restrictions) - withClause.TransformExpressions(transformation); - base.TransformExpressions(transformation); - } - } -} \ No newline at end of file diff --git a/src/NHibernate/Linq/Clauses/NhWithClause.cs b/src/NHibernate/Linq/Clauses/NhWithClause.cs deleted file mode 100644 index ae21bc0f4a6..00000000000 --- a/src/NHibernate/Linq/Clauses/NhWithClause.cs +++ /dev/null @@ -1,19 +0,0 @@ -using System.Linq.Expressions; -using Remotion.Linq.Clauses; -using Remotion.Linq.Clauses.ExpressionTreeVisitors; - -namespace NHibernate.Linq.Clauses -{ - public class NhWithClause : WhereClause - { - public NhWithClause(Expression predicate) - : base(predicate) - { - } - - public override string ToString() - { - return "with " + FormattingExpressionTreeVisitor.Format(Predicate); - } - } -} diff --git a/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs b/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs index 8150d0ae5df..316ba170add 100644 --- a/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs +++ b/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs @@ -2,7 +2,6 @@ using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; -using NHibernate.Linq.Clauses; using NHibernate.Linq.ReWriters; using NHibernate.Linq.Visitors; using Remotion.Linq; @@ -43,11 +42,9 @@ public static class AggregatingGroupByRewriter typeof (CacheableResultOperator) }; - public static void ReWrite(QueryModel queryModel) + public static void ReWrite(QueryModel queryModel, VisitorParameters parameters) { - var subQueryExpression = queryModel.MainFromClause.FromExpression as SubQueryExpression; - - if (subQueryExpression != null) + if (queryModel.MainFromClause.FromExpression is SubQueryExpression subQueryExpression) { var operators = subQueryExpression.QueryModel.ResultOperators .Where(x => !QueryReferenceExpressionFlattener.FlattenableResultOperators.Contains(x.GetType())) @@ -55,17 +52,16 @@ public static void ReWrite(QueryModel queryModel) if (operators.Length == 1) { - var groupBy = operators[0] as GroupResultOperator; - if (groupBy != null) + if (operators[0] is GroupResultOperator groupBy) { - FlattenSubQuery(queryModel, subQueryExpression.QueryModel, groupBy); + FlattenSubQuery(queryModel, subQueryExpression.QueryModel, groupBy, parameters); RemoveCostantGroupByKeys(queryModel, groupBy); } } } } - private static void FlattenSubQuery(QueryModel queryModel, QueryModel subQueryModel, GroupResultOperator groupBy) + private static void FlattenSubQuery(QueryModel queryModel, QueryModel subQueryModel, GroupResultOperator groupBy, VisitorParameters parameters) { foreach (var resultOperator in queryModel.ResultOperators.Where(resultOperator => !AcceptableOuterResultOperators.Contains(resultOperator.GetType()))) { @@ -81,11 +77,9 @@ private static void FlattenSubQuery(QueryModel queryModel, QueryModel subQueryMo clause.TransformExpressions(s => GroupBySelectClauseRewriter.ReWrite(s, groupBy, subQueryModel)); //all outer where clauses actually are having clauses - var whereClause = clause as WhereClause; - if (whereClause != null) + if (clause is WhereClause whereClause) { - queryModel.BodyClauses.RemoveAt(i); - queryModel.BodyClauses.Insert(i, new NhHavingClause(whereClause.Predicate)); + parameters.AddHavingClause(whereClause); } } diff --git a/src/NHibernate/Linq/GroupBy/PagingRewriter.cs b/src/NHibernate/Linq/GroupBy/PagingRewriter.cs index 922f1790a10..87d89e4f2ee 100644 --- a/src/NHibernate/Linq/GroupBy/PagingRewriter.cs +++ b/src/NHibernate/Linq/GroupBy/PagingRewriter.cs @@ -9,17 +9,16 @@ namespace NHibernate.Linq.GroupBy { internal static class PagingRewriter { - private static readonly System.Type[] PagingResultOperators = new[] - { - typeof (SkipResultOperator), - typeof (TakeResultOperator), - }; + private static readonly System.Type[] PagingResultOperators = + new[] + { + typeof (SkipResultOperator), + typeof (TakeResultOperator), + }; public static void ReWrite(QueryModel queryModel) { - var subQueryExpression = queryModel.MainFromClause.FromExpression as SubQueryExpression; - - if (subQueryExpression != null && + if (queryModel.MainFromClause.FromExpression is SubQueryExpression subQueryExpression && subQueryExpression.QueryModel.ResultOperators.All(x => PagingResultOperators.Contains(x.GetType()))) { FlattenSubQuery(subQueryExpression, queryModel); @@ -28,7 +27,7 @@ public static void ReWrite(QueryModel queryModel) private static void FlattenSubQuery(SubQueryExpression subQueryExpression, QueryModel queryModel) { - // we can not flattern subquery if outer query has body clauses. + // we can not flatten subquery if outer query has body clauses. var subQueryModel = subQueryExpression.QueryModel; var subQueryMainFromClause = subQueryModel.MainFromClause; if (queryModel.BodyClauses.Count == 0) @@ -46,9 +45,13 @@ private static void FlattenSubQuery(SubQueryExpression subQueryExpression, Query { var cro = new ContainsResultOperator(new QuerySourceReferenceExpression(subQueryMainFromClause)); + // Cloning may cause having/join/with clauses listed in VisitorParameters to no more be matched. + // Not a problem for now, because those clauses imply a projection, which is not supported + // by the "new WhereClause(new SubQueryExpression(newSubQueryModel))" below. See + // SingleKeyGroupAndCountWithHavingClausePagingAndOuterWhere test by example. var newSubQueryModel = subQueryModel.Clone(); newSubQueryModel.ResultOperators.Add(cro); - newSubQueryModel.ResultTypeOverride = typeof (bool); + newSubQueryModel.ResultTypeOverride = typeof(bool); var where = new WhereClause(new SubQueryExpression(newSubQueryModel)); queryModel.BodyClauses.Add(where); diff --git a/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs b/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs index 37ab62de54b..61a057d97e6 100644 --- a/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs +++ b/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs @@ -3,9 +3,9 @@ using System.Linq; using System.Linq.Expressions; using System.Reflection; -using NHibernate.Linq.Clauses; using NHibernate.Linq.GroupBy; using NHibernate.Linq.Visitors; +using NHibernate.Type; using NHibernate.Util; using Remotion.Linq; using Remotion.Linq.Clauses; @@ -26,20 +26,20 @@ static class NestedSelectRewriter private static readonly PropertyInfo IGroupingKeyProperty = (PropertyInfo) ReflectHelper.GetProperty, Tuple>(g => g.Key); - public static void ReWrite(QueryModel queryModel, ISessionFactory sessionFactory) + public static void ReWrite(QueryModel queryModel, VisitorParameters parameters) { - var nsqmv = new NestedSelectDetector(sessionFactory); + var nsqmv = new NestedSelectDetector(parameters.SessionFactory); nsqmv.VisitExpression(queryModel.SelectClause.Selector); if (!nsqmv.HasSubqueries) return; var elementExpression = new List(); - var group = Expression.Parameter(typeof (IGrouping), "g"); - + var group = Expression.Parameter(typeof(IGrouping), "g"); + var replacements = new Dictionary(); foreach (var expression in nsqmv.Expressions) { - var processed = ProcessExpression(queryModel, sessionFactory, expression, elementExpression, group); + var processed = ProcessExpression(queryModel, expression, elementExpression, group, parameters); if (processed != null) replacements.Add(expression, processed); } @@ -48,7 +48,7 @@ public static void ReWrite(QueryModel queryModel, ISessionFactory sessionFactory var expressions = new List(); - var identifier = GetIdentifier(sessionFactory, new QuerySourceReferenceExpression(queryModel.MainFromClause)); + var identifier = GetIdentifier(parameters.SessionFactory, new QuerySourceReferenceExpression(queryModel.MainFromClause)); var rewriter = new SelectClauseRewriter(key, expressions, identifier, replacements); @@ -59,98 +59,104 @@ public static void ReWrite(QueryModel queryModel, ISessionFactory sessionFactory var keySelector = CreateSelector(elementExpression, 0); var elementSelector = CreateSelector(elementExpression, 1); - - var input = Expression.Parameter(typeof (IEnumerable), "input"); + + var input = Expression.Parameter(typeof(IEnumerable), "input"); var lambda = Expression.Lambda( Expression.Call(GroupByMethod, - Expression.Call(CastMethod, input), - keySelector, - elementSelector), + Expression.Call(CastMethod, input), + keySelector, + elementSelector), input); queryModel.ResultOperators.Add(new ClientSideSelect2(lambda)); - queryModel.ResultOperators.Add(new ClientSideSelect(Expression.Lambda(resultSelector, @group))); + queryModel.ResultOperators.Add(new ClientSideSelect(Expression.Lambda(resultSelector, group))); var initializers = elementExpression.Select(e => ConvertToObject(e.Expression)); - queryModel.SelectClause.Selector = Expression.NewArrayInit(typeof (object), initializers); + queryModel.SelectClause.Selector = Expression.NewArrayInit(typeof(object), initializers); } - private static Expression ProcessExpression(QueryModel queryModel, ISessionFactory sessionFactory, Expression expression, List elementExpression, ParameterExpression @group) + private static Expression ProcessExpression(QueryModel queryModel, Expression expression, List elementExpression, + ParameterExpression group, VisitorParameters parameters) { - var memberExpression = expression as MemberExpression; - if (memberExpression != null) - return ProcessMemberExpression(sessionFactory, elementExpression, queryModel, @group, memberExpression); - - var subQueryExpression = expression as SubQueryExpression; - if (subQueryExpression != null) - return ProcessSubquery(sessionFactory, elementExpression, queryModel, @group, subQueryExpression.QueryModel); - + if (expression is MemberExpression memberExpression) + return ProcessMemberExpression(elementExpression, queryModel, group, memberExpression, parameters); + + if (expression is SubQueryExpression subQueryExpression) + return ProcessSubquery(elementExpression, queryModel, group, subQueryExpression.QueryModel, parameters); + return null; } - private static Expression ProcessSubquery(ISessionFactory sessionFactory, ICollection elementExpression, QueryModel queryModel, Expression @group, QueryModel subQueryModel) + private static Expression ProcessSubquery(ICollection elementExpression, QueryModel queryModel, Expression group, + QueryModel subQueryModel, VisitorParameters parameters) { var subQueryMainFromClause = subQueryModel.MainFromClause; - var restrictions = subQueryModel.BodyClauses - .OfType() - .Select(w => new NhWithClause(w.Predicate)); - - var join = new NhJoinClause(subQueryMainFromClause.ItemName, - subQueryMainFromClause.ItemType, - subQueryMainFromClause.FromExpression, - restrictions); + var restrictions = subQueryModel.BodyClauses.OfType().ToList(); + var join = new AdditionalFromClause( + subQueryMainFromClause.ItemName, + subQueryMainFromClause.ItemType, + subQueryMainFromClause.FromExpression); - queryModel.BodyClauses.Add(@join); + parameters.AddLeftJoin(join, restrictions); + queryModel.BodyClauses.Add(join); + foreach (var withClause in restrictions) + { + queryModel.BodyClauses.Add(withClause); + } - var visitor = new SwapQuerySourceVisitor(subQueryMainFromClause, @join); + var visitor = new SwapQuerySourceVisitor(subQueryMainFromClause, join); queryModel.TransformExpressions(visitor.Swap); var selector = subQueryModel.SelectClause.Selector; var collectionType = subQueryModel.GetResultType(); - + var elementType = selector.Type; - var source = new QuerySourceReferenceExpression(@join); + var source = new QuerySourceReferenceExpression(join); - return BuildSubCollectionQuery(sessionFactory, elementExpression, @group, source, selector, elementType, collectionType); + return BuildSubCollectionQuery(parameters.SessionFactory, elementExpression, group, source, selector, elementType, collectionType); } - private static Expression ProcessMemberExpression(ISessionFactory sessionFactory, ICollection elementExpression, QueryModel queryModel, Expression @group, Expression memberExpression) + private static Expression ProcessMemberExpression(ICollection elementExpression, QueryModel queryModel, Expression group, + Expression memberExpression, VisitorParameters parameters) { - var join = new NhJoinClause(new NameGenerator(queryModel).GetNewName(), - GetElementType(memberExpression.Type), - memberExpression); + var join = new AdditionalFromClause( + new NameGenerator(queryModel).GetNewName(), + GetElementType(memberExpression.Type), + memberExpression); + parameters.AddLeftJoin(join, null); - queryModel.BodyClauses.Add(@join); + queryModel.BodyClauses.Add(join); - var source = new QuerySourceReferenceExpression(@join); + var source = new QuerySourceReferenceExpression(join); - return BuildSubCollectionQuery(sessionFactory, elementExpression, @group, source, source, source.Type, memberExpression.Type); + return BuildSubCollectionQuery(parameters.SessionFactory, elementExpression, group, source, source, source.Type, memberExpression.Type); } - private static Expression BuildSubCollectionQuery(ISessionFactory sessionFactory, ICollection expressions, Expression @group, Expression source, Expression select, System.Type elementType, System.Type collectionType) + private static Expression BuildSubCollectionQuery(ISessionFactory sessionFactory, ICollection expressions, Expression group, + Expression source, Expression select, System.Type elementType, System.Type collectionType) { var predicate = MakePredicate(expressions.Count); var identifier = GetIdentifier(sessionFactory, source); - var selector = MakeSelector(expressions, @select, identifier); + var selector = MakeSelector(expressions, select, identifier); - return SubCollectionQuery(collectionType, elementType, @group, predicate, selector); + return SubCollectionQuery(collectionType, elementType, group, predicate, selector); } - private static LambdaExpression MakeSelector(ICollection elementExpression, Expression @select, Expression identifier) + private static LambdaExpression MakeSelector(ICollection elementExpression, Expression select, Expression identifier) { - var parameter = Expression.Parameter(typeof (Tuple), "value"); + var parameter = Expression.Parameter(typeof(Tuple), "value"); var rewriter = new SelectClauseRewriter(parameter, elementExpression, identifier, 1, new Dictionary()); - var selectorBody = rewriter.VisitExpression(@select); + var selectorBody = rewriter.VisitExpression(select); return Expression.Lambda(selectorBody, parameter); } @@ -161,27 +167,29 @@ private static Expression SubCollectionQuery(System.Type collectionType, System. var selectMethod = ReflectionCache.EnumerableMethods.SelectDefinition.MakeGenericMethod(new[] { typeof(Tuple), elementType }); - var select = Expression.Call(selectMethod, - Expression.Call(WhereMethod, source, predicate), - selector); + var select = Expression.Call( + selectMethod, + Expression.Call(WhereMethod, source, predicate), + selector); if (collectionType.IsArray) { var toArrayMethod = ReflectionCache.EnumerableMethods.ToArrayDefinition.MakeGenericMethod(new[] { elementType }); - var array = Expression.Call(toArrayMethod, @select); + var array = Expression.Call(toArrayMethod, select); return array; } var constructor = GetCollectionConstructor(collectionType, elementType); if (constructor != null) - return Expression.New(constructor, (Expression) @select); + return Expression.New(constructor, select); var toListMethod = ReflectionCache.EnumerableMethods.ToListDefinition.MakeGenericMethod(new[] { elementType }); - return Expression.Call(Expression.Call(toListMethod, @select), - "AsReadonly", - System.Type.EmptyTypes); + return Expression.Call( + Expression.Call(toListMethod, select), + "AsReadonly", + System.Type.EmptyTypes); } private static ConstructorInfo GetCollectionConstructor(System.Type collectionType, System.Type elementType) @@ -201,12 +209,13 @@ private static ConstructorInfo GetCollectionConstructor(System.Type collectionTy private static LambdaExpression MakePredicate(int index) { // t => Not(ReferenceEquals(t.Items[index], null)) - var t = Expression.Parameter(typeof (Tuple), "t"); + var t = Expression.Parameter(typeof(Tuple), "t"); return Expression.Lambda( Expression.Not( - Expression.Call(ObjectReferenceEquals, - ArrayIndex(Expression.Property(t, Tuple.ItemsProperty), index), - Expression.Constant(null))), + Expression.Call( + ObjectReferenceEquals, + ArrayIndex(Expression.Property(t, Tuple.ItemsProperty), index), + Expression.Constant(null))), t); } @@ -219,26 +228,26 @@ private static Expression GetIdentifier(ISessionFactory sessionFactory, Expressi if (classMetadata == null) return Expression.Constant(null); - var propertyName=classMetadata.IdentifierPropertyName; - NHibernate.Type.EmbeddedComponentType componentType; - if (propertyName == null && (componentType=classMetadata.IdentifierType as NHibernate.Type.EmbeddedComponentType)!=null) - { - //The identifier is an embedded composite key. We only need one property from it for a null check - propertyName = componentType.PropertyNames.First(); - } + var propertyName = classMetadata.IdentifierPropertyName; + EmbeddedComponentType componentType; + if (propertyName == null && (componentType = classMetadata.IdentifierType as EmbeddedComponentType) != null) + { + //The identifier is an embedded composite key. We only need one property from it for a null check + propertyName = componentType.PropertyNames.First(); + } - return ConvertToObject(Expression.PropertyOrField(expression, propertyName)); + return ConvertToObject(Expression.PropertyOrField(expression, propertyName)); } private static LambdaExpression CreateSelector(IEnumerable expressions, int tuple) { - var parameter = Expression.Parameter(typeof (object[]), "x"); + var parameter = Expression.Parameter(typeof(object[]), "x"); - var initializers = expressions.Select((x, index) => new { x.Tuple, index}) + var initializers = expressions.Select((x, index) => new { x.Tuple, index }) .Where(x => x.Tuple == tuple) .Select(x => ArrayIndex(parameter, x.index)); - var newArrayInit = Expression.NewArrayInit(typeof (object), initializers); + var newArrayInit = Expression.NewArrayInit(typeof(object), initializers); return Expression.Lambda( Expression.New(Tuple.Constructor, newArrayInit), @@ -252,7 +261,7 @@ private static Expression ArrayIndex(Expression param, int value) private static Expression ConvertToObject(Expression expression) { - return Expression.Convert(expression, typeof (object)); + return Expression.Convert(expression, typeof(object)); } private static System.Type GetElementType(System.Type type) @@ -261,6 +270,6 @@ private static System.Type GetElementType(System.Type type) if (elementType == null) throw new NotSupportedException("Unknown collection type " + type.FullName); return elementType; - } + } } } diff --git a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs index 0dc17e068b1..1670f152227 100644 --- a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs +++ b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs @@ -18,17 +18,17 @@ public class AddJoinsReWriter : QueryModelVisitorBase, IIsEntityDecider private readonly MemberExpressionJoinDetector _memberExpressionJoinDetector; private readonly WhereJoinDetector _whereJoinDetector; - private AddJoinsReWriter(ISessionFactoryImplementor sessionFactory, QueryModel queryModel) + private AddJoinsReWriter(QueryModel queryModel, VisitorParameters parameters) { - _sessionFactory = sessionFactory; - var joiner = new Joiner(queryModel); + _sessionFactory = parameters.SessionFactory; + var joiner = new Joiner(queryModel, parameters); _memberExpressionJoinDetector = new MemberExpressionJoinDetector(this, joiner); _whereJoinDetector = new WhereJoinDetector(this, joiner); } public static void ReWrite(QueryModel queryModel, VisitorParameters parameters) { - var visitor = new AddJoinsReWriter(parameters.SessionFactory, queryModel); + var visitor = new AddJoinsReWriter(queryModel, parameters); visitor.VisitQueryModel(queryModel); } diff --git a/src/NHibernate/Linq/Visitors/JoinBuilder.cs b/src/NHibernate/Linq/Visitors/JoinBuilder.cs index f41ccea5022..b785d6e18d5 100644 --- a/src/NHibernate/Linq/Visitors/JoinBuilder.cs +++ b/src/NHibernate/Linq/Visitors/JoinBuilder.cs @@ -1,6 +1,5 @@ using System.Collections.Generic; using System.Linq.Expressions; -using NHibernate.Linq.Clauses; using Remotion.Linq; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; @@ -17,43 +16,38 @@ public interface IJoiner public class Joiner : IJoiner { - private readonly Dictionary _joins = new Dictionary(); + private readonly Dictionary _joins = new Dictionary(); private readonly NameGenerator _nameGenerator; + private readonly VisitorParameters _parameters; private readonly QueryModel _queryModel; - internal Joiner(QueryModel queryModel) + internal Joiner(QueryModel queryModel, VisitorParameters parameters) { _nameGenerator = new NameGenerator(queryModel); + _parameters = parameters; _queryModel = queryModel; } - public IEnumerable Joins - { - get { return _joins.Values; } - } - public Expression AddJoin(Expression expression, string key) { - NhJoinClause join; - - if (!_joins.TryGetValue(key, out join)) + if (!_joins.TryGetValue(key, out AdditionalFromClause join)) { - join = new NhJoinClause(_nameGenerator.GetNewName(), expression.Type, expression); + join = new AdditionalFromClause(_nameGenerator.GetNewName(), expression.Type, expression); + _parameters.AddLeftJoin(join, null); _queryModel.BodyClauses.Add(join); _joins.Add(key, join); } - return new QuerySourceReferenceExpression(@join); + return new QuerySourceReferenceExpression(join); } public void MakeInnerIfJoined(string key) { // key is not joined if it occurs only at tails of expressions, e.g. // a.B == null, a.B != null, a.B == c.D etc. - NhJoinClause nhJoinClause; - if (_joins.TryGetValue(key, out nhJoinClause)) + if (_joins.TryGetValue(key, out AdditionalFromClause join)) { - nhJoinClause.MakeInner(); + _parameters.MakeInnerJoin(join); } } @@ -64,8 +58,7 @@ public bool CanAddJoin(Expression expression) if (_queryModel.MainFromClause == source) return true; - var bodyClause = source as IBodyClause; - if (bodyClause != null && _queryModel.BodyClauses.Contains(bodyClause)) + if (source is IBodyClause bodyClause && _queryModel.BodyClauses.Contains(bodyClause)) return true; var resultOperatorBase = source as ResultOperatorBase; diff --git a/src/NHibernate/Linq/Visitors/LeftJoinRewriter.cs b/src/NHibernate/Linq/Visitors/LeftJoinRewriter.cs index 83a2facc0b5..e8cf6299208 100644 --- a/src/NHibernate/Linq/Visitors/LeftJoinRewriter.cs +++ b/src/NHibernate/Linq/Visitors/LeftJoinRewriter.cs @@ -1,6 +1,5 @@ using System.Collections.Generic; using System.Linq; -using NHibernate.Linq.Clauses; using Remotion.Linq; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.ExpressionTreeVisitors; @@ -11,9 +10,16 @@ namespace NHibernate.Linq.Visitors { public class LeftJoinRewriter : QueryModelVisitorBase { - public static void ReWrite(QueryModel queryModel) + public static void ReWrite(QueryModel queryModel, VisitorParameters parameters) { - new LeftJoinRewriter().VisitQueryModel(queryModel); + new LeftJoinRewriter(parameters).VisitQueryModel(queryModel); + } + + private readonly VisitorParameters _parameters; + + public LeftJoinRewriter(VisitorParameters parameters) + { + _parameters = parameters; } public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, QueryModel queryModel, int index) @@ -28,14 +34,9 @@ public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, var mainFromClause = subQueryModel.MainFromClause; - var restrictions = subQueryModel.BodyClauses - .OfType() - .Select(w => new NhWithClause(w.Predicate)); - - var join = new NhJoinClause(mainFromClause.ItemName, - mainFromClause.ItemType, - mainFromClause.FromExpression, - restrictions); + var join = new AdditionalFromClause(mainFromClause.ItemName, mainFromClause.ItemType, mainFromClause.FromExpression); + var restrictions = subQueryModel.BodyClauses.OfType().ToList(); + _parameters.AddLeftJoin(join, restrictions); var innerSelectorMapping = new QuerySourceMapping(); innerSelectorMapping.AddMapping(fromClause, subQueryModel.SelectClause.Selector); @@ -43,11 +44,11 @@ public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, queryModel.TransformExpressions(ex => ReferenceReplacingExpressionTreeVisitor.ReplaceClauseReferences(ex, innerSelectorMapping, false)); queryModel.BodyClauses.RemoveAt(index); - queryModel.BodyClauses.Insert(index, @join); - InsertBodyClauses(subQueryModel.BodyClauses.Where(b => !(b is WhereClause)), queryModel, index + 1); + queryModel.BodyClauses.Insert(index, join); + InsertBodyClauses(subQueryModel.BodyClauses, queryModel, index + 1); var innerBodyClauseMapping = new QuerySourceMapping(); - innerBodyClauseMapping.AddMapping(mainFromClause, new QuerySourceReferenceExpression(@join)); + innerBodyClauseMapping.AddMapping(mainFromClause, new QuerySourceReferenceExpression(join)); queryModel.TransformExpressions(ex => ReferenceReplacingExpressionTreeVisitor.ReplaceClauseReferences(ex, innerBodyClauseMapping, false)); } @@ -57,14 +58,14 @@ private static void InsertBodyClauses(IEnumerable bodyClauses, Quer foreach (var bodyClause in bodyClauses) { destinationQueryModel.BodyClauses.Insert(destinationIndex, bodyClause); - ++destinationIndex; + destinationIndex++; } } private static bool IsLeftJoin(QueryModel subQueryModel) { return subQueryModel.ResultOperators.Count == 1 && - subQueryModel.ResultOperators[0] is DefaultIfEmptyResultOperator; + subQueryModel.ResultOperators[0] is DefaultIfEmptyResultOperator; } } } \ No newline at end of file diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index 3e4a1c863cf..ee5bcceb8d0 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -3,7 +3,6 @@ using System.Linq.Expressions; using System.Reflection; using NHibernate.Hql.Ast; -using NHibernate.Linq.Clauses; using NHibernate.Linq.Expressions; using NHibernate.Linq.GroupBy; using NHibernate.Linq.GroupJoin; @@ -25,7 +24,7 @@ public class QueryModelVisitor : QueryModelVisitorBase public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel queryModel, VisitorParameters parameters, bool root, NhLinqExpressionReturnType? rootReturnType) { - NestedSelectRewriter.ReWrite(queryModel, parameters.SessionFactory); + NestedSelectRewriter.ReWrite(queryModel, parameters); // Remove unnecessary body operators RemoveUnnecessaryBodyOperators.ReWrite(queryModel); @@ -37,7 +36,7 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer NonAggregatingGroupByRewriter.ReWrite(queryModel); // Rewrite aggregate group-by statements - AggregatingGroupByRewriter.ReWrite(queryModel); + AggregatingGroupByRewriter.ReWrite(queryModel, parameters); // Rewrite aggregating group-joins AggregatingGroupJoinRewriter.ReWrite(queryModel); @@ -48,7 +47,7 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer SubQueryFromClauseFlattener.ReWrite(queryModel); // Rewrite left-joins - LeftJoinRewriter.ReWrite(queryModel); + LeftJoinRewriter.ReWrite(queryModel, parameters); // Rewrite paging PagingRewriter.ReWrite(queryModel); @@ -280,10 +279,7 @@ private MethodCallExpression GetAggregateMethodCall(MethodInfo aggregateMethodTe public override void VisitMainFromClause(MainFromClause fromClause, QueryModel queryModel) { - var querySourceName = VisitorParameters.QuerySourceNamer.GetName(fromClause); - var hqlExpressionTree = HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, VisitorParameters); - - _hqlTree.AddFromClause(_hqlTree.TreeBuilder.Range(hqlExpressionTree, _hqlTree.TreeBuilder.Alias(querySourceName))); + AddFromClause(fromClause); // apply any result operators that were rewritten if (RewrittenOperatorResult != null) @@ -300,56 +296,39 @@ public override void VisitMainFromClause(MainFromClause fromClause, QueryModel q public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, QueryModel queryModel, int index) { - var querySourceName = VisitorParameters.QuerySourceNamer.GetName(fromClause); - - var joinClause = fromClause as NhJoinClause; - if (joinClause != null) - { - VisitNhJoinClause(querySourceName, joinClause); - } - else if (fromClause.FromExpression is MemberExpression) + if (fromClause.FromExpression is MemberExpression) { // It's a join - _hqlTree.AddFromClause( - _hqlTree.TreeBuilder.Join( - HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, VisitorParameters).AsExpression(), - _hqlTree.TreeBuilder.Alias(querySourceName))); + var expression = HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, VisitorParameters).AsExpression(); + var querySourceName = VisitorParameters.QuerySourceNamer.GetName(fromClause); + var alias = _hqlTree.TreeBuilder.Alias(querySourceName); + var hqlJoin = VisitorParameters.IsLeftJoin(fromClause) ? + (HqlTreeNode)_hqlTree.TreeBuilder.LeftJoin(expression, alias) : + _hqlTree.TreeBuilder.Join(expression, alias); + + foreach (var withClause in VisitorParameters.GetRestrictions(fromClause)) + { + var booleanExpression = HqlGeneratorExpressionTreeVisitor.Visit(withClause.Predicate, VisitorParameters).ToBooleanExpression(); + hqlJoin.AddChild(_hqlTree.TreeBuilder.With(booleanExpression)); + } + + _hqlTree.AddFromClause(hqlJoin); } else { - // TODO - exact same code as in MainFromClause; refactor this out - _hqlTree.AddFromClause( - _hqlTree.TreeBuilder.Range( - HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, VisitorParameters), - _hqlTree.TreeBuilder.Alias(querySourceName))); - + AddFromClause(fromClause); } base.VisitAdditionalFromClause(fromClause, queryModel, index); } - private void VisitNhJoinClause(string querySourceName, NhJoinClause joinClause) + private void AddFromClause(FromClauseBase fromClause) { - var expression = HqlGeneratorExpressionTreeVisitor.Visit(joinClause.FromExpression, VisitorParameters).AsExpression(); - var alias = _hqlTree.TreeBuilder.Alias(querySourceName); - - HqlTreeNode hqlJoin; - if (joinClause.IsInner) - { - hqlJoin = _hqlTree.TreeBuilder.Join(expression, @alias); - } - else - { - hqlJoin = _hqlTree.TreeBuilder.LeftJoin(expression, @alias); - } - - foreach (var withClause in joinClause.Restrictions) - { - var booleanExpression = HqlGeneratorExpressionTreeVisitor.Visit(withClause.Predicate, VisitorParameters).ToBooleanExpression(); - hqlJoin.AddChild(_hqlTree.TreeBuilder.With(booleanExpression)); - } - - _hqlTree.AddFromClause(hqlJoin); + var querySourceName = VisitorParameters.QuerySourceNamer.GetName(fromClause); + _hqlTree.AddFromClause( + _hqlTree.TreeBuilder.Range( + HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, VisitorParameters), + _hqlTree.TreeBuilder.Alias(querySourceName))); } public override void VisitResultOperator(ResultOperatorBase resultOperator, QueryModel queryModel, int index) @@ -397,11 +376,12 @@ public override void VisitWhereClause(WhereClause whereClause, QueryModel queryM // Visit the predicate to build the query var expression = HqlGeneratorExpressionTreeVisitor.Visit(whereClause.Predicate, VisitorParameters).ToBooleanExpression(); - if (whereClause is NhHavingClause) + if (VisitorParameters.IsHavingClause(whereClause)) { _hqlTree.AddHavingClause(expression); } - else + // With clauses are handled with joins. + else if (!VisitorParameters.IsWithClause(whereClause)) { _hqlTree.AddWhereClause(expression); } diff --git a/src/NHibernate/Linq/Visitors/VisitorParameters.cs b/src/NHibernate/Linq/Visitors/VisitorParameters.cs index 27ef5de7029..5e676a21c22 100644 --- a/src/NHibernate/Linq/Visitors/VisitorParameters.cs +++ b/src/NHibernate/Linq/Visitors/VisitorParameters.cs @@ -3,23 +3,29 @@ using NHibernate.Engine; using NHibernate.Engine.Query; using NHibernate.Param; +using Remotion.Linq.Clauses; namespace NHibernate.Linq.Visitors { public class VisitorParameters { - public ISessionFactoryImplementor SessionFactory { get; private set; } + public ISessionFactoryImplementor SessionFactory { get; } - public IDictionary ConstantToParameterMap { get; private set; } + public IDictionary ConstantToParameterMap { get; } - public List RequiredHqlParameters { get; private set; } + public List RequiredHqlParameters { get; } - public QuerySourceNamer QuerySourceNamer { get; set; } + public QuerySourceNamer QuerySourceNamer { get; } + + private readonly HashSet _havingClauses = new HashSet(); + private readonly HashSet _leftJoins = new HashSet(); + private readonly HashSet _withClauses = new HashSet(); + private readonly Dictionary> _joinRestrictions = new Dictionary>(); public VisitorParameters( - ISessionFactoryImplementor sessionFactory, - IDictionary constantToParameterMap, - List requiredHqlParameters, + ISessionFactoryImplementor sessionFactory, + IDictionary constantToParameterMap, + List requiredHqlParameters, QuerySourceNamer querySourceNamer) { SessionFactory = sessionFactory; @@ -27,5 +33,83 @@ public VisitorParameters( RequiredHqlParameters = requiredHqlParameters; QuerySourceNamer = querySourceNamer; } + + /// + /// Indicates if a Linq where clause needs to be converted to a HQL having clause. + /// + /// The clause to test. + /// true if the clause needs to be converted to a HQL having clause, false otherwise. + public bool IsHavingClause(WhereClause clause) + { + return _havingClauses.Contains(clause); + } + + /// + /// Indicates if a Linq where clause needs to be converted to a HQL with clause. + /// + /// The clause to test. + /// true if the clause needs to be converted to a HQL with clause, false otherwise. + public bool IsWithClause(WhereClause clause) + { + return _withClauses.Contains(clause); + } + + /// + /// Indicates if a Linq join clause needs to be converted to a HQL left join. + /// + /// The join to test. + /// true if the clause needs to be converted to a HQL left join, false otherwise. + public bool IsLeftJoin(AdditionalFromClause join) + { + return _leftJoins.Contains(join); + } + + /// + /// Get the clauses to apply to the join as HQL with clauses. + /// + /// The join. + /// A list of where clauses to apply as HQL with to the join. + public IEnumerable GetRestrictions(AdditionalFromClause join) + { + if (_joinRestrictions.TryGetValue(join, out var restrictions)) + return restrictions; + return new List(); + } + + /// + /// Add a detected having clause. + /// + /// The clause to add. + public void AddHavingClause(WhereClause clause) + { + _havingClauses.Add(clause); + } + + /// + /// Add a detected join. + /// + /// The join to add. + /// Its restrictions if any. + public void AddLeftJoin(AdditionalFromClause join, IEnumerable restrictions) + { + _leftJoins.Add(join); + if (restrictions != null) + { + _joinRestrictions.Add(join, restrictions); + foreach (var with in restrictions) + { + _withClauses.Add(with); + } + } + } + + /// + /// Remove a join clause from left join clauses if it was one. + /// + /// The join clause to handle as inner. + public void MakeInnerJoin(AdditionalFromClause join) + { + _leftJoins.Remove(join); + } } } \ No newline at end of file diff --git a/src/NHibernate/NHibernate.csproj b/src/NHibernate/NHibernate.csproj index 19b538de969..3d1e8f8c578 100644 --- a/src/NHibernate/NHibernate.csproj +++ b/src/NHibernate/NHibernate.csproj @@ -295,9 +295,6 @@ - - -