Skip to content

Fix regressions introduced by Conditional and Coalesce expansion (#1880) #1916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
//------------------------------------------------------------------------------
// <auto-generated>
// This code was generated by AsyncGenerator.
//
// Changes to this file may cause incorrect behavior and will be lost if
// the code is regenerated.
// </auto-generated>
//------------------------------------------------------------------------------


using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using NHibernate.Cfg;
using NHibernate.Hql.Ast;
using NHibernate.Linq.Functions;
using NHibernate.Linq.Visitors;
using NHibernate.Util;
using NUnit.Framework;
using NHibernate.Linq;

namespace NHibernate.Test.NHSpecificTest.GH1879
{
using System.Threading.Tasks;
[TestFixture]
public class ExpansionRegressionTestsAsync : GH1879BaseFixtureAsync<Invoice>
{
protected override void OnSetUp()
{
using (var session = OpenSession())
using (var transaction = session.BeginTransaction())
{
session.Save(new Invoice { InvoiceNumber = 1, Amount = 10, SpecialAmount = 100, Paid = false });
session.Save(new Invoice { InvoiceNumber = 2, Amount = 10, SpecialAmount = 100, Paid = true });
session.Save(new Invoice { InvoiceNumber = 2, Amount = 10, SpecialAmount = 110, Paid = false });
session.Save(new Invoice { InvoiceNumber = 2, Amount = 10, SpecialAmount = 110, Paid = true });

session.Flush();
transaction.Commit();
}
}

protected override void Configure(Configuration configuration)
{
configuration.LinqToHqlGeneratorsRegistry<TestLinqToHqlGeneratorsRegistry>();
}

private class TestLinqToHqlGeneratorsRegistry : DefaultLinqToHqlGeneratorsRegistry
{
public TestLinqToHqlGeneratorsRegistry()
{
this.Merge(new ObjectEquality());
}
}

private class ObjectEquality : IHqlGeneratorForMethod
{
public HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
{
return treeBuilder.Equality(visitor.Visit(targetObject).AsExpression(), visitor.Visit(arguments[0]).AsExpression());
}

public IEnumerable<MethodInfo> SupportedMethods
{
get
{
yield return ReflectHelper.GetMethodDefinition<object>(x => x.Equals(x));
}
}
}

[Test]
public async Task MethodShouldNotExpandForNonConditionalOrCoalesceAsync()
{
using (var session = OpenSession())
{
Assert.That(await (session.Query<Invoice>().CountAsync(e => ((object)(e.Amount + e.SpecialAmount)).Equals(110))), Is.EqualTo(2));
}
}

[Test]
public async Task MethodShouldNotExpandForConditionalWithPropertyAccessorAsync()
{
using (var session = OpenSession())
{
Assert.That(await (session.Query<Invoice>().CountAsync(e => ((object)(e.Paid ? e.Amount : e.SpecialAmount)).Equals(10))), Is.EqualTo(2));
}
}

[Test]
public async Task MethodShouldNotExpandForCoalesceWithPropertyAccessorAsync()
{
using (var session = OpenSession())
{
Assert.That(await (session.Query<Invoice>().CountAsync(e => ((object)(e.SpecialAmount ?? e.Amount)).Equals(100))), Is.EqualTo(2));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ protected override HbmMapping GetMappings()
{
rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb));
rc.Property(x => x.InvoiceNumber);
rc.Property(x => x.Amount);
rc.Property(x => x.SpecialAmount);
rc.Property(x => x.Paid);
rc.ManyToOne(x => x.Project, m => m.Column("ProjectId"));
rc.ManyToOne(x => x.Issue, m => m.Column("IssueId"));
});
Expand Down Expand Up @@ -122,9 +125,9 @@ protected async Task AreEqualAsync<TResult>(
expectedResult = await (expectedQuery(session.Query<T>()).ToListAsync(cancellationToken));
}
catch (OperationCanceledException) { throw; }
catch
catch (Exception e)
{
Assert.Ignore("Not currently supported query");
Assert.Ignore($"Not currently supported query: {e}");
}

var testResult = await (actualQuery(session.Query<T>()).ToListAsync(cancellationToken));
Expand Down
3 changes: 3 additions & 0 deletions src/NHibernate.Test/NHSpecificTest/GH1879/Entity.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,8 @@ public class Invoice
public virtual int InvoiceNumber { get; set; }
public virtual Project Project { get; set; }
public virtual Issue Issue { get; set; }
public virtual int Amount { get; set; }
public virtual int? SpecialAmount { get; set; }
public virtual bool Paid { get; set; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using NHibernate.Cfg;
using NHibernate.Hql.Ast;
using NHibernate.Linq.Functions;
using NHibernate.Linq.Visitors;
using NHibernate.Util;
using NUnit.Framework;

namespace NHibernate.Test.NHSpecificTest.GH1879
{
[TestFixture]
public class ExpansionRegressionTests : GH1879BaseFixture<Invoice>
{
protected override void OnSetUp()
{
using (var session = OpenSession())
using (var transaction = session.BeginTransaction())
{
session.Save(new Invoice { InvoiceNumber = 1, Amount = 10, SpecialAmount = 100, Paid = false });
session.Save(new Invoice { InvoiceNumber = 2, Amount = 10, SpecialAmount = 100, Paid = true });
session.Save(new Invoice { InvoiceNumber = 2, Amount = 10, SpecialAmount = 110, Paid = false });
session.Save(new Invoice { InvoiceNumber = 2, Amount = 10, SpecialAmount = 110, Paid = true });

session.Flush();
transaction.Commit();
}
}

protected override void Configure(Configuration configuration)
{
configuration.LinqToHqlGeneratorsRegistry<TestLinqToHqlGeneratorsRegistry>();
}

private class TestLinqToHqlGeneratorsRegistry : DefaultLinqToHqlGeneratorsRegistry
{
public TestLinqToHqlGeneratorsRegistry()
{
this.Merge(new ObjectEquality());
}
}

private class ObjectEquality : IHqlGeneratorForMethod
{
public HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
{
return treeBuilder.Equality(visitor.Visit(targetObject).AsExpression(), visitor.Visit(arguments[0]).AsExpression());
}

public IEnumerable<MethodInfo> SupportedMethods
{
get
{
yield return ReflectHelper.GetMethodDefinition<object>(x => x.Equals(x));
}
}
}

[Test]
public void MethodShouldNotExpandForNonConditionalOrCoalesce()
{
using (var session = OpenSession())
{
Assert.That(session.Query<Invoice>().Count(e => ((object)(e.Amount + e.SpecialAmount)).Equals(110)), Is.EqualTo(2));
}
}

[Test]
public void MethodShouldNotExpandForConditionalWithPropertyAccessor()
{
using (var session = OpenSession())
{
Assert.That(session.Query<Invoice>().Count(e => ((object)(e.Paid ? e.Amount : e.SpecialAmount)).Equals(10)), Is.EqualTo(2));
}
}

[Test]
public void MethodShouldNotExpandForCoalesceWithPropertyAccessor()
{
using (var session = OpenSession())
{
Assert.That(session.Query<Invoice>().Count(e => ((object)(e.SpecialAmount ?? e.Amount)).Equals(100)), Is.EqualTo(2));
}
}
}
}
7 changes: 5 additions & 2 deletions src/NHibernate.Test/NHSpecificTest/GH1879/FixtureByCode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ protected override HbmMapping GetMappings()
{
rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb));
rc.Property(x => x.InvoiceNumber);
rc.Property(x => x.Amount);
rc.Property(x => x.SpecialAmount);
rc.Property(x => x.Paid);
rc.ManyToOne(x => x.Project, m => m.Column("ProjectId"));
rc.ManyToOne(x => x.Issue, m => m.Column("IssueId"));
});
Expand Down Expand Up @@ -108,9 +111,9 @@ protected void AreEqual<TResult>(
{
expectedResult = expectedQuery(session.Query<T>()).ToList();
}
catch
catch (Exception e)
{
Assert.Ignore("Not currently supported query");
Assert.Ignore($"Not currently supported query: {e}");
}

var testResult = actualQuery(session.Query<T>()).ToList();
Expand Down
49 changes: 27 additions & 22 deletions src/NHibernate/Linq/ReWriters/ConditionalQueryReferenceExpander.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public void Transform(ResultOperatorBase resultOperator)
protected override Expression VisitMember(MemberExpression node)
{
var result = (MemberExpression) base.VisitMember(node);
if (QueryReferenceCounter.CountReferences(result.Expression) > 1)
if (ShouldRewrite(result.Expression))
{
return ConditionalQueryReferenceMemberExpressionRewriter.Rewrite(result.Expression, node);
}
Expand All @@ -90,39 +90,44 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
var isExtension = node.Method.GetCustomAttributes<ExtensionAttribute>().Any();
var methodObject = isExtension ? node.Arguments[0] : node.Object;

if (methodObject != null && QueryReferenceCounter.CountReferences(methodObject) > 1)
if (ShouldRewrite(methodObject))
{
return ConditionalQueryReferenceMethodCallExpressionRewriter.Rewrite(methodObject, node);
}
return result;
}
}

private class QueryReferenceCounter : RelinqExpressionVisitor
{
private readonly System.Type _queryType;
private int _queryReferenceCount;

private QueryReferenceCounter(System.Type queryType)
private bool ShouldRewrite(Expression expr, System.Type queryType = null)
{
_queryType = queryType;
}
if (expr == null)
{
return false;
}

// Strip Converts
while (expr.NodeType == ExpressionType.Convert || expr.NodeType == ExpressionType.ConvertChecked)
{
expr = ((UnaryExpression)expr).Operand;
}

protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression)
{
if (_queryType.IsAssignableFrom(expression.Type))
if (expr is QuerySourceReferenceExpression && queryType?.IsAssignableFrom(expr.Type) == true)
{
_queryReferenceCount++;
return true;
}

return base.VisitQuerySourceReference(expression);
}
queryType = queryType ?? expr.Type;

public static int CountReferences(Expression node)
{
var visitor = new QueryReferenceCounter(node.Type);
visitor.Visit(node);
return visitor._queryReferenceCount;
if (expr.NodeType == ExpressionType.Coalesce && expr is BinaryExpression coalesce)
{
return ShouldRewrite(coalesce.Left, queryType) && ShouldRewrite(coalesce.Right, queryType);
}

if (expr.NodeType == ExpressionType.Conditional && expr is ConditionalExpression conditional)
{
return ShouldRewrite(conditional.IfFalse, queryType) && ShouldRewrite(conditional.IfTrue, queryType);
}

return false;
}
}

Expand Down