Skip to content

Commit 1ba7f9b

Browse files
PleasantDfredericDelaporte
authored andcommitted
Fix regressions introduced by Conditional and Coalesce expansion (#1880) (#1916)
1 parent fa978cf commit 1ba7f9b

File tree

6 files changed

+230
-26
lines changed

6 files changed

+230
-26
lines changed
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
//------------------------------------------------------------------------------
2+
// <auto-generated>
3+
// This code was generated by AsyncGenerator.
4+
//
5+
// Changes to this file may cause incorrect behavior and will be lost if
6+
// the code is regenerated.
7+
// </auto-generated>
8+
//------------------------------------------------------------------------------
9+
10+
11+
using System.Collections.Generic;
12+
using System.Collections.ObjectModel;
13+
using System.Linq;
14+
using System.Linq.Expressions;
15+
using System.Reflection;
16+
using NHibernate.Cfg;
17+
using NHibernate.Hql.Ast;
18+
using NHibernate.Linq.Functions;
19+
using NHibernate.Linq.Visitors;
20+
using NHibernate.Util;
21+
using NUnit.Framework;
22+
using NHibernate.Linq;
23+
24+
namespace NHibernate.Test.NHSpecificTest.GH1879
25+
{
26+
using System.Threading.Tasks;
27+
[TestFixture]
28+
public class ExpansionRegressionTestsAsync : GH1879BaseFixtureAsync<Invoice>
29+
{
30+
protected override void OnSetUp()
31+
{
32+
using (var session = OpenSession())
33+
using (var transaction = session.BeginTransaction())
34+
{
35+
session.Save(new Invoice { InvoiceNumber = 1, Amount = 10, SpecialAmount = 100, Paid = false });
36+
session.Save(new Invoice { InvoiceNumber = 2, Amount = 10, SpecialAmount = 100, Paid = true });
37+
session.Save(new Invoice { InvoiceNumber = 2, Amount = 10, SpecialAmount = 110, Paid = false });
38+
session.Save(new Invoice { InvoiceNumber = 2, Amount = 10, SpecialAmount = 110, Paid = true });
39+
40+
session.Flush();
41+
transaction.Commit();
42+
}
43+
}
44+
45+
protected override void Configure(Configuration configuration)
46+
{
47+
configuration.LinqToHqlGeneratorsRegistry<TestLinqToHqlGeneratorsRegistry>();
48+
}
49+
50+
private class TestLinqToHqlGeneratorsRegistry : DefaultLinqToHqlGeneratorsRegistry
51+
{
52+
public TestLinqToHqlGeneratorsRegistry()
53+
{
54+
this.Merge(new ObjectEquality());
55+
}
56+
}
57+
58+
private class ObjectEquality : IHqlGeneratorForMethod
59+
{
60+
public HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
61+
{
62+
return treeBuilder.Equality(visitor.Visit(targetObject).AsExpression(), visitor.Visit(arguments[0]).AsExpression());
63+
}
64+
65+
public IEnumerable<MethodInfo> SupportedMethods
66+
{
67+
get
68+
{
69+
yield return ReflectHelper.GetMethodDefinition<object>(x => x.Equals(x));
70+
}
71+
}
72+
}
73+
74+
[Test]
75+
public async Task MethodShouldNotExpandForNonConditionalOrCoalesceAsync()
76+
{
77+
using (var session = OpenSession())
78+
{
79+
Assert.That(await (session.Query<Invoice>().CountAsync(e => ((object)(e.Amount + e.SpecialAmount)).Equals(110))), Is.EqualTo(2));
80+
}
81+
}
82+
83+
[Test]
84+
public async Task MethodShouldNotExpandForConditionalWithPropertyAccessorAsync()
85+
{
86+
using (var session = OpenSession())
87+
{
88+
Assert.That(await (session.Query<Invoice>().CountAsync(e => ((object)(e.Paid ? e.Amount : e.SpecialAmount)).Equals(10))), Is.EqualTo(2));
89+
}
90+
}
91+
92+
[Test]
93+
public async Task MethodShouldNotExpandForCoalesceWithPropertyAccessorAsync()
94+
{
95+
using (var session = OpenSession())
96+
{
97+
Assert.That(await (session.Query<Invoice>().CountAsync(e => ((object)(e.SpecialAmount ?? e.Amount)).Equals(100))), Is.EqualTo(2));
98+
}
99+
}
100+
}
101+
}

src/NHibernate.Test/Async/NHSpecificTest/GH1879/FixtureByCode.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ protected override HbmMapping GetMappings()
6262
{
6363
rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb));
6464
rc.Property(x => x.InvoiceNumber);
65+
rc.Property(x => x.Amount);
66+
rc.Property(x => x.SpecialAmount);
67+
rc.Property(x => x.Paid);
6568
rc.ManyToOne(x => x.Project, m => m.Column("ProjectId"));
6669
rc.ManyToOne(x => x.Issue, m => m.Column("IssueId"));
6770
});
@@ -122,9 +125,9 @@ protected async Task AreEqualAsync<TResult>(
122125
expectedResult = await (expectedQuery(session.Query<T>()).ToListAsync(cancellationToken));
123126
}
124127
catch (OperationCanceledException) { throw; }
125-
catch
128+
catch (Exception e)
126129
{
127-
Assert.Ignore("Not currently supported query");
130+
Assert.Ignore($"Not currently supported query: {e}");
128131
}
129132

130133
var testResult = await (actualQuery(session.Query<T>()).ToListAsync(cancellationToken));

src/NHibernate.Test/NHSpecificTest/GH1879/Entity.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,8 @@ public class Invoice
5858
public virtual int InvoiceNumber { get; set; }
5959
public virtual Project Project { get; set; }
6060
public virtual Issue Issue { get; set; }
61+
public virtual int Amount { get; set; }
62+
public virtual int? SpecialAmount { get; set; }
63+
public virtual bool Paid { get; set; }
6164
}
6265
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
using System.Collections.Generic;
2+
using System.Collections.ObjectModel;
3+
using System.Linq;
4+
using System.Linq.Expressions;
5+
using System.Reflection;
6+
using NHibernate.Cfg;
7+
using NHibernate.Hql.Ast;
8+
using NHibernate.Linq.Functions;
9+
using NHibernate.Linq.Visitors;
10+
using NHibernate.Util;
11+
using NUnit.Framework;
12+
13+
namespace NHibernate.Test.NHSpecificTest.GH1879
14+
{
15+
[TestFixture]
16+
public class ExpansionRegressionTests : GH1879BaseFixture<Invoice>
17+
{
18+
protected override void OnSetUp()
19+
{
20+
using (var session = OpenSession())
21+
using (var transaction = session.BeginTransaction())
22+
{
23+
session.Save(new Invoice { InvoiceNumber = 1, Amount = 10, SpecialAmount = 100, Paid = false });
24+
session.Save(new Invoice { InvoiceNumber = 2, Amount = 10, SpecialAmount = 100, Paid = true });
25+
session.Save(new Invoice { InvoiceNumber = 2, Amount = 10, SpecialAmount = 110, Paid = false });
26+
session.Save(new Invoice { InvoiceNumber = 2, Amount = 10, SpecialAmount = 110, Paid = true });
27+
28+
session.Flush();
29+
transaction.Commit();
30+
}
31+
}
32+
33+
protected override void Configure(Configuration configuration)
34+
{
35+
configuration.LinqToHqlGeneratorsRegistry<TestLinqToHqlGeneratorsRegistry>();
36+
}
37+
38+
private class TestLinqToHqlGeneratorsRegistry : DefaultLinqToHqlGeneratorsRegistry
39+
{
40+
public TestLinqToHqlGeneratorsRegistry()
41+
{
42+
this.Merge(new ObjectEquality());
43+
}
44+
}
45+
46+
private class ObjectEquality : IHqlGeneratorForMethod
47+
{
48+
public HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
49+
{
50+
return treeBuilder.Equality(visitor.Visit(targetObject).AsExpression(), visitor.Visit(arguments[0]).AsExpression());
51+
}
52+
53+
public IEnumerable<MethodInfo> SupportedMethods
54+
{
55+
get
56+
{
57+
yield return ReflectHelper.GetMethodDefinition<object>(x => x.Equals(x));
58+
}
59+
}
60+
}
61+
62+
[Test]
63+
public void MethodShouldNotExpandForNonConditionalOrCoalesce()
64+
{
65+
using (var session = OpenSession())
66+
{
67+
Assert.That(session.Query<Invoice>().Count(e => ((object)(e.Amount + e.SpecialAmount)).Equals(110)), Is.EqualTo(2));
68+
}
69+
}
70+
71+
[Test]
72+
public void MethodShouldNotExpandForConditionalWithPropertyAccessor()
73+
{
74+
using (var session = OpenSession())
75+
{
76+
Assert.That(session.Query<Invoice>().Count(e => ((object)(e.Paid ? e.Amount : e.SpecialAmount)).Equals(10)), Is.EqualTo(2));
77+
}
78+
}
79+
80+
[Test]
81+
public void MethodShouldNotExpandForCoalesceWithPropertyAccessor()
82+
{
83+
using (var session = OpenSession())
84+
{
85+
Assert.That(session.Query<Invoice>().Count(e => ((object)(e.SpecialAmount ?? e.Amount)).Equals(100)), Is.EqualTo(2));
86+
}
87+
}
88+
}
89+
}

src/NHibernate.Test/NHSpecificTest/GH1879/FixtureByCode.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ protected override HbmMapping GetMappings()
4949
{
5050
rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb));
5151
rc.Property(x => x.InvoiceNumber);
52+
rc.Property(x => x.Amount);
53+
rc.Property(x => x.SpecialAmount);
54+
rc.Property(x => x.Paid);
5255
rc.ManyToOne(x => x.Project, m => m.Column("ProjectId"));
5356
rc.ManyToOne(x => x.Issue, m => m.Column("IssueId"));
5457
});
@@ -108,9 +111,9 @@ protected void AreEqual<TResult>(
108111
{
109112
expectedResult = expectedQuery(session.Query<T>()).ToList();
110113
}
111-
catch
114+
catch (Exception e)
112115
{
113-
Assert.Ignore("Not currently supported query");
116+
Assert.Ignore($"Not currently supported query: {e}");
114117
}
115118

116119
var testResult = actualQuery(session.Query<T>()).ToList();

src/NHibernate/Linq/ReWriters/ConditionalQueryReferenceExpander.cs

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ public void Transform(ResultOperatorBase resultOperator)
7777
protected override Expression VisitMember(MemberExpression node)
7878
{
7979
var result = (MemberExpression) base.VisitMember(node);
80-
if (QueryReferenceCounter.CountReferences(result.Expression) > 1)
80+
if (ShouldRewrite(result.Expression))
8181
{
8282
return ConditionalQueryReferenceMemberExpressionRewriter.Rewrite(result.Expression, node);
8383
}
@@ -90,39 +90,44 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
9090
var isExtension = node.Method.GetCustomAttributes<ExtensionAttribute>().Any();
9191
var methodObject = isExtension ? node.Arguments[0] : node.Object;
9292

93-
if (methodObject != null && QueryReferenceCounter.CountReferences(methodObject) > 1)
93+
if (ShouldRewrite(methodObject))
9494
{
9595
return ConditionalQueryReferenceMethodCallExpressionRewriter.Rewrite(methodObject, node);
9696
}
9797
return result;
9898
}
99-
}
100-
101-
private class QueryReferenceCounter : RelinqExpressionVisitor
102-
{
103-
private readonly System.Type _queryType;
104-
private int _queryReferenceCount;
10599

106-
private QueryReferenceCounter(System.Type queryType)
100+
private bool ShouldRewrite(Expression expr, System.Type queryType = null)
107101
{
108-
_queryType = queryType;
109-
}
102+
if (expr == null)
103+
{
104+
return false;
105+
}
106+
107+
// Strip Converts
108+
while (expr.NodeType == ExpressionType.Convert || expr.NodeType == ExpressionType.ConvertChecked)
109+
{
110+
expr = ((UnaryExpression)expr).Operand;
111+
}
110112

111-
protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression)
112-
{
113-
if (_queryType.IsAssignableFrom(expression.Type))
113+
if (expr is QuerySourceReferenceExpression && queryType?.IsAssignableFrom(expr.Type) == true)
114114
{
115-
_queryReferenceCount++;
115+
return true;
116116
}
117117

118-
return base.VisitQuerySourceReference(expression);
119-
}
118+
queryType = queryType ?? expr.Type;
120119

121-
public static int CountReferences(Expression node)
122-
{
123-
var visitor = new QueryReferenceCounter(node.Type);
124-
visitor.Visit(node);
125-
return visitor._queryReferenceCount;
120+
if (expr.NodeType == ExpressionType.Coalesce && expr is BinaryExpression coalesce)
121+
{
122+
return ShouldRewrite(coalesce.Left, queryType) && ShouldRewrite(coalesce.Right, queryType);
123+
}
124+
125+
if (expr.NodeType == ExpressionType.Conditional && expr is ConditionalExpression conditional)
126+
{
127+
return ShouldRewrite(conditional.IfFalse, queryType) && ShouldRewrite(conditional.IfTrue, queryType);
128+
}
129+
130+
return false;
126131
}
127132
}
128133

0 commit comments

Comments
 (0)