From 5cb7b3f9d67eb84d97984bcf7c86e639e3449d40 Mon Sep 17 00:00:00 2001 From: Alexander Zaytsev Date: Tue, 27 Jun 2017 09:48:20 +1200 Subject: [PATCH] Update Remotion.Linq to 2.1 --- .../Linq/CustomQueryModelRewriterTests.cs | 10 +- src/NHibernate.Test/NHibernate.Test.csproj | 11 +- src/NHibernate.Test/packages.config | 3 +- src/NHibernate/Linq/Clauses/NhClauseBase.cs | 29 ++++ src/NHibernate/Linq/Clauses/NhHavingClause.cs | 73 ++++++++- src/NHibernate/Linq/Clauses/NhJoinClause.cs | 146 +++++++++++++++--- src/NHibernate/Linq/Clauses/NhWithClause.cs | 66 +++++++- .../RemoveCharToIntConversion.cs | 2 +- .../RemoveRedundantCast.cs | 2 +- .../SimplifyCompareTransformer.cs | 2 +- .../Expressions/NhAggregatedExpression.cs | 35 +++-- .../Linq/Expressions/NhAverageExpression.cs | 14 +- .../Linq/Expressions/NhCountExpression.cs | 20 ++- .../Linq/Expressions/NhDistinctExpression.cs | 10 +- .../Linq/Expressions/NhExpression.cs | 21 +++ .../Linq/Expressions/NhExpressionType.cs | 15 -- .../Linq/Expressions/NhMaxExpression.cs | 10 +- .../Linq/Expressions/NhMinExpression.cs | 10 +- .../Linq/Expressions/NhNewExpression.cs | 56 ++----- .../Linq/Expressions/NhNominatedExpression.cs | 34 ++-- .../Linq/Expressions/NhStarExpression.cs | 31 ++++ .../Linq/Expressions/NhSumExpression.cs | 10 +- .../GroupBy/GroupBySelectClauseRewriter.cs | 36 +++-- .../Linq/GroupBy/GroupKeyNominator.cs | 30 ++-- ...IsNonAggregatingGroupByDetectionVisitor.cs | 14 +- .../Linq/GroupBy/KeySelectorVisitor.cs | 6 +- .../GroupBy/NonAggregatingGroupByRewriter.cs | 4 +- .../GroupJoinAggregateDetectionVisitor.cs | 22 +-- .../GroupJoinSelectClauseRewriter.cs | 6 +- .../GroupJoin/LocateGroupJoinQuerySource.cs | 8 +- .../NonAggregatingGroupJoinRewriter.cs | 6 +- src/NHibernate/Linq/INhQueryModelVisitor.cs | 14 ++ src/NHibernate/Linq/LinqExtensionMethods.cs | 4 +- src/NHibernate/Linq/LinqLogging.cs | 12 +- .../NestedSelects/NestedSelectDetector.cs | 10 +- .../NestedSelects/NestedSelectRewriter.cs | 6 +- .../NestedSelects/SelectClauseRewriter.cs | 10 +- src/NHibernate/Linq/NhRelinqQueryParser.cs | 9 +- .../Linq/ReWriters/AddJoinsReWriter.cs | 9 +- .../ArrayIndexExpressionFlattener.cs | 12 +- .../MergeAggregatingResultsRewriter.cs | 28 ++-- .../QueryReferenceExpressionFlattener.cs | 14 +- .../RemoveUnnecessaryBodyOperators.cs | 3 +- .../Linq/ReWriters/ResultOperatorRewriter.cs | 11 +- .../Linq/Visitors/EqualityHqlGenerator.cs | 4 +- .../Linq/Visitors/ExpressionKeyVisitor.cs | 57 +++---- .../Visitors/ExpressionParameterVisitor.cs | 16 +- ...or.cs => HqlGeneratorExpressionVisitor.cs} | 79 +++++----- src/NHibernate/Linq/Visitors/JoinBuilder.cs | 8 +- .../Linq/Visitors/LeftJoinRewriter.cs | 9 +- .../Visitors/MemberExpressionJoinDetector.cs | 32 ++-- .../Linq/Visitors/NhExpressionTreeVisitor.cs | 98 ------------ .../Linq/Visitors/NhExpressionVisitor.cs | 60 +++++++ ...hPartialEvaluatingExpressionTreeVisitor.cs | 32 ---- .../NhPartialEvaluatingExpressionVisitor.cs | 37 +++++ .../Linq/Visitors/NhQueryModelVisitorBase.cs | 20 +++ .../PagingRewriterSelectClauseVisitor.cs | 12 +- .../QueryExpressionSourceIdentifer.cs | 6 +- .../Linq/Visitors/QueryModelVisitor.cs | 80 +++++----- .../Linq/Visitors/QuerySourceIdentifier.cs | 10 +- .../Linq/Visitors/QuerySourceLocator.cs | 72 +++++---- .../ProcessAggregate.cs | 4 +- .../ProcessAggregateFromSeed.cs | 4 +- .../ResultOperatorProcessors/ProcessAll.cs | 2 +- .../ProcessContains.cs | 2 +- .../ProcessGroupBy.cs | 2 +- .../ProcessNonAggregatingGroupBy.cs | 6 +- .../ResultOperatorProcessors/ProcessOfType.cs | 2 +- .../Linq/Visitors/SelectClauseNominator.cs | 28 ++-- .../Linq/Visitors/SelectClauseVisitor.cs | 16 +- .../Visitors/SimplifyConditionalVisitor.cs | 16 +- .../Visitors/SubQueryFromClauseFlattener.cs | 8 +- .../Linq/Visitors/SwapQuerySourceVisitor.cs | 14 +- src/NHibernate/Linq/Visitors/VisitorUtil.cs | 4 +- .../Linq/Visitors/WhereJoinDetector.cs | 33 ++-- src/NHibernate/NHibernate.csproj | 25 ++- src/NHibernate/NHibernate.nuspec.template | 5 +- src/NHibernate/packages.config | 3 +- 78 files changed, 1034 insertions(+), 656 deletions(-) create mode 100644 src/NHibernate/Linq/Clauses/NhClauseBase.cs create mode 100644 src/NHibernate/Linq/Expressions/NhExpression.cs delete mode 100644 src/NHibernate/Linq/Expressions/NhExpressionType.cs create mode 100644 src/NHibernate/Linq/Expressions/NhStarExpression.cs create mode 100644 src/NHibernate/Linq/INhQueryModelVisitor.cs rename src/NHibernate/Linq/Visitors/{HqlGeneratorExpressionTreeVisitor.cs => HqlGeneratorExpressionVisitor.cs} (89%) delete mode 100644 src/NHibernate/Linq/Visitors/NhExpressionTreeVisitor.cs create mode 100644 src/NHibernate/Linq/Visitors/NhExpressionVisitor.cs delete mode 100644 src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionTreeVisitor.cs create mode 100644 src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs create mode 100644 src/NHibernate/Linq/Visitors/NhQueryModelVisitorBase.cs diff --git a/src/NHibernate.Test/Linq/CustomQueryModelRewriterTests.cs b/src/NHibernate.Test/Linq/CustomQueryModelRewriterTests.cs index a22305a9966..28d1834d580 100644 --- a/src/NHibernate.Test/Linq/CustomQueryModelRewriterTests.cs +++ b/src/NHibernate.Test/Linq/CustomQueryModelRewriterTests.cs @@ -40,16 +40,16 @@ public QueryModelVisitorBase CreateVisitor(VisitorParameters parameters) } } - public class CustomVisitor : QueryModelVisitorBase + public class CustomVisitor : NhQueryModelVisitorBase { public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index) { - whereClause.TransformExpressions(new Visitor().VisitExpression); + whereClause.TransformExpressions(new Visitor().Visit); } - private class Visitor : ExpressionTreeVisitor + private class Visitor : RelinqExpressionVisitor { - protected override Expression VisitBinaryExpression(BinaryExpression expression) + protected override Expression VisitBinary(BinaryExpression expression) { if ( expression.NodeType == ExpressionType.Equal || @@ -82,7 +82,7 @@ protected override Expression VisitBinaryExpression(BinaryExpression expression) } } - return base.VisitBinaryExpression(expression); + return base.VisitBinary(expression); } } } diff --git a/src/NHibernate.Test/NHibernate.Test.csproj b/src/NHibernate.Test/NHibernate.Test.csproj index 198851b6e86..d2791f23aa2 100644 --- a/src/NHibernate.Test/NHibernate.Test.csproj +++ b/src/NHibernate.Test/NHibernate.Test.csproj @@ -83,6 +83,12 @@ ..\packages\NUnit.3.6.0\lib\net45\nunit.framework.dll True + + ..\packages\Remotion.Linq.2.1.2\lib\net45\Remotion.Linq.dll + + + ..\packages\Remotion.Linq.EagerFetching.2.1.0\lib\net45\Remotion.Linq.EagerFetching.dll + @@ -103,9 +109,6 @@ 3.5 - - ..\packages\Remotion.Linq.1.15.15.0\lib\portable-net45+wp80+wpa81+win\Remotion.Linq.dll - @@ -3899,4 +3902,4 @@ --> - \ No newline at end of file + diff --git a/src/NHibernate.Test/packages.config b/src/NHibernate.Test/packages.config index 838eb801d3a..4bc6669eb17 100644 --- a/src/NHibernate.Test/packages.config +++ b/src/NHibernate.Test/packages.config @@ -4,7 +4,8 @@ - + + diff --git a/src/NHibernate/Linq/Clauses/NhClauseBase.cs b/src/NHibernate/Linq/Clauses/NhClauseBase.cs new file mode 100644 index 00000000000..d3d78864487 --- /dev/null +++ b/src/NHibernate/Linq/Clauses/NhClauseBase.cs @@ -0,0 +1,29 @@ +using System; +using Remotion.Linq; + +namespace NHibernate.Linq.Clauses +{ + public abstract class NhClauseBase + { + /// + /// Accepts the specified visitor. + /// + /// The visitor to accept. + /// The query model in whose context this clause is visited. + /// + /// The index of this clause in the 's + /// collection. + /// + public void Accept(IQueryModelVisitor visitor, QueryModel queryModel, int index) + { + if (visitor == null) throw new ArgumentNullException(nameof(visitor)); + if (queryModel == null) throw new ArgumentNullException(nameof(queryModel)); + if (!(visitor is INhQueryModelVisitor nhVisitor)) + throw new ArgumentException("Expect visitor to implement INhQueryModelVisitor", nameof(visitor)); + + Accept(nhVisitor, queryModel, index); + } + + protected abstract void Accept(INhQueryModelVisitor visitor, QueryModel queryModel, int index); + } +} diff --git a/src/NHibernate/Linq/Clauses/NhHavingClause.cs b/src/NHibernate/Linq/Clauses/NhHavingClause.cs index 131043ecc92..5a6249d172d 100644 --- a/src/NHibernate/Linq/Clauses/NhHavingClause.cs +++ b/src/NHibernate/Linq/Clauses/NhHavingClause.cs @@ -1,19 +1,78 @@ -using Remotion.Linq.Clauses; -using Remotion.Linq.Clauses.ExpressionTreeVisitors; +using System; using System.Linq.Expressions; +using Remotion.Linq; +using Remotion.Linq.Clauses; namespace NHibernate.Linq.Clauses { - public class NhHavingClause : WhereClause + public class NhHavingClause : NhClauseBase, IBodyClause { - public NhHavingClause(Expression predicate) - : base(predicate) + Expression _predicate; + + /// + /// Initializes a new instance of the class. + /// + /// The predicate used to filter data items. + public NhHavingClause(Expression predicate) + { + if (predicate == null) throw new ArgumentNullException(nameof(predicate)); + _predicate = predicate; + } + + /// + /// Gets the predicate, the expression representing the where condition by which the data items are filtered + /// + public Expression Predicate + { + get { return _predicate; } + set + { + if (value == null) throw new ArgumentNullException(nameof(value)); + _predicate = value; + } + } + + protected override void Accept(INhQueryModelVisitor visitor, QueryModel queryModel, int index) { + visitor.VisitNhHavingClause(this, queryModel, index); + } + + /// + IBodyClause IBodyClause.Clone(CloneContext cloneContext) + { + return Clone(cloneContext); + } + + /// + /// Transforms all the expressions in this clause and its child objects via the given + /// delegate. + /// + /// + /// The transformation object. This delegate is called for each + /// within this + /// clause, and those expressions will be replaced with what the delegate returns. + /// + public void TransformExpressions(Func transformation) + { + if (transformation == null) throw new ArgumentNullException(nameof(transformation)); + Predicate = transformation(Predicate); } public override string ToString() { - return "having " + FormattingExpressionTreeVisitor.Format(Predicate); + return "having " + Predicate; + } + + /// Clones this clause. + /// + /// The clones of all query source clauses are registered with this + /// . + /// + /// + public NhHavingClause Clone(CloneContext cloneContext) + { + if (cloneContext == null) throw new ArgumentNullException(nameof(cloneContext)); + return new NhHavingClause(Predicate); } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Clauses/NhJoinClause.cs b/src/NHibernate/Linq/Clauses/NhJoinClause.cs index 944a8248d4d..9257c15e2a2 100644 --- a/src/NHibernate/Linq/Clauses/NhJoinClause.cs +++ b/src/NHibernate/Linq/Clauses/NhJoinClause.cs @@ -3,48 +3,156 @@ using System.Collections.ObjectModel; using System.Linq.Expressions; using NHibernate.Linq.Visitors; +using Remotion.Linq; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; namespace NHibernate.Linq.Clauses { /// - /// All joins are created as outer joins. An optimization in finds - /// joins that may be inner joined and calls on them. - /// 's will - /// then emit the correct HQL join. + /// All joins are created as outer joins. An optimization in finds + /// joins that may be inner joined and calls on them. + /// 's will + /// then emit the correct HQL join. /// - public class NhJoinClause : AdditionalFromClause + public class NhJoinClause : NhClauseBase, IFromClause, IBodyClause { + Expression _fromExpression; + string _itemName; + System.Type _itemType; + public NhJoinClause(string itemName, System.Type itemType, Expression fromExpression) : this(itemName, itemType, fromExpression, new NhWithClause[0]) { } + /// + /// Initializes a new instance of the class. + /// + /// A name describing the items generated by the from clause. + /// The type of the items generated by the from clause. + /// + /// The generating data items for this + /// from clause. + /// + /// public NhJoinClause(string itemName, System.Type itemType, Expression fromExpression, IEnumerable restrictions) - : base(itemName, itemType, fromExpression) { - Restrictions = new ObservableCollection(); - foreach (var withClause in restrictions) - Restrictions.Add(withClause); + if (string.IsNullOrEmpty(itemName)) throw new ArgumentException("Value cannot be null or empty.", nameof(itemName)); + if (itemType == null) throw new ArgumentNullException(nameof(itemType)); + if (fromExpression == null) throw new ArgumentNullException(nameof(fromExpression)); + + _itemName = itemName; + _itemType = itemType; + _fromExpression = fromExpression; + + Restrictions = new ObservableCollection(restrictions); IsInner = false; } - public ObservableCollection Restrictions { get; private set; } + public ObservableCollection Restrictions { get; } public bool IsInner { get; private set; } - public override AdditionalFromClause Clone(CloneContext cloneContext) + public void TransformExpressions(Func transformation) { - var joinClause = new NhJoinClause(ItemName, ItemType, FromExpression); + if (transformation == null) throw new ArgumentNullException(nameof(transformation)); foreach (var withClause in Restrictions) + withClause.TransformExpressions(transformation); + FromExpression = transformation(FromExpression); + } + + /// + /// Accepts the specified visitor by calling its + /// + /// method. + /// + /// The visitor to accept. + /// The query model in whose context this clause is visited. + /// + /// The index of this clause in the 's + /// collection. + /// + protected override void Accept(INhQueryModelVisitor visitor, QueryModel queryModel, int index) + { + visitor.VisitNhJoinClause(this, queryModel, index); + } + + IBodyClause IBodyClause.Clone(CloneContext cloneContext) + { + return Clone(cloneContext); + } + + /// + /// Gets or sets a name describing the items generated by this from clause. + /// + /// + /// Item names are inferred when a query expression is parsed, and they usually correspond to the variable names + /// present in that expression. + /// However, note that names are not necessarily unique within a . Use names + /// only for readability and debugging, not for + /// uniquely identifying objects. To match an + /// with its references, use the + /// property + /// rather than the . + /// + public string ItemName + { + get { return _itemName; } + set + { + if (string.IsNullOrEmpty(value)) throw new ArgumentException("Value cannot be null or empty.", nameof(value)); + _itemName = value; + } + } + + /// + /// Gets or sets the type of the items generated by this from clause. + /// + /// + /// Changing the of a + /// can make all + /// objects that + /// point to that invalid, so the property setter should be used + /// with care. + /// + public System.Type ItemType + { + get { return _itemType; } + set { - var withClause2 = new NhWithClause(withClause.Predicate); - joinClause.Restrictions.Add(withClause2); + if (value == null) throw new ArgumentNullException(nameof(value)); + _itemType = value; } + } + /// + /// The expression generating the data items for this from clause. + /// + public Expression FromExpression + { + get { return _fromExpression; } + set + { + if (value == null) throw new ArgumentNullException(nameof(value)); + _fromExpression = value; + } + } + + public void CopyFromSource(IFromClause source) + { + if (source == null) throw new ArgumentNullException(nameof(source)); + FromExpression = source.FromExpression; + ItemName = source.ItemName; + ItemType = source.ItemType; + } + + public NhJoinClause Clone(CloneContext cloneContext) + { + var joinClause = new NhJoinClause(ItemName, ItemType, FromExpression, Restrictions); cloneContext.QuerySourceMapping.AddMapping(this, new QuerySourceReferenceExpression(joinClause)); - return base.Clone(cloneContext); + return joinClause; } public void MakeInner() @@ -52,11 +160,9 @@ public void MakeInner() IsInner = true; } - public override void TransformExpressions(Func transformation) + public override string ToString() { - foreach (var withClause in Restrictions) - withClause.TransformExpressions(transformation); - base.TransformExpressions(transformation); + return string.Format("join {0} {1} in {2}", ItemType.Name, ItemName, FromExpression); } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Clauses/NhWithClause.cs b/src/NHibernate/Linq/Clauses/NhWithClause.cs index ae21bc0f4a6..376b04b9125 100644 --- a/src/NHibernate/Linq/Clauses/NhWithClause.cs +++ b/src/NHibernate/Linq/Clauses/NhWithClause.cs @@ -1,19 +1,77 @@ +using System; using System.Linq.Expressions; +using Remotion.Linq; using Remotion.Linq.Clauses; -using Remotion.Linq.Clauses.ExpressionTreeVisitors; namespace NHibernate.Linq.Clauses { - public class NhWithClause : WhereClause + public class NhWithClause : NhClauseBase, IBodyClause { + Expression _predicate; + + /// + /// Initializes a new instance of the class. + /// + /// The predicate used to filter data items. public NhWithClause(Expression predicate) - : base(predicate) { + if (predicate == null) throw new ArgumentNullException(nameof(predicate)); + _predicate = predicate; + } + + /// + /// Gets the predicate, the expression representing the where condition by which the data items are filtered + /// + public Expression Predicate + { + get { return _predicate; } + set + { + if (value == null) throw new ArgumentNullException(nameof(value)); + _predicate = value; + } } public override string ToString() { - return "with " + FormattingExpressionTreeVisitor.Format(Predicate); + return "with " + Predicate; + } + + protected override void Accept(INhQueryModelVisitor visitor, QueryModel queryModel, int index) + { + visitor.VisitNhWithClause(this, queryModel, index); + } + + IBodyClause IBodyClause.Clone(CloneContext cloneContext) + { + return Clone(cloneContext); + } + + /// Clones this clause. + /// + /// The clones of all query source clauses are registered with this + /// . + /// + /// + public NhWithClause Clone(CloneContext cloneContext) + { + if (cloneContext == null) throw new ArgumentNullException("cloneContext"); + return new NhWithClause(Predicate); + } + + /// + /// Transforms all the expressions in this clause and its child objects via the given + /// delegate. + /// + /// + /// The transformation object. This delegate is called for each + /// within this + /// clause, and those expressions will be replaced with what the delegate returns. + /// + public void TransformExpressions(Func transformation) + { + if (transformation == null) throw new ArgumentNullException("transformation"); + Predicate = transformation(Predicate); } } } diff --git a/src/NHibernate/Linq/ExpressionTransformers/RemoveCharToIntConversion.cs b/src/NHibernate/Linq/ExpressionTransformers/RemoveCharToIntConversion.cs index e411da5d7ab..3e0c19d4ca1 100644 --- a/src/NHibernate/Linq/ExpressionTransformers/RemoveCharToIntConversion.cs +++ b/src/NHibernate/Linq/ExpressionTransformers/RemoveCharToIntConversion.cs @@ -1,6 +1,6 @@ using System; using System.Linq.Expressions; -using Remotion.Linq.Parsing.ExpressionTreeVisitors.Transformation; +using Remotion.Linq.Parsing.ExpressionVisitors.Transformation; namespace NHibernate.Linq.ExpressionTransformers { diff --git a/src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs b/src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs index d5a87a97eb9..538e46cb828 100644 --- a/src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs +++ b/src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs @@ -1,5 +1,5 @@ using System.Linq.Expressions; -using Remotion.Linq.Parsing.ExpressionTreeVisitors.Transformation; +using Remotion.Linq.Parsing.ExpressionVisitors.Transformation; namespace NHibernate.Linq.ExpressionTransformers { diff --git a/src/NHibernate/Linq/ExpressionTransformers/SimplifyCompareTransformer.cs b/src/NHibernate/Linq/ExpressionTransformers/SimplifyCompareTransformer.cs index 73ee8d974f6..5dc00b7fe4b 100644 --- a/src/NHibernate/Linq/ExpressionTransformers/SimplifyCompareTransformer.cs +++ b/src/NHibernate/Linq/ExpressionTransformers/SimplifyCompareTransformer.cs @@ -5,7 +5,7 @@ using System.Reflection; using NHibernate.Linq.Functions; using NHibernate.Util; -using Remotion.Linq.Parsing.ExpressionTreeVisitors.Transformation; +using Remotion.Linq.Parsing.ExpressionVisitors.Transformation; namespace NHibernate.Linq.ExpressionTransformers { diff --git a/src/NHibernate/Linq/Expressions/NhAggregatedExpression.cs b/src/NHibernate/Linq/Expressions/NhAggregatedExpression.cs index e564884e771..6b705d3d92e 100644 --- a/src/NHibernate/Linq/Expressions/NhAggregatedExpression.cs +++ b/src/NHibernate/Linq/Expressions/NhAggregatedExpression.cs @@ -1,34 +1,39 @@ using System.Linq.Expressions; -using Remotion.Linq.Clauses.Expressions; -using Remotion.Linq.Parsing; +using NHibernate.Linq.Visitors; namespace NHibernate.Linq.Expressions { - public abstract class NhAggregatedExpression : ExtensionExpression + public abstract class NhAggregatedExpression : NhExpression { - public Expression Expression { get; set; } - - protected NhAggregatedExpression(Expression expression, NhExpressionType type) - : base(expression.Type, (ExpressionType)type) + protected NhAggregatedExpression(Expression expression) + : this(expression, expression.Type) { - Expression = expression; } - protected NhAggregatedExpression(Expression expression, System.Type expressionType, NhExpressionType type) - : base(expressionType, (ExpressionType)type) + protected NhAggregatedExpression(Expression expression, System.Type type) { Expression = expression; + Type = type; } - protected override Expression VisitChildren(ExpressionTreeVisitor visitor) + public sealed override System.Type Type { get; } + + public Expression Expression { get; } + + protected override Expression VisitChildren(ExpressionVisitor visitor) { - var newExpression = visitor.VisitExpression(Expression); + var newExpression = visitor.Visit(Expression); return newExpression != Expression - ? CreateNew(newExpression) - : this; + ? CreateNew(newExpression) + : this; } public abstract Expression CreateNew(Expression expression); + + protected override Expression Accept(NhExpressionVisitor visitor) + { + return visitor.VisitNhAggregated(this); + } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Expressions/NhAverageExpression.cs b/src/NHibernate/Linq/Expressions/NhAverageExpression.cs index 9dffa5f67fe..f12ac947e4a 100644 --- a/src/NHibernate/Linq/Expressions/NhAverageExpression.cs +++ b/src/NHibernate/Linq/Expressions/NhAverageExpression.cs @@ -1,16 +1,17 @@ using System; using System.Linq.Expressions; +using NHibernate.Linq.Visitors; using NHibernate.Util; namespace NHibernate.Linq.Expressions { public class NhAverageExpression : NhAggregatedExpression { - public NhAverageExpression(Expression expression) : base(expression, CalculateAverageType(expression.Type), NhExpressionType.Average) + public NhAverageExpression(Expression expression) : base(expression, CalculateAverageType(expression.Type)) { } - private static System.Type CalculateAverageType(System.Type inputType) + static System.Type CalculateAverageType(System.Type inputType) { var isNullable = false; @@ -27,7 +28,7 @@ private static System.Type CalculateAverageType(System.Type inputType) case TypeCode.Int64: case TypeCode.Single: case TypeCode.Double: - return isNullable ? typeof(double?) : typeof (double); + return isNullable ? typeof(double?) : typeof(double); case TypeCode.Decimal: return isNullable ? typeof(decimal?) : typeof(decimal); } @@ -39,5 +40,10 @@ public override Expression CreateNew(Expression expression) { return new NhAverageExpression(expression); } + + protected override Expression Accept(NhExpressionVisitor visitor) + { + return visitor.VisitNhAverage(this); + } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Expressions/NhCountExpression.cs b/src/NHibernate/Linq/Expressions/NhCountExpression.cs index 8c9024e0280..6dc698add5c 100644 --- a/src/NHibernate/Linq/Expressions/NhCountExpression.cs +++ b/src/NHibernate/Linq/Expressions/NhCountExpression.cs @@ -1,17 +1,27 @@ using System.Linq.Expressions; +using NHibernate.Linq.Visitors; namespace NHibernate.Linq.Expressions { public abstract class NhCountExpression : NhAggregatedExpression { protected NhCountExpression(Expression expression, System.Type type) - : base(expression, type, NhExpressionType.Count) {} + : base(expression, type) + { + } + + protected override Expression Accept(NhExpressionVisitor visitor) + { + return visitor.VisitNhCount(this); + } } public class NhShortCountExpression : NhCountExpression { public NhShortCountExpression(Expression expression) - : base(expression, typeof (int)) {} + : base(expression, typeof(int)) + { + } public override Expression CreateNew(Expression expression) { @@ -22,11 +32,13 @@ public override Expression CreateNew(Expression expression) public class NhLongCountExpression : NhCountExpression { public NhLongCountExpression(Expression expression) - : base(expression, typeof (long)) {} + : base(expression, typeof(long)) + { + } public override Expression CreateNew(Expression expression) { return new NhLongCountExpression(expression); } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Expressions/NhDistinctExpression.cs b/src/NHibernate/Linq/Expressions/NhDistinctExpression.cs index 4ddc970adfd..bcf3aeaca3f 100644 --- a/src/NHibernate/Linq/Expressions/NhDistinctExpression.cs +++ b/src/NHibernate/Linq/Expressions/NhDistinctExpression.cs @@ -1,11 +1,12 @@ using System.Linq.Expressions; +using NHibernate.Linq.Visitors; namespace NHibernate.Linq.Expressions { public class NhDistinctExpression : NhAggregatedExpression { public NhDistinctExpression(Expression expression) - : base(expression, NhExpressionType.Distinct) + : base(expression) { } @@ -13,5 +14,10 @@ public override Expression CreateNew(Expression expression) { return new NhDistinctExpression(expression); } + + protected override Expression Accept(NhExpressionVisitor visitor) + { + return visitor.VisitNhDistinct(this); + } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Expressions/NhExpression.cs b/src/NHibernate/Linq/Expressions/NhExpression.cs new file mode 100644 index 00000000000..880e2e48b0a --- /dev/null +++ b/src/NHibernate/Linq/Expressions/NhExpression.cs @@ -0,0 +1,21 @@ +using System.Linq.Expressions; +using NHibernate.Linq.Visitors; + +namespace NHibernate.Linq.Expressions +{ + public abstract class NhExpression : Expression + { + public sealed override ExpressionType NodeType => ExpressionType.Extension; + + protected sealed override Expression Accept(ExpressionVisitor visitor) + { + var nhVisitor = visitor as NhExpressionVisitor; + if (nhVisitor != null) + return Accept(nhVisitor); + + return base.Accept(visitor); + } + + protected abstract Expression Accept(NhExpressionVisitor visitor); + } +} diff --git a/src/NHibernate/Linq/Expressions/NhExpressionType.cs b/src/NHibernate/Linq/Expressions/NhExpressionType.cs deleted file mode 100644 index 27463c5904f..00000000000 --- a/src/NHibernate/Linq/Expressions/NhExpressionType.cs +++ /dev/null @@ -1,15 +0,0 @@ -namespace NHibernate.Linq.Expressions -{ - public enum NhExpressionType - { - Average = 10000, - Min, - Max, - Sum, - Count, - Distinct, - New, - Star, - Nominator - } -} \ No newline at end of file diff --git a/src/NHibernate/Linq/Expressions/NhMaxExpression.cs b/src/NHibernate/Linq/Expressions/NhMaxExpression.cs index b4b536fabd0..5060a41a2ec 100644 --- a/src/NHibernate/Linq/Expressions/NhMaxExpression.cs +++ b/src/NHibernate/Linq/Expressions/NhMaxExpression.cs @@ -1,11 +1,12 @@ using System.Linq.Expressions; +using NHibernate.Linq.Visitors; namespace NHibernate.Linq.Expressions { public class NhMaxExpression : NhAggregatedExpression { public NhMaxExpression(Expression expression) - : base(expression, NhExpressionType.Max) + : base(expression) { } @@ -13,5 +14,10 @@ public override Expression CreateNew(Expression expression) { return new NhMaxExpression(expression); } + + protected override Expression Accept(NhExpressionVisitor visitor) + { + return visitor.VisitNhMax(this); + } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Expressions/NhMinExpression.cs b/src/NHibernate/Linq/Expressions/NhMinExpression.cs index e8eb33570dc..b975e83a3e7 100644 --- a/src/NHibernate/Linq/Expressions/NhMinExpression.cs +++ b/src/NHibernate/Linq/Expressions/NhMinExpression.cs @@ -1,11 +1,12 @@ using System.Linq.Expressions; +using NHibernate.Linq.Visitors; namespace NHibernate.Linq.Expressions { public class NhMinExpression : NhAggregatedExpression { public NhMinExpression(Expression expression) - : base(expression, NhExpressionType.Min) + : base(expression) { } @@ -13,5 +14,10 @@ public override Expression CreateNew(Expression expression) { return new NhMinExpression(expression); } + + protected override Expression Accept(NhExpressionVisitor visitor) + { + return visitor.VisitNhMin(this); + } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Expressions/NhNewExpression.cs b/src/NHibernate/Linq/Expressions/NhNewExpression.cs index 5c55e020f17..11e49f63e60 100644 --- a/src/NHibernate/Linq/Expressions/NhNewExpression.cs +++ b/src/NHibernate/Linq/Expressions/NhNewExpression.cs @@ -1,64 +1,36 @@ using System.Collections.Generic; using System.Collections.ObjectModel; using System.Linq.Expressions; -using Remotion.Linq.Clauses.Expressions; -using Remotion.Linq.Parsing; +using NHibernate.Linq.Visitors; namespace NHibernate.Linq.Expressions { - public class NhNewExpression : ExtensionExpression + public class NhNewExpression : NhExpression { - private readonly ReadOnlyCollection _members; - private readonly ReadOnlyCollection _arguments; - public NhNewExpression(IList members, IList arguments) - : base(typeof(object), (ExpressionType)NhExpressionType.New) { - _members = new ReadOnlyCollection(members); - _arguments = new ReadOnlyCollection(arguments); + Members = new ReadOnlyCollection(members); + Arguments = new ReadOnlyCollection(arguments); } - public ReadOnlyCollection Arguments - { - get { return _arguments; } - } + public override System.Type Type => typeof(object); - public ReadOnlyCollection Members - { - get { return _members; } - } + public ReadOnlyCollection Arguments { get; } + + public ReadOnlyCollection Members { get; } - protected override Expression VisitChildren(ExpressionTreeVisitor visitor) + protected override Expression VisitChildren(ExpressionVisitor visitor) { var arguments = visitor.VisitAndConvert(Arguments, "VisitNhNew"); return arguments != Arguments - ? new NhNewExpression(Members, arguments) - : this; + ? new NhNewExpression(Members, arguments) + : this; } - } - public class NhStarExpression : ExtensionExpression - { - public NhStarExpression(Expression expression) - : base(expression.Type, (ExpressionType)NhExpressionType.Star) + protected override Expression Accept(NhExpressionVisitor visitor) { - Expression = expression; - } - - public Expression Expression - { - get; - private set; - } - - protected override Expression VisitChildren(ExpressionTreeVisitor visitor) - { - var newExpression = visitor.VisitExpression(Expression); - - return newExpression != Expression - ? new NhStarExpression(newExpression) - : this; + return visitor.VisitNhNew(this); } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Expressions/NhNominatedExpression.cs b/src/NHibernate/Linq/Expressions/NhNominatedExpression.cs index 15997729997..c8af6774f7d 100644 --- a/src/NHibernate/Linq/Expressions/NhNominatedExpression.cs +++ b/src/NHibernate/Linq/Expressions/NhNominatedExpression.cs @@ -1,34 +1,40 @@ using System.Linq.Expressions; -using Remotion.Linq.Clauses.Expressions; -using Remotion.Linq.Parsing; +using NHibernate.Linq.Visitors; namespace NHibernate.Linq.Expressions { /// - /// Represents an expression that has been nominated for direct inclusion in the SELECT clause. - /// This bypasses the standard nomination process and assumes that the expression can be converted - /// directly to SQL. + /// Represents an expression that has been nominated for direct inclusion in the SELECT clause. + /// This bypasses the standard nomination process and assumes that the expression can be converted + /// directly to SQL. /// /// - /// Used in the nomination of GroupBy key expressions to ensure that matching select clauses - /// are generated the same way. + /// Used in the nomination of GroupBy key expressions to ensure that matching select clauses + /// are generated the same way. /// - internal class NhNominatedExpression : ExtensionExpression + public class NhNominatedExpression : NhExpression { - public Expression Expression { get; private set; } - - public NhNominatedExpression(Expression expression) : base(expression.Type, (ExpressionType)NhExpressionType.Nominator) + public NhNominatedExpression(Expression expression) { Expression = expression; } - protected override Expression VisitChildren(ExpressionTreeVisitor visitor) + public override System.Type Type => Expression.Type; + + public Expression Expression { get; } + + protected override Expression Accept(NhExpressionVisitor visitor) + { + return visitor.VisitNhNominated(this); + } + + protected override Expression VisitChildren(ExpressionVisitor visitor) { - var newExpression = visitor.VisitExpression(Expression); + var newExpression = visitor.Visit(Expression); return newExpression != Expression ? new NhNominatedExpression(newExpression) : this; } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Expressions/NhStarExpression.cs b/src/NHibernate/Linq/Expressions/NhStarExpression.cs new file mode 100644 index 00000000000..adffa3e1642 --- /dev/null +++ b/src/NHibernate/Linq/Expressions/NhStarExpression.cs @@ -0,0 +1,31 @@ +using System.Linq.Expressions; +using NHibernate.Linq.Visitors; + +namespace NHibernate.Linq.Expressions +{ + public class NhStarExpression : NhExpression + { + public NhStarExpression(Expression expression) + { + Expression = expression; + } + + public Expression Expression { get; } + + public override System.Type Type => Expression.Type; + + protected override Expression VisitChildren(ExpressionVisitor visitor) + { + var newExpression = visitor.Visit(Expression); + + return newExpression != Expression + ? new NhStarExpression(newExpression) + : this; + } + + protected override Expression Accept(NhExpressionVisitor visitor) + { + return visitor.VisitNhStar(this); + } + } +} diff --git a/src/NHibernate/Linq/Expressions/NhSumExpression.cs b/src/NHibernate/Linq/Expressions/NhSumExpression.cs index d8e7326eda0..e7a0e263beb 100644 --- a/src/NHibernate/Linq/Expressions/NhSumExpression.cs +++ b/src/NHibernate/Linq/Expressions/NhSumExpression.cs @@ -1,11 +1,12 @@ using System.Linq.Expressions; +using NHibernate.Linq.Visitors; namespace NHibernate.Linq.Expressions { public class NhSumExpression : NhAggregatedExpression { public NhSumExpression(Expression expression) - : base(expression, NhExpressionType.Sum) + : base(expression) { } @@ -13,5 +14,10 @@ public override Expression CreateNew(Expression expression) { return new NhSumExpression(expression); } + + protected override Expression Accept(NhExpressionVisitor visitor) + { + return visitor.VisitNhSum(this); + } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs b/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs index 5290f3189dd..c031d305396 100644 --- a/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs +++ b/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs @@ -7,17 +7,17 @@ using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Clauses.ResultOperators; using Remotion.Linq.Parsing; -using Remotion.Linq.Parsing.ExpressionTreeVisitors; +using Remotion.Linq.Parsing.ExpressionVisitors; namespace NHibernate.Linq.GroupBy { //This should be renamed. It handles entire querymodels, not just select clauses - internal class GroupBySelectClauseRewriter : ExpressionTreeVisitor + internal class GroupBySelectClauseRewriter : RelinqExpressionVisitor { public static Expression ReWrite(Expression expression, GroupResultOperator groupBy, QueryModel model) { var visitor = new GroupBySelectClauseRewriter(groupBy, model); - return TransparentIdentifierRemovingExpressionTreeVisitor.ReplaceTransparentIdentifiers(visitor.VisitExpression(expression)); + return TransparentIdentifierRemovingExpressionVisitor.ReplaceTransparentIdentifiers(visitor.Visit(expression)); } private readonly GroupResultOperator _groupBy; @@ -31,11 +31,11 @@ private GroupBySelectClauseRewriter(GroupResultOperator groupBy, QueryModel mode _nominatedKeySelector = GroupKeyNominator.Visit(groupBy); } - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { if (!IsMemberOfModel(expression)) { - return base.VisitQuerySourceReferenceExpression(expression); + return base.VisitQuerySourceReference(expression); } if (expression.IsGroupingElementOf(_groupBy)) @@ -43,14 +43,14 @@ protected override Expression VisitQuerySourceReferenceExpression(QuerySourceRef return _groupBy.ElementSelector; } - return base.VisitQuerySourceReferenceExpression(expression); + return base.VisitQuerySourceReference(expression); } - protected override Expression VisitMemberExpression(MemberExpression expression) + protected override Expression VisitMember(MemberExpression expression) { if (!IsMemberOfModel(expression)) { - return base.VisitMemberExpression(expression); + return base.VisitMember(expression); } if (expression.IsGroupingKeyOf(_groupBy)) @@ -64,7 +64,7 @@ protected override Expression VisitMemberExpression(MemberExpression expression) if ((elementSelector is MemberExpression) || (elementSelector is QuerySourceReferenceExpression)) { // If ElementSelector is MemberExpression, just return - return base.VisitMemberExpression(expression); + return base.VisitMember(expression); } if ((elementSelector is NewExpression || elementSelector.NodeType == ExpressionType.Convert) @@ -120,21 +120,23 @@ private bool IsMemberOfModel(QuerySourceReferenceExpression expression) return subQuery2 != null && subQuery2.QueryModel == _model; } - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { //If the subquery is a Count(*) aggregate with a condition if (expression.QueryModel.MainFromClause.FromExpression.Type == _groupBy.ItemType) { var where = expression.QueryModel.BodyClauses.OfType().FirstOrDefault(); - NhCountExpression countExpression; - if (where != null && (countExpression = expression.QueryModel.SelectClause.Selector as NhCountExpression) != - null && countExpression.Expression.NodeType == (ExpressionType)NhExpressionType.Star) + if (where != null && + expression.QueryModel.SelectClause.Selector is NhCountExpression countExpression && + countExpression.Expression is NhStarExpression) { //return it as a CASE [column] WHEN [predicate] THEN 1 ELSE NULL END return - countExpression.CreateNew(Expression.Condition(where.Predicate, Expression.Constant(1, typeof(int?)), + countExpression.CreateNew( + Expression.Condition( + where.Predicate, + Expression.Constant(1, typeof(int?)), Expression.Constant(null, typeof(int?)))); - } } @@ -144,9 +146,9 @@ protected override Expression VisitSubQueryExpression(SubQueryExpression express { foreach (var bodyClause in expression.QueryModel.BodyClauses) { - bodyClause.TransformExpressions((e) => new KeySelectorVisitor(_groupBy).VisitExpression(e)); + bodyClause.TransformExpressions((e) => new KeySelectorVisitor(_groupBy).Visit(e)); } - return base.VisitSubQueryExpression(expression); + return base.VisitSubQuery(expression); } diff --git a/src/NHibernate/Linq/GroupBy/GroupKeyNominator.cs b/src/NHibernate/Linq/GroupBy/GroupKeyNominator.cs index c550f9aa0e6..b2ecae95243 100644 --- a/src/NHibernate/Linq/GroupBy/GroupKeyNominator.cs +++ b/src/NHibernate/Linq/GroupBy/GroupKeyNominator.cs @@ -12,7 +12,7 @@ namespace NHibernate.Linq.GroupBy /// This class nominates sub-expression trees on the GroupBy Key expression /// for inclusion in the Select clause. /// - internal class GroupKeyNominator : ExpressionTreeVisitor + internal class GroupKeyNominator : RelinqExpressionVisitor { private GroupKeyNominator() { } @@ -27,13 +27,13 @@ public static Expression Visit(GroupResultOperator groupBy) private static Expression VisitInternal(Expression expr) { - return new GroupKeyNominator().VisitExpression(expr); + return new GroupKeyNominator().Visit(expr); } - public override Expression VisitExpression(Expression expression) + public override Expression Visit(Expression expression) { _depth++; - var expr = base.VisitExpression(expression); + var expr = base.Visit(expression); _depth--; // At the root expression, wrap it in the nominator expression if needed @@ -44,45 +44,45 @@ public override Expression VisitExpression(Expression expression) return expr; } - protected override Expression VisitNewArrayExpression(NewArrayExpression expression) + protected override Expression VisitNewArray(NewArrayExpression expression) { _transformed = true; // Transform each initializer recursively (to allow for nested initializers) return Expression.NewArrayInit(expression.Type.GetElementType(), expression.Expressions.Select(VisitInternal)); } - protected override Expression VisitNewExpression(NewExpression expression) + protected override Expression VisitNew(NewExpression expression) { _transformed = true; // Transform each initializer recursively (to allow for nested initializers) return Expression.New(expression.Constructor, expression.Arguments.Select(VisitInternal), expression.Members); } - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { // If the (sub)expression contains a QuerySourceReference, then the entire expression should be nominated _requiresRootNomination = true; - return base.VisitQuerySourceReferenceExpression(expression); + return base.VisitQuerySourceReference(expression); } - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { // If the (sub)expression contains a QuerySourceReference, then the entire expression should be nominated _requiresRootNomination = true; - return base.VisitSubQueryExpression(expression); + return base.VisitSubQuery(expression); } - protected override Expression VisitBinaryExpression(BinaryExpression expression) + protected override Expression VisitBinary(BinaryExpression expression) { if (expression.NodeType != ExpressionType.ArrayIndex) - return base.VisitBinaryExpression(expression); + return base.VisitBinary(expression); // If we encounter an array index then we need to attempt to flatten it before nomination - var flattenedExpression = new ArrayIndexExpressionFlattener().VisitExpression(expression); + var flattenedExpression = new ArrayIndexExpressionFlattener().Visit(expression); if (flattenedExpression != expression) - return base.VisitExpression(flattenedExpression); + return base.Visit(flattenedExpression); - return base.VisitBinaryExpression(expression); + return base.VisitBinary(expression); } } } \ No newline at end of file diff --git a/src/NHibernate/Linq/GroupBy/IsNonAggregatingGroupByDetectionVisitor.cs b/src/NHibernate/Linq/GroupBy/IsNonAggregatingGroupByDetectionVisitor.cs index 97478d29fbf..7ce24eb898e 100644 --- a/src/NHibernate/Linq/GroupBy/IsNonAggregatingGroupByDetectionVisitor.cs +++ b/src/NHibernate/Linq/GroupBy/IsNonAggregatingGroupByDetectionVisitor.cs @@ -8,7 +8,7 @@ namespace NHibernate.Linq.GroupBy /// /// Detects if an expression tree contains naked QuerySourceReferenceExpression /// - internal class IsNonAggregatingGroupByDetectionVisitor : NhExpressionTreeVisitor + internal class IsNonAggregatingGroupByDetectionVisitor : NhExpressionVisitor { private bool _containsNakedQuerySourceReferenceExpression; @@ -16,27 +16,27 @@ public bool IsNonAggregatingGroupBy(Expression expression) { _containsNakedQuerySourceReferenceExpression = false; - VisitExpression(expression); + Visit(expression); return _containsNakedQuerySourceReferenceExpression; } - protected override Expression VisitMemberExpression(MemberExpression expression) + protected override Expression VisitMember(MemberExpression expression) { return expression.IsGroupingKey() ? expression - : base.VisitMemberExpression(expression); + : base.VisitMember(expression); } - protected override Expression VisitNhAggregate(NhAggregatedExpression expression) + protected internal override Expression VisitNhAggregated(NhAggregatedExpression expression) { return expression; } - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { _containsNakedQuerySourceReferenceExpression = true; return expression; } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/GroupBy/KeySelectorVisitor.cs b/src/NHibernate/Linq/GroupBy/KeySelectorVisitor.cs index 7de5756ae72..db9c4d098bf 100644 --- a/src/NHibernate/Linq/GroupBy/KeySelectorVisitor.cs +++ b/src/NHibernate/Linq/GroupBy/KeySelectorVisitor.cs @@ -4,7 +4,7 @@ namespace NHibernate.Linq.GroupBy { - internal class KeySelectorVisitor : ExpressionTreeVisitor + internal class KeySelectorVisitor : RelinqExpressionVisitor { private readonly GroupResultOperator _groupBy; @@ -13,13 +13,13 @@ public KeySelectorVisitor(GroupResultOperator groupBy) _groupBy = groupBy; } - protected override Expression VisitMemberExpression(MemberExpression expression) + protected override Expression VisitMember(MemberExpression expression) { if (expression.IsGroupingKeyOf(_groupBy)) { return _groupBy.KeySelector; } - return base.VisitMemberExpression(expression); + return base.VisitMember(expression); } } } \ No newline at end of file diff --git a/src/NHibernate/Linq/GroupBy/NonAggregatingGroupByRewriter.cs b/src/NHibernate/Linq/GroupBy/NonAggregatingGroupByRewriter.cs index 120d332a7f4..90bc4c1884c 100644 --- a/src/NHibernate/Linq/GroupBy/NonAggregatingGroupByRewriter.cs +++ b/src/NHibernate/Linq/GroupBy/NonAggregatingGroupByRewriter.cs @@ -3,7 +3,7 @@ using NHibernate.Linq.ResultOperators; using Remotion.Linq; using Remotion.Linq.Clauses; -using Remotion.Linq.Clauses.ExpressionTreeVisitors; +using Remotion.Linq.Clauses.ExpressionVisitors; using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Clauses.ResultOperators; @@ -71,7 +71,7 @@ private static ClientSideSelect CreateClientSideSelect(Expression expression, Qu var mapping = new QuerySourceMapping(); mapping.AddMapping(queryModel.MainFromClause, parameter); - var body = ReferenceReplacingExpressionTreeVisitor.ReplaceClauseReferences(queryModel.SelectClause.Selector, mapping, false); + var body = ReferenceReplacingExpressionVisitor.ReplaceClauseReferences(queryModel.SelectClause.Selector, mapping, false); var lambda = Expression.Lambda(body, parameter); diff --git a/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs b/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs index 27f0bf8805f..78ee78c35eb 100644 --- a/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs +++ b/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs @@ -8,7 +8,7 @@ namespace NHibernate.Linq.GroupJoin { - internal class GroupJoinAggregateDetectionVisitor : NhExpressionTreeVisitor + internal class GroupJoinAggregateDetectionVisitor : NhExpressionVisitor { private readonly HashSet _groupJoinClauses; private readonly StackFlag _inAggregate = new StackFlag(); @@ -27,26 +27,26 @@ public static IsAggregatingResults Visit(IEnumerable groupJoinC { var visitor = new GroupJoinAggregateDetectionVisitor(groupJoinClause); - visitor.VisitExpression(selectExpression); + visitor.Visit(selectExpression); return new IsAggregatingResults { NonAggregatingClauses = visitor._nonAggregatingGroupJoins, AggregatingClauses = visitor._aggregatingGroupJoins, NonAggregatingExpressions = visitor._nonAggregatingExpressions }; } - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { - VisitExpression(expression.QueryModel.SelectClause.Selector); + Visit(expression.QueryModel.SelectClause.Selector); return expression; } - protected override Expression VisitNhAggregate(NhAggregatedExpression expression) + protected internal override Expression VisitNhAggregated(NhAggregatedExpression expression) { using (_inAggregate.SetFlag()) { - return base.VisitNhAggregate(expression); + return base.VisitNhAggregated(expression); } } - protected override Expression VisitMemberExpression(MemberExpression expression) + protected override Expression VisitMember(MemberExpression expression) { if (_inAggregate.FlagIsFalse && _parentExpressionProcessed.FlagIsFalse) { @@ -55,11 +55,11 @@ protected override Expression VisitMemberExpression(MemberExpression expression) using (_parentExpressionProcessed.SetFlag()) { - return base.VisitMemberExpression(expression); + return base.VisitMember(expression); } } - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { var fromClause = (FromClauseBase) expression.ReferencedQuerySource; @@ -80,7 +80,7 @@ protected override Expression VisitQuerySourceReferenceExpression(QuerySourceRef } } - return base.VisitQuerySourceReferenceExpression(expression); + return base.VisitQuerySourceReference(expression); } internal class StackFlag @@ -113,4 +113,4 @@ public void Dispose() } } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/GroupJoin/GroupJoinSelectClauseRewriter.cs b/src/NHibernate/Linq/GroupJoin/GroupJoinSelectClauseRewriter.cs index ad7ac347467..546e7cd514e 100644 --- a/src/NHibernate/Linq/GroupJoin/GroupJoinSelectClauseRewriter.cs +++ b/src/NHibernate/Linq/GroupJoin/GroupJoinSelectClauseRewriter.cs @@ -7,13 +7,13 @@ namespace NHibernate.Linq.GroupJoin { - public class GroupJoinSelectClauseRewriter : ExpressionTreeVisitor + public class GroupJoinSelectClauseRewriter : RelinqExpressionVisitor { private readonly IsAggregatingResults _results; public static Expression ReWrite(Expression expression, IsAggregatingResults results) { - return new GroupJoinSelectClauseRewriter(results).VisitExpression(expression); + return new GroupJoinSelectClauseRewriter(results).Visit(expression); } private GroupJoinSelectClauseRewriter(IsAggregatingResults results) @@ -21,7 +21,7 @@ private GroupJoinSelectClauseRewriter(IsAggregatingResults results) _results = results; } - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { // If the sub query's main (and only) from clause is one of our aggregating group bys, then swap it GroupJoinClause groupJoin = LocateGroupJoinQuerySource(expression.QueryModel); diff --git a/src/NHibernate/Linq/GroupJoin/LocateGroupJoinQuerySource.cs b/src/NHibernate/Linq/GroupJoin/LocateGroupJoinQuerySource.cs index dc55aae5812..c61b5f7a466 100644 --- a/src/NHibernate/Linq/GroupJoin/LocateGroupJoinQuerySource.cs +++ b/src/NHibernate/Linq/GroupJoin/LocateGroupJoinQuerySource.cs @@ -5,7 +5,7 @@ namespace NHibernate.Linq.GroupJoin { - public class LocateGroupJoinQuerySource : ExpressionTreeVisitor + public class LocateGroupJoinQuerySource : RelinqExpressionVisitor { private readonly IsAggregatingResults _results; private GroupJoinClause _groupJoin; @@ -17,11 +17,11 @@ public LocateGroupJoinQuerySource(IsAggregatingResults results) public GroupJoinClause Detect(Expression expression) { - VisitExpression(expression); + Visit(expression); return _groupJoin; } - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { var groupJoinClause = expression.ReferencedQuerySource as GroupJoinClause; if (groupJoinClause != null && _results.AggregatingClauses.Contains(groupJoinClause)) @@ -29,7 +29,7 @@ protected override Expression VisitQuerySourceReferenceExpression(QuerySourceRef _groupJoin = groupJoinClause; } - return base.VisitQuerySourceReferenceExpression(expression); + return base.VisitQuerySourceReference(expression); } } } \ No newline at end of file diff --git a/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs b/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs index 73989092dcb..94da56879db 100644 --- a/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs +++ b/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs @@ -154,7 +154,7 @@ private IsAggregatingResults GetGroupJoinInformation(IEnumerable ToFutureValue(this IQuerya throw new NotSupportedException($"Source {nameof(source.Provider)} must be a {nameof(INhQueryProvider)}"); } - var expression = ReplacingExpressionTreeVisitor + var expression = ReplacingExpressionVisitor .Replace(selector.Parameters.Single(), source.Expression, selector.Body); return provider.ExecuteFutureValue(expression); diff --git a/src/NHibernate/Linq/LinqLogging.cs b/src/NHibernate/Linq/LinqLogging.cs index 942fc007340..7770875e378 100644 --- a/src/NHibernate/Linq/LinqLogging.cs +++ b/src/NHibernate/Linq/LinqLogging.cs @@ -21,8 +21,8 @@ internal static void LogExpression(string msg, Expression expression) // generated by a class internal to System.Linq.Expression, so we cannot // actually override that logic. Circumvent it by replacing such ConstantExpressions // with ParameterExpression, having their name set to the string we wish to display. - var visitor = new ProxyReplacingExpressionTreeVisitor(); - var preparedExpression = visitor.VisitExpression(expression); + var visitor = new ProxyReplacingExpressionVisitor(); + var preparedExpression = visitor.Visit(expression); Log.DebugFormat("{0}: {1}", msg, preparedExpression.ToString()); } @@ -34,17 +34,17 @@ internal static void LogExpression(string msg, Expression expression) /// proxy with a ParameterExpression. The name of the parameter will be a string /// representing the proxied entity, without initializing it. /// - private class ProxyReplacingExpressionTreeVisitor : NhExpressionTreeVisitor + private class ProxyReplacingExpressionVisitor : NhExpressionVisitor { - // See also e.g. Remotion.Linq.Clauses.ExpressionTreeVisitors.FormattingExpressionTreeVisitor + // See also e.g. Remotion.Linq.Clauses.ExpressionVisitors.FormattingExpressionTreeVisitor // for another example of this technique. - protected override Expression VisitConstantExpression(ConstantExpression expression) + protected override Expression VisitConstant(ConstantExpression expression) { if (expression.Value.IsProxy()) return Expression.Parameter(expression.Type, ObjectHelpers.IdentityToString(expression.Value)); - return base.VisitConstantExpression(expression); + return base.VisitConstant(expression); } } } diff --git a/src/NHibernate/Linq/NestedSelects/NestedSelectDetector.cs b/src/NHibernate/Linq/NestedSelects/NestedSelectDetector.cs index 0d6639a6ef5..e8877673edf 100644 --- a/src/NHibernate/Linq/NestedSelects/NestedSelectDetector.cs +++ b/src/NHibernate/Linq/NestedSelects/NestedSelectDetector.cs @@ -7,7 +7,7 @@ namespace NHibernate.Linq.NestedSelects { - internal class NestedSelectDetector : ExpressionTreeVisitor + internal class NestedSelectDetector : RelinqExpressionVisitor { private readonly ISessionFactory sessionFactory; private readonly ICollection _expressions = new List(); @@ -27,14 +27,14 @@ public bool HasSubqueries get { return Expressions.Count > 0; } } - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { if (expression.QueryModel.ResultOperators.Count == 0) Expressions.Add(expression); - return base.VisitSubQueryExpression(expression); + return base.VisitSubQuery(expression); } - protected override Expression VisitMemberExpression(MemberExpression expression) + protected override Expression VisitMember(MemberExpression expression) { var memberType = ReflectHelper.GetPropertyOrFieldType(expression.Member); @@ -45,7 +45,7 @@ protected override Expression VisitMemberExpression(MemberExpression expression) Expressions.Add(expression); } - return base.VisitMemberExpression(expression); + return base.VisitMember(expression); } private bool IsMappedCollection(MemberInfo memberInfo) diff --git a/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs b/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs index 37ab62de54b..4e4ee3220fc 100644 --- a/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs +++ b/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs @@ -29,7 +29,7 @@ static class NestedSelectRewriter public static void ReWrite(QueryModel queryModel, ISessionFactory sessionFactory) { var nsqmv = new NestedSelectDetector(sessionFactory); - nsqmv.VisitExpression(queryModel.SelectClause.Selector); + nsqmv.Visit(queryModel.SelectClause.Selector); if (!nsqmv.HasSubqueries) return; @@ -52,7 +52,7 @@ public static void ReWrite(QueryModel queryModel, ISessionFactory sessionFactory var rewriter = new SelectClauseRewriter(key, expressions, identifier, replacements); - var resultSelector = rewriter.VisitExpression(queryModel.SelectClause.Selector); + var resultSelector = rewriter.Visit(queryModel.SelectClause.Selector); elementExpression.AddRange(expressions); @@ -150,7 +150,7 @@ private static LambdaExpression MakeSelector(ICollection eleme var rewriter = new SelectClauseRewriter(parameter, elementExpression, identifier, 1, new Dictionary()); - var selectorBody = rewriter.VisitExpression(@select); + var selectorBody = rewriter.Visit(@select); return Expression.Lambda(selectorBody, parameter); } diff --git a/src/NHibernate/Linq/NestedSelects/SelectClauseRewriter.cs b/src/NHibernate/Linq/NestedSelects/SelectClauseRewriter.cs index 78337496bf0..aa1a869c611 100644 --- a/src/NHibernate/Linq/NestedSelects/SelectClauseRewriter.cs +++ b/src/NHibernate/Linq/NestedSelects/SelectClauseRewriter.cs @@ -5,7 +5,7 @@ namespace NHibernate.Linq.NestedSelects { - class SelectClauseRewriter : ExpressionTreeVisitor + class SelectClauseRewriter : RelinqExpressionVisitor { private readonly Dictionary _dictionary; @@ -27,7 +27,7 @@ public SelectClauseRewriter(Expression parameter, ICollection _dictionary = dictionary; } - public override Expression VisitExpression(Expression expression) + public override Expression Visit(Expression expression) { if (expression == null) return null; @@ -35,15 +35,15 @@ public override Expression VisitExpression(Expression expression) if (_dictionary.TryGetValue(expression, out replacement)) return replacement; - return base.VisitExpression(expression); + return base.Visit(expression); } - protected override Expression VisitMemberExpression(MemberExpression expression) + protected override Expression VisitMember(MemberExpression expression) { return AddAndConvertExpression(expression); } - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { return AddAndConvertExpression(expression); } diff --git a/src/NHibernate/Linq/NhRelinqQueryParser.cs b/src/NHibernate/Linq/NhRelinqQueryParser.cs index 4de1dd3d9d1..78babe38c1a 100644 --- a/src/NHibernate/Linq/NhRelinqQueryParser.cs +++ b/src/NHibernate/Linq/NhRelinqQueryParser.cs @@ -10,7 +10,7 @@ using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.StreamedData; using Remotion.Linq.EagerFetching.Parsing; -using Remotion.Linq.Parsing.ExpressionTreeVisitors.Transformation; +using Remotion.Linq.Parsing.ExpressionVisitors.Transformation; using Remotion.Linq.Parsing.Structure; using Remotion.Linq.Parsing.Structure.ExpressionTreeProcessors; using Remotion.Linq.Parsing.Structure.IntermediateModel; @@ -55,7 +55,7 @@ static NhRelinqQueryParser() /// The transformed expression. public static Expression PreTransform(Expression expression) { - var partiallyEvaluatedExpression = NhPartialEvaluatingExpressionTreeVisitor.EvaluateIndependentSubtrees(expression); + var partiallyEvaluatedExpression = NhPartialEvaluatingExpressionVisitor.EvaluateIndependentSubtrees(expression); return PreProcessor.Process(partiallyEvaluatedExpression); } @@ -140,9 +140,8 @@ public override Expression Resolve(ParameterExpression inputParameter, Expressio return Source.Resolve(inputParameter, expressionToBeResolved, clauseGenerationContext); } - protected override QueryModel ApplyNodeSpecificSemantics(QueryModel queryModel, ClauseGenerationContext clauseGenerationContext) + protected override void ApplyNodeSpecificSemantics(QueryModel queryModel, ClauseGenerationContext clauseGenerationContext) { - return queryModel; } } @@ -253,4 +252,4 @@ public override void TransformExpressions(Func transform { } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs index 0dc17e068b1..576bb65a9ea 100644 --- a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs +++ b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs @@ -1,5 +1,7 @@ +using System; using System.Linq; using NHibernate.Engine; +using NHibernate.Linq.Clauses; using NHibernate.Linq.Visitors; using Remotion.Linq; using Remotion.Linq.Clauses; @@ -12,7 +14,7 @@ internal interface IIsEntityDecider bool IsIdentifier(System.Type type, string propertyName); } - public class AddJoinsReWriter : QueryModelVisitorBase, IIsEntityDecider + public class AddJoinsReWriter : NhQueryModelVisitorBase, IIsEntityDecider { private readonly ISessionFactoryImplementor _sessionFactory; private readonly MemberExpressionJoinDetector _memberExpressionJoinDetector; @@ -52,6 +54,11 @@ public override void VisitWhereClause(WhereClause whereClause, QueryModel queryM _whereJoinDetector.Transform(whereClause); } + public override void VisitNhHavingClause(NhHavingClause havingClause, QueryModel queryModel, int index) + { + _whereJoinDetector.Transform(havingClause); + } + public bool IsEntity(System.Type type) { return _sessionFactory.GetImplementors(type.FullName).Any(); diff --git a/src/NHibernate/Linq/ReWriters/ArrayIndexExpressionFlattener.cs b/src/NHibernate/Linq/ReWriters/ArrayIndexExpressionFlattener.cs index 0b739b078e3..ed4476b72af 100644 --- a/src/NHibernate/Linq/ReWriters/ArrayIndexExpressionFlattener.cs +++ b/src/NHibernate/Linq/ReWriters/ArrayIndexExpressionFlattener.cs @@ -5,17 +5,17 @@ namespace NHibernate.Linq.ReWriters { - public class ArrayIndexExpressionFlattener : ExpressionTreeVisitor + public class ArrayIndexExpressionFlattener : RelinqExpressionVisitor { public static void ReWrite(QueryModel model) { var visitor = new ArrayIndexExpressionFlattener(); - model.TransformExpressions(visitor.VisitExpression); + model.TransformExpressions(visitor.Visit); } - protected override Expression VisitBinaryExpression(BinaryExpression expression) + protected override Expression VisitBinary(BinaryExpression expression) { - var visitedExpression = base.VisitBinaryExpression(expression); + var visitedExpression = base.VisitBinary(expression); if (visitedExpression.NodeType != ExpressionType.ArrayIndex) return visitedExpression; @@ -28,10 +28,10 @@ protected override Expression VisitBinaryExpression(BinaryExpression expression) if (expressionList == null || expressionList.NodeType != ExpressionType.NewArrayInit) return visitedExpression; - return VisitExpression(expressionList.Expressions[(int)index.Value]); + return Visit(expressionList.Expressions[(int)index.Value]); } - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { ReWrite(expression.QueryModel); return expression; // Note that we modifiy the (mutable) QueryModel, we return an unchanged expression diff --git a/src/NHibernate/Linq/ReWriters/MergeAggregatingResultsRewriter.cs b/src/NHibernate/Linq/ReWriters/MergeAggregatingResultsRewriter.cs index 2e2b3c52405..907671dfc28 100644 --- a/src/NHibernate/Linq/ReWriters/MergeAggregatingResultsRewriter.cs +++ b/src/NHibernate/Linq/ReWriters/MergeAggregatingResultsRewriter.cs @@ -8,11 +8,11 @@ using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Clauses.ResultOperators; using Remotion.Linq.Parsing; -using Remotion.Linq.Parsing.ExpressionTreeVisitors; +using Remotion.Linq.Parsing.ExpressionVisitors; namespace NHibernate.Linq.ReWriters { - public class MergeAggregatingResultsRewriter : QueryModelVisitorBase + public class MergeAggregatingResultsRewriter : NhQueryModelVisitorBase { private MergeAggregatingResultsRewriter() { @@ -85,7 +85,7 @@ private static Expression TransformCountExpression(Expression expression) { if (expression.NodeType == ExpressionType.MemberInit || expression.NodeType == ExpressionType.New || - expression.NodeType == QuerySourceReferenceExpression.ExpressionType) + expression is QuerySourceReferenceExpression) { //Probably it should be done by CountResultOperatorProcessor return new NhStarExpression(expression); @@ -95,7 +95,7 @@ private static Expression TransformCountExpression(Expression expression) } } - internal class MergeAggregatingResultsInExpressionRewriter : ExpressionTreeVisitor + internal class MergeAggregatingResultsInExpressionRewriter : RelinqExpressionVisitor { private readonly NameGenerator _nameGenerator; @@ -108,16 +108,16 @@ public static Expression Rewrite(Expression expression, NameGenerator nameGenera { var visitor = new MergeAggregatingResultsInExpressionRewriter(nameGenerator); - return visitor.VisitExpression(expression); + return visitor.Visit(expression); } - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { MergeAggregatingResultsRewriter.ReWrite(expression.QueryModel); return expression; } - protected override Expression VisitMethodCallExpression(MethodCallExpression m) + protected override Expression VisitMethodCall(MethodCallExpression m) { if (m.Method.DeclaringType == typeof(Queryable) || m.Method.DeclaringType == typeof(Enumerable)) @@ -152,18 +152,20 @@ protected override Expression VisitMethodCallExpression(MethodCallExpression m) } } - return base.VisitMethodCallExpression(m); + return base.VisitMethodCall(m); } private Expression CreateAggregate(Expression fromClauseExpression, LambdaExpression body, Func aggregateFactory, Func resultOperatorFactory) { var fromClause = new MainFromClause(_nameGenerator.GetNewName(), body.Parameters[0].Type, fromClauseExpression); var selectClause = body.Body; - selectClause = ReplacingExpressionTreeVisitor.Replace(body.Parameters[0], - new QuerySourceReferenceExpression( - fromClause), selectClause); - var queryModel = new QueryModel(fromClause, - new SelectClause(aggregateFactory(selectClause))); + selectClause = ReplacingExpressionVisitor.Replace( + body.Parameters[0], + new QuerySourceReferenceExpression(fromClause), + selectClause); + var queryModel = new QueryModel( + fromClause, + new SelectClause(aggregateFactory(selectClause))); // TODO - this sucks, but we use it to get the Type of the SubQueryExpression correct queryModel.ResultOperators.Add(resultOperatorFactory()); diff --git a/src/NHibernate/Linq/ReWriters/QueryReferenceExpressionFlattener.cs b/src/NHibernate/Linq/ReWriters/QueryReferenceExpressionFlattener.cs index 412f1c8fabe..8024560e321 100644 --- a/src/NHibernate/Linq/ReWriters/QueryReferenceExpressionFlattener.cs +++ b/src/NHibernate/Linq/ReWriters/QueryReferenceExpressionFlattener.cs @@ -9,7 +9,7 @@ namespace NHibernate.Linq.ReWriters { - public class QueryReferenceExpressionFlattener : ExpressionTreeVisitor + public class QueryReferenceExpressionFlattener : RelinqExpressionVisitor { private readonly QueryModel _model; @@ -29,16 +29,16 @@ private QueryReferenceExpressionFlattener(QueryModel model) public static void ReWrite(QueryModel model) { var visitor = new QueryReferenceExpressionFlattener(model); - model.TransformExpressions(visitor.VisitExpression); + model.TransformExpressions(visitor.Visit); } - protected override Expression VisitSubQueryExpression(SubQueryExpression subQuery) + protected override Expression VisitSubQuery(SubQueryExpression subQuery) { var subQueryModel = subQuery.QueryModel; var hasBodyClauses = subQueryModel.BodyClauses.Count > 0; if (hasBodyClauses) { - return base.VisitSubQueryExpression(subQuery); + return base.VisitSubQuery(subQuery); } var resultOperators = subQueryModel.ResultOperators; @@ -57,7 +57,7 @@ protected override Expression VisitSubQueryExpression(SubQueryExpression subQuer } } - return base.VisitSubQueryExpression(subQuery); + return base.VisitSubQuery(subQuery); } private static bool HasJustAllFlattenableOperator(IEnumerable resultOperators) @@ -65,7 +65,7 @@ private static bool HasJustAllFlattenableOperator(IEnumerable FlattenableResultOperators.Contains(x.GetType())); } - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { var fromClauseBase = expression.ReferencedQuerySource as FromClauseBase; @@ -76,7 +76,7 @@ fromClauseBase.FromExpression is QuerySourceReferenceExpression && return fromClauseBase.FromExpression; } - return base.VisitQuerySourceReferenceExpression(expression); + return base.VisitQuerySourceReference(expression); } } } \ No newline at end of file diff --git a/src/NHibernate/Linq/ReWriters/RemoveUnnecessaryBodyOperators.cs b/src/NHibernate/Linq/ReWriters/RemoveUnnecessaryBodyOperators.cs index fc42a14eeff..eed32ba581d 100644 --- a/src/NHibernate/Linq/ReWriters/RemoveUnnecessaryBodyOperators.cs +++ b/src/NHibernate/Linq/ReWriters/RemoveUnnecessaryBodyOperators.cs @@ -1,5 +1,6 @@ using System; using System.Linq; +using NHibernate.Linq.Visitors; using Remotion.Linq; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.ResultOperators; @@ -7,7 +8,7 @@ namespace NHibernate.Linq.ReWriters { - public class RemoveUnnecessaryBodyOperators : QueryModelVisitorBase + public class RemoveUnnecessaryBodyOperators : NhQueryModelVisitorBase { private RemoveUnnecessaryBodyOperators() {} diff --git a/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs b/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs index fe1e50f7c40..581c3d8f2c3 100644 --- a/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs +++ b/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs @@ -1,3 +1,4 @@ +using NHibernate.Linq.Visitors; using Remotion.Linq.Parsing; namespace NHibernate.Linq.ReWriters @@ -19,7 +20,7 @@ namespace NHibernate.Linq.ReWriters /// Removes various result operators from a query so that they can be processed at the same /// tree level as the query itself. /// - public class ResultOperatorRewriter : QueryModelVisitorBase + public class ResultOperatorRewriter : NhQueryModelVisitorBase { private readonly List resultOperators = new List(); private IStreamedDataInfo evaluationType; @@ -59,7 +60,7 @@ public override void VisitMainFromClause(MainFromClause fromClause, QueryModel q /// /// Rewrites expressions so that they sit in the outermost portion of the query. /// - private class ResultOperatorExpressionRewriter : ExpressionTreeVisitor + private class ResultOperatorExpressionRewriter : RelinqExpressionVisitor { private static readonly System.Type[] rewrittenTypes = new[] { @@ -91,10 +92,10 @@ public IStreamedDataInfo EvaluationType public Expression Rewrite(Expression expression) { - return VisitExpression(expression); + return Visit(expression); } - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { resultOperators.AddRange( expression.QueryModel.ResultOperators @@ -108,7 +109,7 @@ protected override Expression VisitSubQueryExpression(SubQueryExpression express return expression.QueryModel.MainFromClause.FromExpression; } - return base.VisitSubQueryExpression(expression); + return base.VisitSubQuery(expression); } } } diff --git a/src/NHibernate/Linq/Visitors/EqualityHqlGenerator.cs b/src/NHibernate/Linq/Visitors/EqualityHqlGenerator.cs index 13b36cb2ef6..aad5a031c7f 100644 --- a/src/NHibernate/Linq/Visitors/EqualityHqlGenerator.cs +++ b/src/NHibernate/Linq/Visitors/EqualityHqlGenerator.cs @@ -24,7 +24,7 @@ public HqlBooleanExpression Visit(Expression innerKeySelector, Expression outerK var outerNewExpression = outerKeySelector as NewExpression; return innerNewExpression != null && outerNewExpression != null ? VisitNew(innerNewExpression, outerNewExpression) - : GenerateEqualityNode(innerKeySelector, outerKeySelector, new HqlGeneratorExpressionTreeVisitor(_parameters)); + : GenerateEqualityNode(innerKeySelector, outerKeySelector, new HqlGeneratorExpressionVisitor(_parameters)); } private HqlBooleanExpression VisitNew(NewExpression innerKeySelector, NewExpression outerKeySelector) @@ -46,7 +46,7 @@ private HqlBooleanExpression VisitNew(NewExpression innerKeySelector, NewExpress private HqlEquality GenerateEqualityNode(NewExpression innerKeySelector, NewExpression outerKeySelector, int index) { - return GenerateEqualityNode(innerKeySelector.Arguments[index], outerKeySelector.Arguments[index], new HqlGeneratorExpressionTreeVisitor(_parameters)); + return GenerateEqualityNode(innerKeySelector.Arguments[index], outerKeySelector.Arguments[index], new HqlGeneratorExpressionVisitor(_parameters)); } private HqlEquality GenerateEqualityNode(Expression leftExpr, Expression rightExpr, IHqlExpressionVisitor visitor) diff --git a/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs index ab4618feb71..95a40181595 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs @@ -1,6 +1,7 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Collections.ObjectModel; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -17,7 +18,7 @@ namespace NHibernate.Linq.Visitors /// generate the same key as /// from c in Customers where c.City = "Madrid" /// - public class ExpressionKeyVisitor : ExpressionTreeVisitor + public class ExpressionKeyVisitor : RelinqExpressionVisitor { private readonly IDictionary _constantToParameterMap; readonly StringBuilder _string = new StringBuilder(); @@ -31,7 +32,7 @@ public static string Visit(Expression expression, IDictionary(T expression) where T : Expression { - VisitExpression(expression); + Visit(expression); _string.Append(", "); return expression; } - protected override Expression VisitLambdaExpression(LambdaExpression expression) + protected override Expression VisitLambda(Expression expression) { _string.Append('('); - VisitList(expression.Parameters, AppendCommas); + Visit(expression.Parameters, AppendCommas); _string.Append(") => ("); - VisitExpression(expression.Body); + Visit(expression.Body); _string.Append(')'); return expression; } - protected override Expression VisitMemberExpression(MemberExpression expression) + protected override Expression VisitMember(MemberExpression expression) { - base.VisitMemberExpression(expression); + base.VisitMember(expression); _string.Append('.'); _string.Append(expression.Member.Name); @@ -156,7 +157,7 @@ protected override Expression VisitMemberExpression(MemberExpression expression) } private bool insideSelectClause; - protected override Expression VisitMethodCallExpression(MethodCallExpression expression) + protected override Expression VisitMethodCall(MethodCallExpression expression) { var old = insideSelectClause; @@ -175,39 +176,39 @@ protected override Expression VisitMethodCallExpression(MethodCallExpression exp break; } - VisitExpression(expression.Object); + Visit(expression.Object); _string.Append('.'); VisitMethod(expression.Method); _string.Append('('); - VisitList(expression.Arguments, AppendCommas); + ExpressionVisitor.Visit(expression.Arguments, AppendCommas); _string.Append(')'); insideSelectClause = old; return expression; } - protected override Expression VisitNewExpression(NewExpression expression) + protected override Expression VisitNew(NewExpression expression) { _string.Append("new "); _string.Append(expression.Constructor.DeclaringType.Name); _string.Append('('); - VisitList(expression.Arguments, AppendCommas); + Visit(expression.Arguments, AppendCommas); _string.Append(')'); return expression; } - protected override Expression VisitParameterExpression(ParameterExpression expression) + protected override Expression VisitParameter(ParameterExpression expression) { _string.Append(expression.Name); return expression; } - protected override Expression VisitTypeBinaryExpression(TypeBinaryExpression expression) + protected override Expression VisitTypeBinary(TypeBinaryExpression expression) { _string.Append("IsType("); - VisitExpression(expression.Expression); + Visit(expression.Expression); _string.Append(", "); _string.Append(expression.TypeOperand.FullName); _string.Append(")"); @@ -215,17 +216,17 @@ protected override Expression VisitTypeBinaryExpression(TypeBinaryExpression exp return expression; } - protected override Expression VisitUnaryExpression(UnaryExpression expression) + protected override Expression VisitUnary(UnaryExpression expression) { _string.Append(expression.NodeType); _string.Append('('); - VisitExpression(expression.Operand); + Visit(expression.Operand); _string.Append(')'); return expression; } - protected override Expression VisitQuerySourceReferenceExpression(Remotion.Linq.Clauses.Expressions.QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(Remotion.Linq.Clauses.Expressions.QuerySourceReferenceExpression expression) { _string.Append(expression.ReferencedQuerySource.ItemName); return expression; diff --git a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs index cc3662fb149..11e52cdb23e 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs @@ -14,7 +14,7 @@ namespace NHibernate.Linq.Visitors /// /// Locates constants in the expression tree and generates parameters for each one /// - public class ExpressionParameterVisitor : ExpressionTreeVisitor + public class ExpressionParameterVisitor : RelinqExpressionVisitor { private readonly Dictionary _parameters = new Dictionary(); private readonly ISessionFactoryImplementor _sessionFactory; @@ -48,16 +48,16 @@ internal static IDictionary Visit(ref Expres { var visitor = new ExpressionParameterVisitor(sessionFactory); - expression = visitor.VisitExpression(expression); + expression = visitor.Visit(expression); return visitor._parameters; } - protected override Expression VisitMethodCallExpression(MethodCallExpression expression) + protected override Expression VisitMethodCall(MethodCallExpression expression) { if (expression.Method.Name == nameof(LinqExtensionMethods.MappedAs) && expression.Method.DeclaringType == typeof(LinqExtensionMethods)) { - var rawParameter = VisitExpression(expression.Arguments[0]); + var rawParameter = Visit(expression.Arguments[0]); var parameter = rawParameter as ConstantExpression; var type = expression.Arguments[1] as ConstantExpression; if (parameter == null) @@ -81,7 +81,7 @@ protected override Expression VisitMethodCallExpression(MethodCallExpression exp if (_pagingMethods.Contains(method) && !_sessionFactory.Dialect.SupportsVariableLimit) { //TODO: find a way to make this code cleaner - var query = VisitExpression(expression.Arguments[0]); + var query = Visit(expression.Arguments[0]); var arg = expression.Arguments[1]; if (query == expression.Arguments[0]) @@ -95,10 +95,10 @@ protected override Expression VisitMethodCallExpression(MethodCallExpression exp return expression; } - return base.VisitMethodCallExpression(expression); + return base.VisitMethodCall(expression); } - protected override Expression VisitConstantExpression(ConstantExpression expression) + protected override Expression VisitConstant(ConstantExpression expression) { if (!_parameters.ContainsKey(expression) && !typeof(IQueryable).IsAssignableFrom(expression.Type) && !IsNullObject(expression)) { @@ -125,7 +125,7 @@ protected override Expression VisitConstantExpression(ConstantExpression express _parameters.Add(expression, new NamedParameter("p" + (_parameters.Count + 1), value, type)); } - return base.VisitConstantExpression(expression); + return base.VisitConstant(expression); } private static bool IsNullObject(ConstantExpression expression) diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs similarity index 89% rename from src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs rename to src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index b009496ec56..cb4ec1cca66 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -12,7 +12,7 @@ namespace NHibernate.Linq.Visitors { - public class HqlGeneratorExpressionTreeVisitor : IHqlExpressionVisitor + public class HqlGeneratorExpressionVisitor : IHqlExpressionVisitor { private readonly HqlTreeBuilder _hqlTreeBuilder = new HqlTreeBuilder(); private readonly VisitorParameters _parameters; @@ -20,10 +20,10 @@ public class HqlGeneratorExpressionTreeVisitor : IHqlExpressionVisitor public static HqlTreeNode Visit(Expression expression, VisitorParameters parameters) { - return new HqlGeneratorExpressionTreeVisitor(parameters).VisitExpression(expression); + return new HqlGeneratorExpressionVisitor(parameters).VisitExpression(expression); } - public HqlGeneratorExpressionTreeVisitor(VisitorParameters parameters) + public HqlGeneratorExpressionVisitor(VisitorParameters parameters) { _functionRegistry = parameters.SessionFactory.Settings.LinqToHqlGeneratorsRegistry; _parameters = parameters; @@ -93,7 +93,7 @@ protected HqlTreeNode VisitExpression(Expression expression) case ExpressionType.Call: return VisitMethodCallExpression((MethodCallExpression) expression); //case ExpressionType.New: - // return VisitNewExpression((NewExpression)expression); + // return VisitNew((NewExpression)expression); //case ExpressionType.NewArrayBounds: case ExpressionType.NewArrayInit: return VisitNewArrayExpression((NewArrayExpression) expression); @@ -106,45 +106,43 @@ protected HqlTreeNode VisitExpression(Expression expression) case ExpressionType.TypeIs: return VisitTypeBinaryExpression((TypeBinaryExpression) expression); - default: - var subQueryExpression = expression as SubQueryExpression; - if (subQueryExpression != null) - return VisitSubQueryExpression(subQueryExpression); - - var querySourceReferenceExpression = expression as QuerySourceReferenceExpression; - if (querySourceReferenceExpression != null) - return VisitQuerySourceReferenceExpression(querySourceReferenceExpression); - - var vbStringComparisonExpression = expression as VBStringComparisonExpression; - if (vbStringComparisonExpression != null) - return VisitVBStringComparisonExpression(vbStringComparisonExpression); - - switch ((NhExpressionType) expression.NodeType) + case ExpressionType.Extension: + switch (expression) { - case NhExpressionType.Average: - return VisitNhAverage((NhAverageExpression) expression); - case NhExpressionType.Min: - return VisitNhMin((NhMinExpression) expression); - case NhExpressionType.Max: - return VisitNhMax((NhMaxExpression) expression); - case NhExpressionType.Sum: - return VisitNhSum((NhSumExpression) expression); - case NhExpressionType.Count: - return VisitNhCount((NhCountExpression) expression); - case NhExpressionType.Distinct: - return VisitNhDistinct((NhDistinctExpression) expression); - case NhExpressionType.Star: - return VisitNhStar((NhStarExpression) expression); - case NhExpressionType.Nominator: - return VisitExpression(((NhNominatedExpression) expression).Expression); - //case NhExpressionType.New: - // return VisitNhNew((NhNewExpression)expression); + case SubQueryExpression subQueryExpression: + return VisitSubQueryExpression(subQueryExpression); + case QuerySourceReferenceExpression querySourceReferenceExpression: + return VisitQuerySourceReferenceExpression(querySourceReferenceExpression); + case VBStringComparisonExpression vbStringComparisonExpression: + return VisitVBStringComparisonExpression(vbStringComparisonExpression); + case NhAverageExpression nhAverageExpression: + return VisitNhAverage(nhAverageExpression); + case NhMinExpression nhMinExpression: + return VisitNhMin(nhMinExpression); + case NhMaxExpression nhMaxExpression: + return VisitNhMax(nhMaxExpression); + case NhSumExpression nhSumExpression: + return VisitNhSum(nhSumExpression); + case NhCountExpression nhCountExpression: + return VisitNhCount(nhCountExpression); + case NhDistinctExpression nhDistinctExpression: + return VisitNhDistinct(nhDistinctExpression); + case NhStarExpression nhStarExpression: + return VisitNhStar(nhStarExpression); + case NhNominatedExpression nhNominatedExpression: + return VisitNhNominated(nhNominatedExpression); + //case NhNewExpression nhNewExpression: + // return VisitNhNew(nhNewExpression); + default: + throw new NotSupportedException(expression.ToString()); } - throw new NotSupportedException(expression.ToString()); + default: + throw new NotSupportedException(expression.NodeType.ToString()); } } + private HqlTreeNode VisitTypeBinaryExpression(TypeBinaryExpression expression) { return BuildOfType(expression.Expression, expression.TypeOperand); @@ -202,6 +200,11 @@ protected HqlTreeNode VisitNhStar(NhStarExpression expression) return _hqlTreeBuilder.Star(); } + private HqlTreeNode VisitNhNominated(NhNominatedExpression nhNominatedExpression) + { + return VisitExpression(nhNominatedExpression.Expression); + } + private HqlTreeNode VisitInvocationExpression(InvocationExpression expression) { return VisitExpression(expression.Expression); @@ -238,7 +241,7 @@ protected HqlTreeNode VisitNhSum(NhSumExpression expression) protected HqlTreeNode VisitNhDistinct(NhDistinctExpression expression) { - var visitor = new HqlGeneratorExpressionTreeVisitor(_parameters); + var visitor = new HqlGeneratorExpressionVisitor(_parameters); return _hqlTreeBuilder.ExpressionSubTreeHolder(_hqlTreeBuilder.Distinct(), visitor.VisitExpression(expression.Expression)); } diff --git a/src/NHibernate/Linq/Visitors/JoinBuilder.cs b/src/NHibernate/Linq/Visitors/JoinBuilder.cs index f41ccea5022..b84a72c7ac3 100644 --- a/src/NHibernate/Linq/Visitors/JoinBuilder.cs +++ b/src/NHibernate/Linq/Visitors/JoinBuilder.cs @@ -72,21 +72,21 @@ public bool CanAddJoin(Expression expression) return resultOperatorBase != null && _queryModel.ResultOperators.Contains(resultOperatorBase); } - private class QuerySourceExtractor : ExpressionTreeVisitor + private class QuerySourceExtractor : RelinqExpressionVisitor { private IQuerySource _querySource; public static IQuerySource GetQuerySource(Expression expression) { var sourceExtractor = new QuerySourceExtractor(); - sourceExtractor.VisitExpression(expression); + sourceExtractor.Visit(expression); return sourceExtractor._querySource; } - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { _querySource = expression.ReferencedQuerySource; - return base.VisitQuerySourceReferenceExpression(expression); + return base.VisitQuerySourceReference(expression); } } } diff --git a/src/NHibernate/Linq/Visitors/LeftJoinRewriter.cs b/src/NHibernate/Linq/Visitors/LeftJoinRewriter.cs index 83a2facc0b5..34d8b1ab6c6 100644 --- a/src/NHibernate/Linq/Visitors/LeftJoinRewriter.cs +++ b/src/NHibernate/Linq/Visitors/LeftJoinRewriter.cs @@ -1,15 +1,16 @@ using System.Collections.Generic; using System.Linq; using NHibernate.Linq.Clauses; +using NHibernate.Linq.ReWriters; using Remotion.Linq; using Remotion.Linq.Clauses; -using Remotion.Linq.Clauses.ExpressionTreeVisitors; +using Remotion.Linq.Clauses.ExpressionVisitors; using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Clauses.ResultOperators; namespace NHibernate.Linq.Visitors { - public class LeftJoinRewriter : QueryModelVisitorBase + public class LeftJoinRewriter : NhQueryModelVisitorBase { public static void ReWrite(QueryModel queryModel) { @@ -40,7 +41,7 @@ public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, var innerSelectorMapping = new QuerySourceMapping(); innerSelectorMapping.AddMapping(fromClause, subQueryModel.SelectClause.Selector); - queryModel.TransformExpressions(ex => ReferenceReplacingExpressionTreeVisitor.ReplaceClauseReferences(ex, innerSelectorMapping, false)); + queryModel.TransformExpressions(ex => ReferenceReplacingExpressionVisitor.ReplaceClauseReferences(ex, innerSelectorMapping, false)); queryModel.BodyClauses.RemoveAt(index); queryModel.BodyClauses.Insert(index, @join); @@ -49,7 +50,7 @@ public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, var innerBodyClauseMapping = new QuerySourceMapping(); innerBodyClauseMapping.AddMapping(mainFromClause, new QuerySourceReferenceExpression(@join)); - queryModel.TransformExpressions(ex => ReferenceReplacingExpressionTreeVisitor.ReplaceClauseReferences(ex, innerBodyClauseMapping, false)); + queryModel.TransformExpressions(ex => ReferenceReplacingExpressionVisitor.ReplaceClauseReferences(ex, innerBodyClauseMapping, false)); } private static void InsertBodyClauses(IEnumerable bodyClauses, QueryModel destinationQueryModel, int destinationIndex) diff --git a/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs b/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs index 537b02c674d..624a9aae8ad 100644 --- a/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs +++ b/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs @@ -14,7 +14,7 @@ namespace NHibernate.Linq.Visitors /// Replaces them with appropriate joins, maintaining reference equality between different clauses. /// This allows extracted GroupBy key expression to also be replaced so that they can continue to match replaced Select expressions /// - internal class MemberExpressionJoinDetector : ExpressionTreeVisitor + internal class MemberExpressionJoinDetector : RelinqExpressionVisitor { private readonly IIsEntityDecider _isEntityDecider; private readonly IJoiner _joiner; @@ -30,7 +30,7 @@ public MemberExpressionJoinDetector(IIsEntityDecider isEntityDecider, IJoiner jo _joiner = joiner; } - protected override Expression VisitMemberExpression(MemberExpression expression) + protected override Expression VisitMember(MemberExpression expression) { var isIdentifier = _isEntityDecider.IsIdentifier(expression.Expression.Type, expression.Member.Name); if (isIdentifier) @@ -38,7 +38,7 @@ protected override Expression VisitMemberExpression(MemberExpression expression) if (!isIdentifier) _memberExpressionDepth++; - var result = base.VisitMemberExpression(expression); + var result = base.VisitMember(expression); if (!isIdentifier) _memberExpressionDepth--; @@ -55,33 +55,33 @@ protected override Expression VisitMemberExpression(MemberExpression expression) return result; } - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { - expression.QueryModel.TransformExpressions(VisitExpression); + expression.QueryModel.TransformExpressions(Visit); return expression; } - protected override Expression VisitConditionalExpression(ConditionalExpression expression) + protected override Expression VisitConditional(ConditionalExpression expression) { var oldRequiresJoinForNonIdentifier = _requiresJoinForNonIdentifier; _requiresJoinForNonIdentifier = !_preventJoinsInConditionalTest && _requiresJoinForNonIdentifier; - var newTest = VisitExpression(expression.Test); + var newTest = Visit(expression.Test); _requiresJoinForNonIdentifier = oldRequiresJoinForNonIdentifier; - var newFalse = VisitExpression(expression.IfFalse); - var newTrue = VisitExpression(expression.IfTrue); + var newFalse = Visit(expression.IfFalse); + var newTrue = Visit(expression.IfTrue); if ((newTest != expression.Test) || (newFalse != expression.IfFalse) || (newTrue != expression.IfTrue)) return Expression.Condition(newTest, newTrue, newFalse); return expression; } - protected override Expression VisitExtensionExpression(ExtensionExpression expression) + protected override Expression VisitExtension(Expression expression) { // Nominated expressions need to prevent joins on non-Identifier member expressions // (for the test expression of conditional expressions only) // Otherwise an extra join is created and the GroupBy and Select clauses will not match var old = _preventJoinsInConditionalTest; - _preventJoinsInConditionalTest = (NhExpressionType)expression.NodeType == NhExpressionType.Nominator; - var expr = base.VisitExtensionExpression(expression); + _preventJoinsInConditionalTest = expression is NhNominatedExpression; + var expr = base.VisitExtension(expression); _preventJoinsInConditionalTest = old; return expr; } @@ -90,18 +90,18 @@ public void Transform(SelectClause selectClause) { // The select clause typically requires joins for non-Identifier member access _requiresJoinForNonIdentifier = true; - selectClause.TransformExpressions(VisitExpression); + selectClause.TransformExpressions(Visit); _requiresJoinForNonIdentifier = false; } public void Transform(ResultOperatorBase resultOperator) { - resultOperator.TransformExpressions(VisitExpression); + resultOperator.TransformExpressions(Visit); } public void Transform(Ordering ordering) { - ordering.TransformExpressions(VisitExpression); + ordering.TransformExpressions(Visit); } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Visitors/NhExpressionTreeVisitor.cs b/src/NHibernate/Linq/Visitors/NhExpressionTreeVisitor.cs deleted file mode 100644 index 3fb10519975..00000000000 --- a/src/NHibernate/Linq/Visitors/NhExpressionTreeVisitor.cs +++ /dev/null @@ -1,98 +0,0 @@ -using System; -using System.Linq.Expressions; -using NHibernate.Linq.Expressions; -using Remotion.Linq.Parsing; - -namespace NHibernate.Linq.Visitors -{ - public class NhExpressionTreeVisitor : ExpressionTreeVisitor - { - public override Expression VisitExpression(Expression expression) - { - if (expression == null) - { - return null; - } - - switch ((NhExpressionType)expression.NodeType) - { - case NhExpressionType.Average: - case NhExpressionType.Min: - case NhExpressionType.Max: - case NhExpressionType.Sum: - case NhExpressionType.Count: - case NhExpressionType.Distinct: - return VisitNhAggregate((NhAggregatedExpression)expression); - case NhExpressionType.New: - return VisitNhNew((NhNewExpression)expression); - case NhExpressionType.Star: - return VisitNhStar((NhStarExpression)expression); - } - - // Keep this variable for easy examination during debug. - var expr = base.VisitExpression(expression); - return expr; - } - - protected virtual Expression VisitNhStar(NhStarExpression expression) - { - return expression.Accept(this); - } - - protected virtual Expression VisitNhNew(NhNewExpression expression) - { - return expression.Accept(this); - } - - protected virtual Expression VisitNhAggregate(NhAggregatedExpression expression) - { - switch ((NhExpressionType)expression.NodeType) - { - case NhExpressionType.Average: - return VisitNhAverage((NhAverageExpression)expression); - case NhExpressionType.Min: - return VisitNhMin((NhMinExpression)expression); - case NhExpressionType.Max: - return VisitNhMax((NhMaxExpression)expression); - case NhExpressionType.Sum: - return VisitNhSum((NhSumExpression)expression); - case NhExpressionType.Count: - return VisitNhCount((NhCountExpression)expression); - case NhExpressionType.Distinct: - return VisitNhDistinct((NhDistinctExpression)expression); - default: - throw new ArgumentException(); - } - } - - protected virtual Expression VisitNhDistinct(NhDistinctExpression expression) - { - return expression.Accept(this); - } - - protected virtual Expression VisitNhCount(NhCountExpression expression) - { - return expression.Accept(this); - } - - protected virtual Expression VisitNhSum(NhSumExpression expression) - { - return expression.Accept(this); - } - - protected virtual Expression VisitNhMax(NhMaxExpression expression) - { - return expression.Accept(this); - } - - protected virtual Expression VisitNhMin(NhMinExpression expression) - { - return expression.Accept(this); - } - - protected virtual Expression VisitNhAverage(NhAverageExpression expression) - { - return expression.Accept(this); - } - } -} \ No newline at end of file diff --git a/src/NHibernate/Linq/Visitors/NhExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/NhExpressionVisitor.cs new file mode 100644 index 00000000000..705dbb38121 --- /dev/null +++ b/src/NHibernate/Linq/Visitors/NhExpressionVisitor.cs @@ -0,0 +1,60 @@ +using System; +using System.Linq.Expressions; +using NHibernate.Linq.Expressions; +using Remotion.Linq.Parsing; + +namespace NHibernate.Linq.Visitors +{ + public class NhExpressionVisitor : RelinqExpressionVisitor + { + protected internal virtual Expression VisitNhStar(NhStarExpression expression) + { + return VisitExtension(expression); + } + + protected internal virtual Expression VisitNhNew(NhNewExpression expression) + { + return VisitExtension(expression); + } + + protected internal virtual Expression VisitNhAggregated(NhAggregatedExpression node) + { + return VisitExtension(node); + } + + protected internal virtual Expression VisitNhDistinct(NhDistinctExpression expression) + { + return VisitNhAggregated(expression); + } + + protected internal virtual Expression VisitNhCount(NhCountExpression expression) + { + return VisitNhAggregated(expression); + } + + protected internal virtual Expression VisitNhSum(NhSumExpression expression) + { + return VisitNhAggregated(expression); + } + + protected internal virtual Expression VisitNhMax(NhMaxExpression expression) + { + return VisitNhAggregated(expression); + } + + protected internal virtual Expression VisitNhMin(NhMinExpression expression) + { + return VisitNhAggregated(expression); + } + + protected internal virtual Expression VisitNhAverage(NhAverageExpression expression) + { + return VisitNhAggregated(expression); + } + + protected internal virtual Expression VisitNhNominated(NhNominatedExpression expression) + { + return VisitExtension(expression); + } + } +} diff --git a/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionTreeVisitor.cs b/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionTreeVisitor.cs deleted file mode 100644 index 462051d75b9..00000000000 --- a/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionTreeVisitor.cs +++ /dev/null @@ -1,32 +0,0 @@ -using System.Linq.Expressions; -using Remotion.Linq.Clauses.Expressions; -using Remotion.Linq.Parsing; -using Remotion.Linq.Parsing.ExpressionTreeVisitors; - -namespace NHibernate.Linq.Visitors -{ - internal class NhPartialEvaluatingExpressionTreeVisitor : ExpressionTreeVisitor, IPartialEvaluationExceptionExpressionVisitor - { - protected override Expression VisitConstantExpression(ConstantExpression expression) - { - var value = expression.Value as Expression; - if (value == null) - { - return base.VisitConstantExpression(expression); - } - - return EvaluateIndependentSubtrees(value); - } - - public static Expression EvaluateIndependentSubtrees(Expression expression) - { - var evaluatedExpression = PartialEvaluatingExpressionTreeVisitor.EvaluateIndependentSubtrees(expression); - return new NhPartialEvaluatingExpressionTreeVisitor().VisitExpression(evaluatedExpression); - } - - public Expression VisitPartialEvaluationExceptionExpression(PartialEvaluationExceptionExpression expression) - { - return VisitExpression(expression.Reduce()); - } - } -} \ No newline at end of file diff --git a/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs new file mode 100644 index 00000000000..dec366d8dae --- /dev/null +++ b/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs @@ -0,0 +1,37 @@ +using System.Linq.Expressions; +using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Parsing; +using Remotion.Linq.Parsing.ExpressionVisitors; +using Remotion.Linq.Parsing.ExpressionVisitors.TreeEvaluation; + +namespace NHibernate.Linq.Visitors +{ + internal class NhPartialEvaluatingExpressionVisitor : RelinqExpressionVisitor, IPartialEvaluationExceptionExpressionVisitor + { + protected override Expression VisitConstant(ConstantExpression expression) + { + var value = expression.Value as Expression; + if (value == null) + { + return base.VisitConstant(expression); + } + + return EvaluateIndependentSubtrees(value); + } + + public static Expression EvaluateIndependentSubtrees(Expression expression) + { + var evaluatedExpression = PartialEvaluatingExpressionVisitor.EvaluateIndependentSubtrees(expression, new NullEvaluatableExpressionFilter()); + return new NhPartialEvaluatingExpressionVisitor().Visit(evaluatedExpression); + } + + public Expression VisitPartialEvaluationException(PartialEvaluationExceptionExpression expression) + { + return Visit(expression.Reduce()); + } + } + + internal class NullEvaluatableExpressionFilter : EvaluatableExpressionFilterBase + { + } +} diff --git a/src/NHibernate/Linq/Visitors/NhQueryModelVisitorBase.cs b/src/NHibernate/Linq/Visitors/NhQueryModelVisitorBase.cs new file mode 100644 index 00000000000..26e2046a6f3 --- /dev/null +++ b/src/NHibernate/Linq/Visitors/NhQueryModelVisitorBase.cs @@ -0,0 +1,20 @@ +using NHibernate.Linq.Clauses; +using Remotion.Linq; + +namespace NHibernate.Linq.Visitors +{ + public class NhQueryModelVisitorBase : QueryModelVisitorBase, INhQueryModelVisitor + { + public virtual void VisitNhHavingClause(NhHavingClause havingClause, QueryModel queryModel, int index) + { + } + + public virtual void VisitNhJoinClause(NhJoinClause joinClause, QueryModel queryModel, int index) + { + } + + public virtual void VisitNhWithClause(NhWithClause nhWhereClause, QueryModel queryModel, int index) + { + } + } +} diff --git a/src/NHibernate/Linq/Visitors/PagingRewriterSelectClauseVisitor.cs b/src/NHibernate/Linq/Visitors/PagingRewriterSelectClauseVisitor.cs index 87c0f90359f..5e996b1f311 100644 --- a/src/NHibernate/Linq/Visitors/PagingRewriterSelectClauseVisitor.cs +++ b/src/NHibernate/Linq/Visitors/PagingRewriterSelectClauseVisitor.cs @@ -2,11 +2,11 @@ using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Parsing; -using Remotion.Linq.Parsing.ExpressionTreeVisitors; +using Remotion.Linq.Parsing.ExpressionVisitors; namespace NHibernate.Linq.Visitors { - internal class PagingRewriterSelectClauseVisitor : ExpressionTreeVisitor + internal class PagingRewriterSelectClauseVisitor : RelinqExpressionVisitor { private readonly FromClauseBase querySource; @@ -17,18 +17,18 @@ public PagingRewriterSelectClauseVisitor(FromClauseBase querySource) public Expression Swap(Expression expression) { - return TransparentIdentifierRemovingExpressionTreeVisitor.ReplaceTransparentIdentifiers(VisitExpression(expression)); + return TransparentIdentifierRemovingExpressionVisitor.ReplaceTransparentIdentifiers(Visit(expression)); } - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { var innerSelector = GetSubQuerySelectorOrNull(expression); if (innerSelector != null) { - return VisitExpression(innerSelector); + return Visit(innerSelector); } - return base.VisitQuerySourceReferenceExpression(expression); + return base.VisitQuerySourceReference(expression); } /// diff --git a/src/NHibernate/Linq/Visitors/QueryExpressionSourceIdentifer.cs b/src/NHibernate/Linq/Visitors/QueryExpressionSourceIdentifer.cs index d0dcd2a0668..87003308674 100644 --- a/src/NHibernate/Linq/Visitors/QueryExpressionSourceIdentifer.cs +++ b/src/NHibernate/Linq/Visitors/QueryExpressionSourceIdentifer.cs @@ -4,7 +4,7 @@ namespace NHibernate.Linq.Visitors { - public class QueryExpressionSourceIdentifer : ExpressionTreeVisitor + public class QueryExpressionSourceIdentifer : RelinqExpressionVisitor { private readonly QuerySourceIdentifier _identifier; @@ -13,10 +13,10 @@ public QueryExpressionSourceIdentifer(QuerySourceIdentifier identifier) _identifier = identifier; } - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { _identifier.VisitQueryModel(expression.QueryModel); - return base.VisitSubQueryExpression(expression); + return base.VisitSubQuery(expression); } } } \ No newline at end of file diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index 3e4a1c863cf..aebe8f1ed83 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -20,7 +20,7 @@ namespace NHibernate.Linq.Visitors { - public class QueryModelVisitor : QueryModelVisitorBase + public class QueryModelVisitor : NhQueryModelVisitorBase, INhQueryModelVisitor { public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel queryModel, VisitorParameters parameters, bool root, NhLinqExpressionReturnType? rootReturnType) @@ -152,23 +152,23 @@ private void AddAdditionalPostExecuteTransformer() if (_rootReturnType == NhLinqExpressionReturnType.Scalar && Model.ResultTypeOverride != null) { // NH-3850: handle polymorphic scalar results aggregation - switch ((NhExpressionType)Model.SelectClause.Selector.NodeType) + switch (Model.SelectClause.Selector) { - case NhExpressionType.Average: + case NhAverageExpression _: // Polymorphic case complex to handle and not implemented. (HQL query must be reshaped for adding // additional data to allow a meaningful overall average computation.) // Leaving it untouched for allowing non polymorphic cases to work. break; - case NhExpressionType.Count: + case NhCountExpression _: AddPostExecuteTransformerForCount(); break; - case NhExpressionType.Max: + case NhMaxExpression _: AddPostExecuteTransformerForResultAggregate(ReflectionCache.EnumerableMethods.MaxDefinition); break; - case NhExpressionType.Min: + case NhMinExpression _: AddPostExecuteTransformerForResultAggregate(ReflectionCache.EnumerableMethods.MinDefinition); break; - case NhExpressionType.Sum: + case NhSumExpression _: AddPostExecuteTransformerForSum(); break; } @@ -281,7 +281,7 @@ private MethodCallExpression GetAggregateMethodCall(MethodInfo aggregateMethodTe public override void VisitMainFromClause(MainFromClause fromClause, QueryModel queryModel) { var querySourceName = VisitorParameters.QuerySourceNamer.GetName(fromClause); - var hqlExpressionTree = HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, VisitorParameters); + var hqlExpressionTree = HqlGeneratorExpressionVisitor.Visit(fromClause.FromExpression, VisitorParameters); _hqlTree.AddFromClause(_hqlTree.TreeBuilder.Range(hqlExpressionTree, _hqlTree.TreeBuilder.Alias(querySourceName))); @@ -302,17 +302,12 @@ public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, { var querySourceName = VisitorParameters.QuerySourceNamer.GetName(fromClause); - var joinClause = fromClause as NhJoinClause; - if (joinClause != null) - { - VisitNhJoinClause(querySourceName, joinClause); - } - else if (fromClause.FromExpression is MemberExpression) + if (fromClause.FromExpression is MemberExpression) { // It's a join _hqlTree.AddFromClause( _hqlTree.TreeBuilder.Join( - HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, VisitorParameters).AsExpression(), + HqlGeneratorExpressionVisitor.Visit(fromClause.FromExpression, VisitorParameters).AsExpression(), _hqlTree.TreeBuilder.Alias(querySourceName))); } else @@ -320,7 +315,7 @@ public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, // TODO - exact same code as in MainFromClause; refactor this out _hqlTree.AddFromClause( _hqlTree.TreeBuilder.Range( - HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, VisitorParameters), + HqlGeneratorExpressionVisitor.Visit(fromClause.FromExpression, VisitorParameters), _hqlTree.TreeBuilder.Alias(querySourceName))); } @@ -328,24 +323,26 @@ public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, base.VisitAdditionalFromClause(fromClause, queryModel, index); } - private void VisitNhJoinClause(string querySourceName, NhJoinClause joinClause) + public override void VisitNhJoinClause(NhJoinClause joinClause, QueryModel queryModel, int index) { - var expression = HqlGeneratorExpressionTreeVisitor.Visit(joinClause.FromExpression, VisitorParameters).AsExpression(); + var querySourceName = VisitorParameters.QuerySourceNamer.GetName(joinClause); + + var expression = HqlGeneratorExpressionVisitor.Visit(joinClause.FromExpression, VisitorParameters).AsExpression(); var alias = _hqlTree.TreeBuilder.Alias(querySourceName); HqlTreeNode hqlJoin; if (joinClause.IsInner) { - hqlJoin = _hqlTree.TreeBuilder.Join(expression, @alias); + hqlJoin = _hqlTree.TreeBuilder.Join(expression, alias); } else { - hqlJoin = _hqlTree.TreeBuilder.LeftJoin(expression, @alias); + hqlJoin = _hqlTree.TreeBuilder.LeftJoin(expression, alias); } foreach (var withClause in joinClause.Restrictions) { - var booleanExpression = HqlGeneratorExpressionTreeVisitor.Visit(withClause.Predicate, VisitorParameters).ToBooleanExpression(); + var booleanExpression = HqlGeneratorExpressionVisitor.Visit(withClause.Predicate, VisitorParameters).ToBooleanExpression(); hqlJoin.AddChild(_hqlTree.TreeBuilder.With(booleanExpression)); } @@ -378,7 +375,7 @@ public override void VisitSelectClause(SelectClause selectClause, QueryModel que var visitor = new SelectClauseVisitor(typeof(object[]), VisitorParameters); - visitor.Visit(selectClause.Selector); + visitor.VisitSelector(selectClause.Selector); if (visitor.ProjectionExpression != null) { @@ -393,25 +390,18 @@ public override void VisitSelectClause(SelectClause selectClause, QueryModel que public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index) { var visitor = new SimplifyConditionalVisitor(); - whereClause.Predicate = visitor.VisitExpression(whereClause.Predicate); + whereClause.Predicate = visitor.Visit(whereClause.Predicate); // Visit the predicate to build the query - var expression = HqlGeneratorExpressionTreeVisitor.Visit(whereClause.Predicate, VisitorParameters).ToBooleanExpression(); - if (whereClause is NhHavingClause) - { - _hqlTree.AddHavingClause(expression); - } - else - { - _hqlTree.AddWhereClause(expression); - } + var expression = HqlGeneratorExpressionVisitor.Visit(whereClause.Predicate, VisitorParameters).ToBooleanExpression(); + _hqlTree.AddWhereClause(expression); } public override void VisitOrderByClause(OrderByClause orderByClause, QueryModel queryModel, int index) { foreach (var clause in orderByClause.Orderings) { - _hqlTree.AddOrderByClause(HqlGeneratorExpressionTreeVisitor.Visit(clause.Expression, VisitorParameters).AsExpression(), + _hqlTree.AddOrderByClause(HqlGeneratorExpressionVisitor.Visit(clause.Expression, VisitorParameters).AsExpression(), clause.OrderingDirection == OrderingDirection.Asc ? _hqlTree.TreeBuilder.Ascending() : (HqlDirectionStatement)_hqlTree.TreeBuilder.Descending()); @@ -427,7 +417,7 @@ public override void VisitJoinClause(JoinClause joinClause, QueryModel queryMode _hqlTree.AddFromClause( _hqlTree.TreeBuilder.Range( - HqlGeneratorExpressionTreeVisitor.Visit(joinClause.InnerSequence, VisitorParameters), + HqlGeneratorExpressionVisitor.Visit(joinClause.InnerSequence, VisitorParameters), _hqlTree.TreeBuilder.Alias(joinClause.ItemName))); } @@ -435,5 +425,25 @@ public override void VisitGroupJoinClause(GroupJoinClause groupJoinClause, Query { throw new NotImplementedException(); } + + public override void VisitNhHavingClause(NhHavingClause havingClause, QueryModel queryModel, int index) + { + var visitor = new SimplifyConditionalVisitor(); + havingClause.Predicate = visitor.Visit(havingClause.Predicate); + + // Visit the predicate to build the query + var expression = HqlGeneratorExpressionVisitor.Visit(havingClause.Predicate, VisitorParameters).ToBooleanExpression(); + _hqlTree.AddHavingClause(expression); + } + + public override void VisitNhWithClause(NhWithClause withClause, QueryModel queryModel, int index) + { + var visitor = new SimplifyConditionalVisitor(); + withClause.Predicate = visitor.Visit(withClause.Predicate); + + // Visit the predicate to build the query + var expression = HqlGeneratorExpressionVisitor.Visit(withClause.Predicate, VisitorParameters).ToBooleanExpression(); + _hqlTree.AddWhereClause(expression); + } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Visitors/QuerySourceIdentifier.cs b/src/NHibernate/Linq/Visitors/QuerySourceIdentifier.cs index 72a06822512..475a13050c8 100644 --- a/src/NHibernate/Linq/Visitors/QuerySourceIdentifier.cs +++ b/src/NHibernate/Linq/Visitors/QuerySourceIdentifier.cs @@ -1,3 +1,4 @@ +using NHibernate.Linq.Clauses; using Remotion.Linq; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.ResultOperators; @@ -13,7 +14,7 @@ namespace NHibernate.Linq.Visitors /// the HQL expression tree) means a query source may be referenced by a QuerySourceReference /// before it has been identified - and named. /// - public class QuerySourceIdentifier : QueryModelVisitorBase + public class QuerySourceIdentifier : NhQueryModelVisitorBase { private readonly QuerySourceNamer _namer; @@ -52,6 +53,11 @@ public override void VisitJoinClause(JoinClause joinClause, QueryModel queryMode _namer.Add(joinClause); } + public override void VisitNhJoinClause(NhJoinClause joinClause, QueryModel queryModel, int index) + { + _namer.Add(joinClause); + } + public override void VisitResultOperator(ResultOperatorBase resultOperator, QueryModel queryModel, int index) { var groupBy = resultOperator as GroupResultOperator; @@ -62,7 +68,7 @@ public override void VisitResultOperator(ResultOperatorBase resultOperator, Quer public override void VisitSelectClause(SelectClause selectClause, QueryModel queryModel) { //Find nested query sources - new QueryExpressionSourceIdentifer(this).VisitExpression(selectClause.Selector); + new QueryExpressionSourceIdentifer(this).Visit(selectClause.Selector); } public QuerySourceNamer Namer { get { return _namer; } } diff --git a/src/NHibernate/Linq/Visitors/QuerySourceLocator.cs b/src/NHibernate/Linq/Visitors/QuerySourceLocator.cs index c6a80ddd380..54ad4e4c24a 100644 --- a/src/NHibernate/Linq/Visitors/QuerySourceLocator.cs +++ b/src/NHibernate/Linq/Visitors/QuerySourceLocator.cs @@ -1,27 +1,27 @@ -using Remotion.Linq; +using NHibernate.Linq.Clauses; +using Remotion.Linq; using Remotion.Linq.Clauses; -using Remotion.Linq.Collections; namespace NHibernate.Linq.Visitors { - public class QuerySourceLocator : QueryModelVisitorBase - { - private readonly System.Type _type; - private IQuerySource _querySource; + public class QuerySourceLocator : NhQueryModelVisitorBase + { + readonly System.Type _type; + IQuerySource _querySource; - private QuerySourceLocator(System.Type type) - { - _type = type; - } + QuerySourceLocator(System.Type type) + { + _type = type; + } - public static IQuerySource FindQuerySource(QueryModel queryModel, System.Type type) - { - var finder = new QuerySourceLocator(type); + public static IQuerySource FindQuerySource(QueryModel queryModel, System.Type type) + { + var finder = new QuerySourceLocator(type); - finder.VisitQueryModel(queryModel); + finder.VisitQueryModel(queryModel); - return finder._querySource; - } + return finder._querySource; + } public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, QueryModel queryModel, int index) { @@ -37,16 +37,30 @@ public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, base.VisitAdditionalFromClause(fromClause, queryModel, index); } - public override void VisitMainFromClause(MainFromClause fromClause, QueryModel queryModel) - { - if (_type.IsAssignableFrom(fromClause.ItemType)) - { - _querySource = fromClause; - } - else - { - base.VisitMainFromClause(fromClause, queryModel); - } - } - } -} \ No newline at end of file + public override void VisitNhJoinClause(NhJoinClause joinClause, QueryModel queryModel, int index) + { + if (_type.IsAssignableFrom(joinClause.ItemType)) + { + if (_querySource == null) + { + _querySource = joinClause; + return; + } + } + + base.VisitNhJoinClause(joinClause, queryModel, index); + } + + public override void VisitMainFromClause(MainFromClause fromClause, QueryModel queryModel) + { + if (_type.IsAssignableFrom(fromClause.ItemType)) + { + _querySource = fromClause; + } + else + { + base.VisitMainFromClause(fromClause, queryModel); + } + } + } +} diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregate.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregate.cs index 7a2b5a819e3..574533820b7 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregate.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregate.cs @@ -3,7 +3,7 @@ using NHibernate.Util; using Remotion.Linq.Clauses.ResultOperators; using Remotion.Linq.Clauses.StreamedData; -using Remotion.Linq.Parsing.ExpressionTreeVisitors; +using Remotion.Linq.Parsing.ExpressionVisitors; namespace NHibernate.Linq.Visitors.ResultOperatorProcessors { @@ -15,7 +15,7 @@ public void Process(AggregateResultOperator resultOperator, QueryModelVisitor qu var inputType = inputExpr.Type; var paramExpr = Expression.Parameter(inputType, "item"); var accumulatorFunc = Expression.Lambda( - ReplacingExpressionTreeVisitor.Replace(inputExpr, paramExpr, resultOperator.Func.Body), + ReplacingExpressionVisitor.Replace(inputExpr, paramExpr, resultOperator.Func.Body), resultOperator.Func.Parameters[0], paramExpr); diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregateFromSeed.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregateFromSeed.cs index 4e55fe6382c..d369066a575 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregateFromSeed.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregateFromSeed.cs @@ -3,7 +3,7 @@ using NHibernate.Util; using Remotion.Linq.Clauses.ResultOperators; using Remotion.Linq.Clauses.StreamedData; -using Remotion.Linq.Parsing.ExpressionTreeVisitors; +using Remotion.Linq.Parsing.ExpressionVisitors; namespace NHibernate.Linq.Visitors.ResultOperatorProcessors { @@ -15,7 +15,7 @@ public void Process(AggregateFromSeedResultOperator resultOperator, QueryModelVi var inputType = inputExpr.Type; var paramExpr = Expression.Parameter(inputType, "item"); var accumulatorFunc = Expression.Lambda( - ReplacingExpressionTreeVisitor.Replace(inputExpr, paramExpr, resultOperator.Func.Body), + ReplacingExpressionVisitor.Replace(inputExpr, paramExpr, resultOperator.Func.Body), resultOperator.Func.Parameters[0], paramExpr); diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAll.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAll.cs index 692365a5d89..b22237e6f58 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAll.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAll.cs @@ -12,7 +12,7 @@ public class ProcessAll : IResultOperatorProcessor public void Process(AllResultOperator resultOperator, QueryModelVisitor queryModelVisitor, IntermediateHqlTree tree) { tree.AddWhereClause(tree.TreeBuilder.BooleanNot( - HqlGeneratorExpressionTreeVisitor.Visit(resultOperator.Predicate, queryModelVisitor.VisitorParameters). + HqlGeneratorExpressionVisitor.Visit(resultOperator.Predicate, queryModelVisitor.VisitorParameters). ToBooleanExpression())); if (tree.IsRoot) diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessContains.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessContains.cs index 01928d78999..17fc7850425 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessContains.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessContains.cs @@ -10,7 +10,7 @@ public class ProcessContains : IResultOperatorProcessor public void Process(ContainsResultOperator resultOperator, QueryModelVisitor queryModelVisitor, IntermediateHqlTree tree) { var itemExpression = - HqlGeneratorExpressionTreeVisitor.Visit(resultOperator.Item, queryModelVisitor.VisitorParameters) + HqlGeneratorExpressionVisitor.Visit(resultOperator.Item, queryModelVisitor.VisitorParameters) .AsExpression(); var from = GetFromRangeClause(tree.Root); diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessGroupBy.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessGroupBy.cs index 0c4e19db231..2ee37b5387b 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessGroupBy.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessGroupBy.cs @@ -16,7 +16,7 @@ public void Process(GroupResultOperator resultOperator, QueryModelVisitor queryM else groupByKeys = new[] {resultOperator.KeySelector}; - IEnumerable hqlGroupByKeys = groupByKeys.Select(k => HqlGeneratorExpressionTreeVisitor.Visit(k, queryModelVisitor.VisitorParameters).AsExpression()); + IEnumerable hqlGroupByKeys = groupByKeys.Select(k => HqlGeneratorExpressionVisitor.Visit(k, queryModelVisitor.VisitorParameters).AsExpression()); tree.AddGroupByClause(tree.TreeBuilder.GroupBy(hqlGroupByKeys.ToArray())); } diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessNonAggregatingGroupBy.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessNonAggregatingGroupBy.cs index abaf20e1471..21c82a87eb0 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessNonAggregatingGroupBy.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessNonAggregatingGroupBy.cs @@ -3,7 +3,7 @@ using System.Linq.Expressions; using NHibernate.Linq.ResultOperators; using NHibernate.Util; -using Remotion.Linq.Clauses.ExpressionTreeVisitors; +using Remotion.Linq.Clauses.ExpressionVisitors; namespace NHibernate.Linq.Visitors.ResultOperatorProcessors { @@ -22,9 +22,9 @@ public void Process(NonAggregatingGroupBy resultOperator, QueryModelVisitor quer // Stuff in the group by that doesn't map to HQL. Run it client-side var listParameter = Expression.Parameter(typeof(IEnumerable), "list"); - var keySelectorExpr = ReverseResolvingExpressionTreeVisitor.ReverseResolve(selector, keySelector); + var keySelectorExpr = ReverseResolvingExpressionVisitor.ReverseResolve(selector, keySelector); - var elementSelectorExpr = ReverseResolvingExpressionTreeVisitor.ReverseResolve(selector, elementSelector); + var elementSelectorExpr = ReverseResolvingExpressionVisitor.ReverseResolve(selector, elementSelector); var groupByMethod = ReflectionCache.EnumerableMethods.GroupByWithElementSelectorDefinition .MakeGenericMethod(new[] { sourceType, keyType, elementType }); diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessOfType.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessOfType.cs index 4da0271fee2..f6137239277 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessOfType.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessOfType.cs @@ -8,7 +8,7 @@ public void Process(OfTypeResultOperator resultOperator, QueryModelVisitor query { var source = queryModelVisitor.Model.SelectClause.GetOutputDataInfo().ItemExpression; - var expression = new HqlGeneratorExpressionTreeVisitor(queryModelVisitor.VisitorParameters) + var expression = new HqlGeneratorExpressionVisitor(queryModelVisitor.VisitorParameters) .BuildOfType(source, resultOperator.SearchedItemType); tree.AddWhereClause(expression); diff --git a/src/NHibernate/Linq/Visitors/SelectClauseNominator.cs b/src/NHibernate/Linq/Visitors/SelectClauseNominator.cs index 4dae4f2aa9e..2db121b9316 100644 --- a/src/NHibernate/Linq/Visitors/SelectClauseNominator.cs +++ b/src/NHibernate/Linq/Visitors/SelectClauseNominator.cs @@ -12,7 +12,7 @@ namespace NHibernate.Linq.Visitors /// Analyze the select clause to determine what parts can be translated /// fully to HQL, and some other properties of the clause. /// - class SelectClauseHqlNominator : ExpressionTreeVisitor + class SelectClauseHqlNominator : RelinqExpressionVisitor { private readonly ILinqToHqlGeneratorsRegistry _functionRegistry; @@ -37,7 +37,7 @@ public SelectClauseHqlNominator(VisitorParameters parameters) _functionRegistry = parameters.SessionFactory.Settings.LinqToHqlGeneratorsRegistry; } - internal Expression Visit(Expression expression) + internal Expression Nominate(Expression expression) { HqlCandidates = new HashSet(); ContainsUntranslatedMethodCalls = false; @@ -45,18 +45,18 @@ internal Expression Visit(Expression expression) _stateStack = new Stack(); _stateStack.Push(false); - return VisitExpression(expression); + return Visit(expression); } - public override Expression VisitExpression(Expression expression) + public override Expression Visit(Expression expression) { if (expression == null) return null; - if (expression.NodeType == (ExpressionType)NhExpressionType.Nominator) + if (expression is NhNominatedExpression nominatedExpression) { // Add the nominated clause and strip the nominator wrapper from the select expression - var innerExpression = ((NhNominatedExpression)expression).Expression; + var innerExpression = nominatedExpression.Expression; HqlCandidates.Add(innerExpression); return innerExpression; } @@ -86,7 +86,7 @@ public override Expression VisitExpression(Expression expression) return expression; } - expression = base.VisitExpression(expression); + expression = base.Visit(expression); if (_canBeCandidate) { @@ -118,14 +118,14 @@ private bool IsRegisteredFunction(Expression expression) if (_functionRegistry.TryGetGenerator(methodCallExpression.Method, out methodGenerator)) { return methodCallExpression.Object == null || // is static or extension method - methodCallExpression.Object.NodeType != ExpressionType.Constant; // does not belong to parameter + methodCallExpression.Object.NodeType != ExpressionType.Constant; // does not belong to parameter } } - else if (expression.NodeType == (ExpressionType)NhExpressionType.Sum || - expression.NodeType == (ExpressionType)NhExpressionType.Count || - expression.NodeType == (ExpressionType)NhExpressionType.Average || - expression.NodeType == (ExpressionType)NhExpressionType.Max || - expression.NodeType == (ExpressionType)NhExpressionType.Min) + else if (expression is NhSumExpression || + expression is NhCountExpression || + expression is NhAverageExpression || + expression is NhMaxExpression || + expression is NhMinExpression) { return true; } @@ -172,7 +172,7 @@ private bool CanBeEvaluatedInHqlSelectStatement(Expression expression, bool proj private static bool CanBeEvaluatedInHqlStatementShortcut(Expression expression) { - return ((NhExpressionType)expression.NodeType) == NhExpressionType.Count; + return expression is NhCountExpression; } } } \ No newline at end of file diff --git a/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs b/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs index 4b2254e53e3..ca1318260d2 100644 --- a/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs +++ b/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs @@ -8,7 +8,7 @@ namespace NHibernate.Linq.Visitors { - public class SelectClauseVisitor : ExpressionTreeVisitor + public class SelectClauseVisitor : RelinqExpressionVisitor { private readonly HqlTreeBuilder _hqlTreeBuilder = new HqlTreeBuilder(); private HashSet _hqlNodes; @@ -16,13 +16,13 @@ public class SelectClauseVisitor : ExpressionTreeVisitor private readonly VisitorParameters _parameters; private int _iColumn; private List _hqlTreeNodes = new List(); - private readonly HqlGeneratorExpressionTreeVisitor _hqlVisitor; + private readonly HqlGeneratorExpressionVisitor _hqlVisitor; public SelectClauseVisitor(System.Type inputType, VisitorParameters parameters) { _inputParameter = Expression.Parameter(inputType, "input"); _parameters = parameters; - _hqlVisitor = new HqlGeneratorExpressionTreeVisitor(_parameters); + _hqlVisitor = new HqlGeneratorExpressionVisitor(_parameters); } public LambdaExpression ProjectionExpression { get; private set; } @@ -32,7 +32,7 @@ public IEnumerable GetHqlNodes() return _hqlTreeNodes; } - public void Visit(Expression expression) + public void VisitSelector(Expression expression) { var distinct = expression as NhDistinctExpression; if (distinct != null) @@ -42,7 +42,7 @@ public void Visit(Expression expression) // Find the sub trees that can be expressed purely in HQL var nominator = new SelectClauseHqlNominator(_parameters); - expression = nominator.Visit(expression); + expression = nominator.Nominate(expression); _hqlNodes = nominator.HqlCandidates; // Linq2SQL ignores calls to local methods. Linq2EF seems to not support @@ -53,7 +53,7 @@ public void Visit(Expression expression) throw new NotSupportedException("Cannot use distinct on result that depends on methods for which no SQL equivalent exist."); // Now visit the tree - var projection = VisitExpression(expression); + var projection = Visit(expression); if ((projection != expression) && !_hqlNodes.Contains(expression)) { @@ -71,7 +71,7 @@ public void Visit(Expression expression) } } - public override Expression VisitExpression(Expression expression) + public override Expression Visit(Expression expression) { if (expression == null) { @@ -87,7 +87,7 @@ public override Expression VisitExpression(Expression expression) } // Can't handle this node with HQL. Just recurse down, and emit the expression - return base.VisitExpression(expression); + return base.Visit(expression); } } diff --git a/src/NHibernate/Linq/Visitors/SimplifyConditionalVisitor.cs b/src/NHibernate/Linq/Visitors/SimplifyConditionalVisitor.cs index 3263fb4d235..20047942e40 100644 --- a/src/NHibernate/Linq/Visitors/SimplifyConditionalVisitor.cs +++ b/src/NHibernate/Linq/Visitors/SimplifyConditionalVisitor.cs @@ -7,26 +7,26 @@ namespace NHibernate.Linq.Visitors /// /// Some conditional expressions can be reduced to just their IfTrue or IfFalse part. /// - internal class SimplifyConditionalVisitor :ExpressionTreeVisitor + internal class SimplifyConditionalVisitor :RelinqExpressionVisitor { - protected override Expression VisitConditionalExpression(ConditionalExpression expression) + protected override Expression VisitConditional(ConditionalExpression expression) { - var testExpression = VisitExpression(expression.Test); + var testExpression = Visit(expression.Test); bool testExprResult; if (VisitorUtil.IsBooleanConstant(testExpression, out testExprResult)) { if (testExprResult) - return VisitExpression(expression.IfTrue); + return Visit(expression.IfTrue); - return VisitExpression(expression.IfFalse); + return Visit(expression.IfFalse); } - return base.VisitConditionalExpression(expression); + return base.VisitConditional(expression); } - protected override Expression VisitBinaryExpression(BinaryExpression expression) + protected override Expression VisitBinary(BinaryExpression expression) { // See NH-3423. Conditional expression where the test expression is a comparison // of a construction expression and null will happen in WCF DS. @@ -42,7 +42,7 @@ protected override Expression VisitBinaryExpression(BinaryExpression expression) return Expression.Constant(true); } - return base.VisitBinaryExpression(expression); + return base.VisitBinary(expression); } diff --git a/src/NHibernate/Linq/Visitors/SubQueryFromClauseFlattener.cs b/src/NHibernate/Linq/Visitors/SubQueryFromClauseFlattener.cs index 56eace76c69..d8d1f2ca58d 100644 --- a/src/NHibernate/Linq/Visitors/SubQueryFromClauseFlattener.cs +++ b/src/NHibernate/Linq/Visitors/SubQueryFromClauseFlattener.cs @@ -3,12 +3,12 @@ using Remotion.Linq; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; -using Remotion.Linq.Clauses.ExpressionTreeVisitors; +using Remotion.Linq.Clauses.ExpressionVisitors; using Remotion.Linq.EagerFetching; namespace NHibernate.Linq.Visitors { - public class SubQueryFromClauseFlattener : QueryModelVisitorBase + public class SubQueryFromClauseFlattener : NhQueryModelVisitorBase { private static readonly System.Type[] FlattenableResultOperators = { @@ -70,14 +70,14 @@ private static void FlattenSubQuery(SubQueryExpression subQueryExpression, FromC var innerSelectorMapping = new QuerySourceMapping(); innerSelectorMapping.AddMapping(fromClause, subQueryExpression.QueryModel.SelectClause.Selector); - queryModel.TransformExpressions(ex => ReferenceReplacingExpressionTreeVisitor.ReplaceClauseReferences(ex, innerSelectorMapping, false)); + queryModel.TransformExpressions(ex => ReferenceReplacingExpressionVisitor.ReplaceClauseReferences(ex, innerSelectorMapping, false)); InsertBodyClauses(subQueryExpression.QueryModel.BodyClauses, queryModel, destinationIndex); InsertResultOperators(subQueryExpression.QueryModel.ResultOperators, queryModel); var innerBodyClauseMapping = new QuerySourceMapping(); innerBodyClauseMapping.AddMapping(mainFromClause, new QuerySourceReferenceExpression(fromClause)); - queryModel.TransformExpressions(ex => ReferenceReplacingExpressionTreeVisitor.ReplaceClauseReferences(ex, innerBodyClauseMapping, false)); + queryModel.TransformExpressions(ex => ReferenceReplacingExpressionVisitor.ReplaceClauseReferences(ex, innerBodyClauseMapping, false)); } internal static void InsertResultOperators(IEnumerable resultOperators, QueryModel queryModel) diff --git a/src/NHibernate/Linq/Visitors/SwapQuerySourceVisitor.cs b/src/NHibernate/Linq/Visitors/SwapQuerySourceVisitor.cs index 70181397bc0..97ae577e202 100644 --- a/src/NHibernate/Linq/Visitors/SwapQuerySourceVisitor.cs +++ b/src/NHibernate/Linq/Visitors/SwapQuerySourceVisitor.cs @@ -5,7 +5,7 @@ namespace NHibernate.Linq.Visitors { - public class SwapQuerySourceVisitor : ExpressionTreeVisitor + public class SwapQuerySourceVisitor : RelinqExpressionVisitor { private readonly IQuerySource _oldClause; private readonly IQuerySource _newClause; @@ -18,10 +18,10 @@ public SwapQuerySourceVisitor(IQuerySource oldClause, IQuerySource newClause) public Expression Swap(Expression expression) { - return VisitExpression(expression); + return Visit(expression); } - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { if (expression.ReferencedQuerySource == _oldClause) { @@ -33,16 +33,16 @@ protected override Expression VisitQuerySourceReferenceExpression(QuerySourceRef if (mainFromClause != null) { - mainFromClause.FromExpression = VisitExpression(mainFromClause.FromExpression); + mainFromClause.FromExpression = Visit(mainFromClause.FromExpression); } return expression; } - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { - expression.QueryModel.TransformExpressions(VisitExpression); - return base.VisitSubQueryExpression(expression); + expression.QueryModel.TransformExpressions(Visit); + return base.VisitSubQuery(expression); } } } \ No newline at end of file diff --git a/src/NHibernate/Linq/Visitors/VisitorUtil.cs b/src/NHibernate/Linq/Visitors/VisitorUtil.cs index 5b559b2eb32..b9d531f6d95 100644 --- a/src/NHibernate/Linq/Visitors/VisitorUtil.cs +++ b/src/NHibernate/Linq/Visitors/VisitorUtil.cs @@ -38,10 +38,10 @@ public static bool IsDynamicComponentDictionaryGetter(MethodInfo method, Express targetObject = member.Expression; while (metaData == null && targetObject != null && (targetObject.NodeType == ExpressionType.MemberAccess || targetObject.NodeType == ExpressionType.Parameter || - targetObject.NodeType == QuerySourceReferenceExpression.ExpressionType)) + targetObject is QuerySourceReferenceExpression)) { System.Type memberType; - if (targetObject.NodeType == QuerySourceReferenceExpression.ExpressionType) + if (targetObject is QuerySourceReferenceExpression) { var querySourceExpression = (QuerySourceReferenceExpression) targetObject; memberType = querySourceExpression.Type; diff --git a/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs b/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs index dcb5e57c617..e86212d8e0f 100644 --- a/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs +++ b/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; +using NHibernate.Linq.Clauses; using NHibernate.Linq.ReWriters; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; @@ -34,7 +35,7 @@ namespace NHibernate.Linq.Visitors /// a.B.C == 1 && a.D.E == 1 can be inner joined. /// a.B.C == 1 || a.D.E == 1 must be outer joined. /// - /// By default we outer join via the code in VisitExpression. The use of inner joins is only + /// By default we outer join via the code in Visit. The use of inner joins is only /// an optimization hint to the database. /// /// More examples: @@ -56,7 +57,7 @@ namespace NHibernate.Linq.Visitors /// /// The code here is based on the excellent work started by Harald Mueller. /// - internal class WhereJoinDetector : ExpressionTreeVisitor + internal class WhereJoinDetector : RelinqExpressionVisitor { // TODO: There are a number of types of expressions that we didn't handle here due to time constraints. For example, the ?: operator could be checked easily. private readonly IIsEntityDecider _isEntityDecider; @@ -76,9 +77,9 @@ internal WhereJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner) _joiner = joiner; } - public void Transform(WhereClause whereClause) + public void Transform(IClause whereClause) { - whereClause.TransformExpressions(VisitExpression); + whereClause.TransformExpressions(Visit); var values = _values.Pop(); @@ -92,7 +93,7 @@ public void Transform(WhereClause whereClause) } } - public override Expression VisitExpression(Expression expression) + public override Expression Visit(Expression expression) { if (expression == null) return null; @@ -104,7 +105,7 @@ public override Expression VisitExpression(Expression expression) _handled.Push(false); int originalCount = _values.Count; - Expression result = base.VisitExpression(expression); + Expression result = base.Visit(expression); if (!_handled.Pop()) { @@ -119,9 +120,9 @@ public override Expression VisitExpression(Expression expression) return result; } - protected override Expression VisitBinaryExpression(BinaryExpression expression) + protected override Expression VisitBinary(BinaryExpression expression) { - var result = base.VisitBinaryExpression(expression); + var result = base.VisitBinary(expression); if (expression.NodeType == ExpressionType.AndAlso) { @@ -239,9 +240,9 @@ protected override Expression VisitBinaryExpression(BinaryExpression expression) return result; } - protected override Expression VisitUnaryExpression(UnaryExpression expression) + protected override Expression VisitUnary(UnaryExpression expression) { - Expression result = base.VisitUnaryExpression(expression); + Expression result = base.VisitUnary(expression); if (expression.NodeType == ExpressionType.Not && expression.Type == typeof(bool)) { @@ -271,22 +272,22 @@ protected override Expression VisitUnaryExpression(UnaryExpression expression) return result; } - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + protected override Expression VisitSubQuery(SubQueryExpression expression) { - expression.QueryModel.TransformExpressions(VisitExpression); + expression.QueryModel.TransformExpressions(Visit); return expression; } // We would usually get NULL if one of our inner member expresions was null. // However, it's possible a method call will convert the null value from the failed join into a non-null value. // This could be optimized by actually checking what the method does. For example StartsWith("s") would leave null as null and would still allow us to inner join. - //protected override Expression VisitMethodCallExpression(MethodCallExpression expression) + //protected override Expression VisitMethodCall(MethodCallExpression expression) //{ - // Expression result = base.VisitMethodCallExpression(expression); + // Expression result = base.VisitMethodCall(expression); // return result; //} - protected override Expression VisitMemberExpression(MemberExpression expression) + protected override Expression VisitMember(MemberExpression expression) { // The member expression we're visiting might be on the end of a variety of things, such as: // a.B @@ -300,7 +301,7 @@ protected override Expression VisitMemberExpression(MemberExpression expression) if (!isIdentifier) _memberExpressionDepth++; - var result = base.VisitMemberExpression(expression); + var result = base.VisitMember(expression); if (!isIdentifier) _memberExpressionDepth--; diff --git a/src/NHibernate/NHibernate.csproj b/src/NHibernate/NHibernate.csproj index d1053596580..a1d77c746a9 100644 --- a/src/NHibernate/NHibernate.csproj +++ b/src/NHibernate/NHibernate.csproj @@ -79,8 +79,11 @@ ..\packages\Iesi.Collections.4.0.1.4000\lib\net40\Iesi.Collections.dll True - - ..\packages\Remotion.Linq.1.15.15.0\lib\portable-net45+wp80+wpa81+win\Remotion.Linq.dll + + ..\packages\Remotion.Linq.2.1.2\lib\net45\Remotion.Linq.dll + + + ..\packages\Remotion.Linq.EagerFetching.2.1.0\lib\net45\Remotion.Linq.EagerFetching.dll @@ -146,6 +149,10 @@ + + + + @@ -303,6 +310,7 @@ + @@ -331,7 +339,7 @@ - + @@ -990,7 +998,6 @@ - @@ -1050,11 +1057,11 @@ - + - + @@ -1808,7 +1815,9 @@ - + + Designer + @@ -1836,4 +1845,4 @@ - \ No newline at end of file + diff --git a/src/NHibernate/NHibernate.nuspec.template b/src/NHibernate/NHibernate.nuspec.template index 7ed3e54f5d2..6a9f9f49dba 100644 --- a/src/NHibernate/NHibernate.nuspec.template +++ b/src/NHibernate/NHibernate.nuspec.template @@ -15,7 +15,8 @@ - + + http://nhibernate.info @@ -34,4 +35,4 @@ - \ No newline at end of file + diff --git a/src/NHibernate/packages.config b/src/NHibernate/packages.config index b3056b09770..71e81725de5 100644 --- a/src/NHibernate/packages.config +++ b/src/NHibernate/packages.config @@ -3,5 +3,6 @@ - + + \ No newline at end of file