Skip to content

Commit f5495f3

Browse files
committed
Proper support for IN clause for composite properties in Criteria
1 parent b734d00 commit f5495f3

File tree

9 files changed

+160
-54
lines changed

9 files changed

+160
-54
lines changed

src/NHibernate.Test/Async/CompositeId/ClassWithCompositeIdFixture.cs

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,6 @@ protected override string[] Mappings
3737
get { return new string[] {"CompositeId.ClassWithCompositeId.hbm.xml"}; }
3838
}
3939

40-
protected override bool AppliesTo(Dialect.Dialect dialect)
41-
{
42-
return !(dialect is Dialect.FirebirdDialect); // Firebird has no CommandTimeout, and locks up during the tear-down of this fixture
43-
}
44-
4540
protected override void OnSetUp()
4641
{
4742
id = new Id("stringKey", 3, firstDateTime);
@@ -51,9 +46,11 @@ protected override void OnSetUp()
5146
protected override void OnTearDown()
5247
{
5348
using (ISession s = Sfi.OpenSession())
49+
using (var t = s.BeginTransaction())
5450
{
5551
s.Delete("from ClassWithCompositeId");
5652
s.Flush();
53+
t.Commit();
5754
}
5855
}
5956

@@ -239,5 +236,26 @@ public async Task HqlAsync()
239236
Assert.AreEqual(1, results.Count);
240237
}
241238
}
239+
240+
[Test]
241+
public async Task QueryOverInClauseAsync()
242+
{
243+
// insert the new objects
244+
using (ISession s = OpenSession())
245+
using (ITransaction t = s.BeginTransaction())
246+
{
247+
await (s.SaveAsync(new ClassWithCompositeId(id) {OneProperty = 5}));
248+
await (s.SaveAsync(new ClassWithCompositeId(secondId) {OneProperty = 10}));
249+
await (s.SaveAsync(new ClassWithCompositeId(new Id(id.KeyString, id.GetKeyShort(), secondId.KeyDateTime))));
250+
251+
await (t.CommitAsync());
252+
}
253+
254+
using (var s = OpenSession())
255+
{
256+
var results = await (s.QueryOver<ClassWithCompositeId>().WhereRestrictionOn(p => p.Id).IsIn(new[] {id, secondId}).ListAsync());
257+
Assert.That(results.Count, Is.EqualTo(2));
258+
}
259+
}
242260
}
243261
}

src/NHibernate.Test/CompositeId/ClassWithCompositeIdFixture.cs

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,6 @@ protected override string[] Mappings
2626
get { return new string[] {"CompositeId.ClassWithCompositeId.hbm.xml"}; }
2727
}
2828

29-
protected override bool AppliesTo(Dialect.Dialect dialect)
30-
{
31-
return !(dialect is Dialect.FirebirdDialect); // Firebird has no CommandTimeout, and locks up during the tear-down of this fixture
32-
}
33-
3429
protected override void OnSetUp()
3530
{
3631
id = new Id("stringKey", 3, firstDateTime);
@@ -40,9 +35,11 @@ protected override void OnSetUp()
4035
protected override void OnTearDown()
4136
{
4237
using (ISession s = Sfi.OpenSession())
38+
using (var t = s.BeginTransaction())
4339
{
4440
s.Delete("from ClassWithCompositeId");
4541
s.Flush();
42+
t.Commit();
4643
}
4744
}
4845

@@ -228,5 +225,26 @@ public void Hql()
228225
Assert.AreEqual(1, results.Count);
229226
}
230227
}
228+
229+
[Test]
230+
public void QueryOverInClause()
231+
{
232+
// insert the new objects
233+
using (ISession s = OpenSession())
234+
using (ITransaction t = s.BeginTransaction())
235+
{
236+
s.Save(new ClassWithCompositeId(id) {OneProperty = 5});
237+
s.Save(new ClassWithCompositeId(secondId) {OneProperty = 10});
238+
s.Save(new ClassWithCompositeId(new Id(id.KeyString, id.GetKeyShort(), secondId.KeyDateTime)));
239+
240+
t.Commit();
241+
}
242+
243+
using (var s = OpenSession())
244+
{
245+
var results = s.QueryOver<ClassWithCompositeId>().WhereRestrictionOn(p => p.Id).IsIn(new[] {id, secondId}).List();
246+
Assert.That(results.Count, Is.EqualTo(2));
247+
}
248+
}
231249
}
232250
}

src/NHibernate.Test/CompositeId/Id.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ public string KeyString
2828
set { _keyString = value; }
2929
}
3030

31+
public short GetKeyShort() => _keyShort;
3132
// public short KeyShort {
3233
// get { return _keyShort;}
3334
// set {_keyShort = value;}
@@ -55,4 +56,4 @@ public override bool Equals(object obj)
5556
return false;
5657
}
5758
}
58-
}
59+
}

src/NHibernate/Criterion/InExpression.cs

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using System;
2-
using System.Collections;
32
using System.Collections.Generic;
43
using System.Linq;
54
using NHibernate.Engine;
@@ -13,9 +12,6 @@ namespace NHibernate.Criterion
1312
/// An <see cref="ICriterion"/> that constrains the property
1413
/// to a specified list of values.
1514
/// </summary>
16-
/// <remarks>
17-
/// InExpression - should only be used with a Single Value column - no multicolumn properties...
18-
/// </remarks>
1915
[Serializable]
2016
public class InExpression : AbstractCriterion
2117
{
@@ -62,41 +58,50 @@ public override SqlString ToSqlString(ICriteria criteria, ICriteriaQuery criteri
6258
return new SqlString("1=0");
6359
}
6460

65-
//TODO: add default capacity
66-
SqlStringBuilder result = new SqlStringBuilder();
67-
SqlString[] columnNames =
68-
CriterionUtil.GetColumnNames(_propertyName, _projection, criteriaQuery, criteria);
69-
70-
// Generate SqlString of the form:
71-
// columnName1 in (values) and columnName2 in (values) and ...
72-
Parameter[] parameters = GetParameterTypedValues(criteria, criteriaQuery).SelectMany(t => criteriaQuery.NewQueryParameter(t)).ToArray();
61+
SqlString[] columns = CriterionUtil.GetColumnNames(_propertyName, _projection, criteriaQuery, criteria);
7362

74-
for (int columnIndex = 0; columnIndex < columnNames.Length; columnIndex++)
63+
var list = new List<Parameter>(columns.Length * Values.Length);
64+
foreach (var typedValue in GetParameterTypedValues(criteria, criteriaQuery))
7565
{
76-
SqlString columnName = columnNames[columnIndex];
77-
78-
if (columnIndex > 0)
79-
{
80-
result.Add(" and ");
81-
}
66+
//Must be executed after CriterionUtil.GetColumnNames (as it might add _projection parameters to criteria)
67+
list.AddRange(criteriaQuery.NewQueryParameter(typedValue));
68+
}
8269

83-
result
84-
.Add(columnName)
85-
.Add(" in (");
70+
var bogusParam = Parameter.Placeholder;
8671

87-
for (int i = 0; i < _values.Length; i++)
88-
{
89-
if (i > 0)
90-
{
91-
result.Add(StringHelper.CommaSpace);
92-
}
93-
result.Add(parameters[i]);
94-
}
72+
var sqlString = GetSqlString(criteriaQuery, columns, bogusParam);
73+
sqlString.SubstituteBogusParameters(list, bogusParam);
74+
return sqlString;
75+
}
9576

96-
result.Add(")");
77+
private SqlString GetSqlString(ICriteriaQuery criteriaQuery, SqlString[] columns, Parameter bogusParam)
78+
{
79+
if (columns.Length <= 1 || criteriaQuery.Factory.Dialect.SupportsRowValueConstructorSyntaxInInList)
80+
{
81+
var parens = columns.Length > 1 ? new[] {new SqlString("("), new SqlString(")"),} : null;
82+
SqlString comaSeparator = new SqlString(", ");
83+
var singleValueParam = SqlStringHelper.Repeat(new SqlString(bogusParam), columns.Length, comaSeparator, parens);
84+
85+
var parameters = SqlStringHelper.Repeat(singleValueParam, Values.Length, comaSeparator, null);
86+
87+
//single column: col1 in (?, ?)
88+
//multi column: (col1, col2) in ((?, ?), (?, ?))
89+
return new SqlString(
90+
parens?[0] ?? SqlString.Empty,
91+
SqlStringHelper.Join(comaSeparator, columns),
92+
parens?[1] ?? SqlString.Empty,
93+
" in (",
94+
parameters,
95+
")");
9796
}
9897

99-
return result.ToSqlString();
98+
//((col1 = ? and col2 = ?) or (col1 = ? and col2 = ?))
99+
var cols = new SqlString(
100+
" ( ",
101+
SqlStringHelper.Join(new SqlString(" = ", bogusParam, " and "), columns),
102+
new SqlString("= ", bogusParam, " ) "));
103+
cols = SqlStringHelper.Repeat(cols, Values.Length, "or ", new[] {" ( ", " ) "});
104+
return cols;
100105
}
101106

102107
private void AssertPropertyIsNotCollection(ICriteriaQuery criteriaQuery, ICriteria criteria)
@@ -127,16 +132,13 @@ private List<TypedValue> GetParameterTypedValues(ICriteria criteria, ICriteriaQu
127132
List<TypedValue> list = new List<TypedValue>();
128133
IAbstractComponentType actype = (IAbstractComponentType) type;
129134
IType[] types = actype.Subtypes;
130-
131-
for (int i = 0; i < types.Length; i++)
135+
for (int vi = 0; vi < _values.Length; vi++)
136+
for (int ti = 0; ti < types.Length; ti++)
132137
{
133-
for (int j = 0; j < _values.Length; j++)
134-
{
135-
object subval = _values[j] == null
136-
? null
137-
: actype.GetPropertyValues(_values[j])[i];
138-
list.Add(new TypedValue(types[i], subval, false));
139-
}
138+
object subval = _values[vi] == null
139+
? null
140+
: actype.GetPropertyValues(_values[vi])[ti];
141+
list.Add(new TypedValue(types[ti], subval, false));
140142
}
141143

142144
return list;

src/NHibernate/Dialect/MySQL57Dialect.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,8 @@ public MySQL57Dialect()
1616

1717
/// <inheritdoc />
1818
public override bool SupportsDateTimeScale => true;
19+
20+
/// <inheritdoc />
21+
public override bool SupportsRowValueConstructorSyntaxInInList => true;
1922
}
2023
}

src/NHibernate/Dialect/Oracle9iDialect.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,8 @@ public override CaseFragment CreateCaseFragment()
5656

5757
/// <inheritdoc />
5858
public override bool SupportsDateTimeScale => true;
59+
60+
/// <inheritdoc />
61+
public override bool SupportsRowValueConstructorSyntaxInInList => true;
5962
}
6063
}

src/NHibernate/Dialect/PostgreSQL82Dialect.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,8 @@ public override string GetDropSequenceString(string sequenceName)
2727
{
2828
return string.Concat("drop sequence if exists ", sequenceName);
2929
}
30+
31+
/// <inheritdoc />
32+
public override bool SupportsRowValueConstructorSyntaxInInList => true;
3033
}
31-
}
34+
}

src/NHibernate/SqlCommand/SqlString.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,19 @@ public SqlString GetSubselectString()
10081008
return new SubselectClauseExtractor(this).GetSqlString();
10091009
}
10101010

1011+
internal void SubstituteBogusParameters(IReadOnlyList<Parameter> actualParams, Parameter bogusParam)
1012+
{
1013+
int index = 0;
1014+
var keys = _parameters.Keys;
1015+
// ReSharper disable once ForCanBeConvertedToForeach
1016+
for (var i = 0; i < keys.Count; i++)
1017+
{
1018+
var key = keys[i];
1019+
if (ReferenceEquals(_parameters[key], bogusParam))
1020+
_parameters[key] = actualParams[index++];
1021+
}
1022+
}
1023+
10111024
[Serializable]
10121025
private struct Part : IEquatable<Part>
10131026
{

src/NHibernate/SqlCommand/SqlStringHelper.cs

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,55 @@ public static bool IsNotEmpty(SqlString str)
5555
return !IsEmpty(str);
5656
}
5757

58-
5958
public static bool IsEmpty(SqlString str)
6059
{
6160
return str == null || str.Count == 0;
6261
}
62+
63+
internal static SqlString Repeat(SqlString placeholder, int count, string separator, string[] wrapResult)
64+
{
65+
return Repeat(
66+
placeholder,
67+
count,
68+
new SqlString(separator),
69+
wrapResult == null
70+
? null
71+
: new[]
72+
{
73+
new SqlString(wrapResult[0]),
74+
new SqlString(wrapResult[1]),
75+
});
76+
}
77+
78+
internal static SqlString Repeat(SqlString placeholder, int count, SqlString separator, SqlString[] wrapResult)
79+
{
80+
if (wrapResult == null)
81+
{
82+
if (count == 0)
83+
return SqlString.Empty;
84+
if (count == 1)
85+
return placeholder;
86+
}
87+
88+
var builder = new SqlStringBuilder(count * 2 + 1);
89+
if (wrapResult != null)
90+
{
91+
builder.Add(wrapResult[0]);
92+
}
93+
94+
if (count > 0)
95+
builder.Add(placeholder);
96+
97+
for (int i = 1; i < count; i++)
98+
{
99+
builder.Add(separator).Add(placeholder);
100+
}
101+
102+
if (wrapResult != null)
103+
{
104+
builder.Add(wrapResult[1]);
105+
}
106+
return builder.ToSqlString();
107+
}
63108
}
64109
}

0 commit comments

Comments
 (0)