Skip to content

Commit 8c84d2c

Browse files
beliefercloud-fan
authored andcommitted
[SPARK-44018][SQL] Improve the hashCode and toString for some DS V2 Expression
### What changes were proposed in this pull request? The `hashCode() `of `UserDefinedScalarFunc` and `GeneralScalarExpression` is not good enough. Take for example, `GeneralScalarExpression` uses `Objects.hash(name, children)`, it adopt the hash code of `name` and `children`'s reference and then combine them together as the `GeneralScalarExpression`'s hash code. In fact, we should adopt the hash code for each element in `children`. Because `UserDefinedAggregateFunc` and `GeneralAggregateFunc` missing `hashCode()`, this PR also want add them. This PR also improve the toString for `UserDefinedAggregateFunc` and `GeneralAggregateFunc` by using bool primitive comparison instead `Objects.equals`. Because the performance of bool primitive comparison better than `Objects.equals`. ### Why are the changes needed? Improve the hash code for some DS V2 Expression. ### Does this PR introduce _any_ user-facing change? 'Yes'. ### How was this patch tested? N/A Closes #41543 from beliefer/SPARK-44018. Authored-by: Jiaan Geng <beliefer@163.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 8ccacb8 commit 8c84d2c

File tree

4 files changed

+62
-7
lines changed

4 files changed

+62
-7
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.sql.connector.expressions;
1919

2020
import java.util.Arrays;
21-
import java.util.Objects;
2221

2322
import org.apache.spark.annotation.Evolving;
2423
import org.apache.spark.sql.connector.expressions.filter.Predicate;
@@ -441,12 +440,17 @@ public GeneralScalarExpression(String name, Expression[] children) {
441440
public boolean equals(Object o) {
442441
if (this == o) return true;
443442
if (o == null || getClass() != o.getClass()) return false;
443+
444444
GeneralScalarExpression that = (GeneralScalarExpression) o;
445-
return Objects.equals(name, that.name) && Arrays.equals(children, that.children);
445+
446+
if (!name.equals(that.name)) return false;
447+
return Arrays.equals(children, that.children);
446448
}
447449

448450
@Override
449451
public int hashCode() {
450-
return Objects.hash(name, children);
452+
int result = name.hashCode();
453+
result = 31 * result + Arrays.hashCode(children);
454+
return result;
451455
}
452456
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.sql.connector.expressions;
1919

2020
import java.util.Arrays;
21-
import java.util.Objects;
2221

2322
import org.apache.spark.annotation.Evolving;
2423
import org.apache.spark.sql.internal.connector.ExpressionWithToString;
@@ -51,13 +50,19 @@ public UserDefinedScalarFunc(String name, String canonicalName, Expression[] chi
5150
public boolean equals(Object o) {
5251
if (this == o) return true;
5352
if (o == null || getClass() != o.getClass()) return false;
53+
5454
UserDefinedScalarFunc that = (UserDefinedScalarFunc) o;
55-
return Objects.equals(name, that.name) && Objects.equals(canonicalName, that.canonicalName) &&
56-
Arrays.equals(children, that.children);
55+
56+
if (!name.equals(that.name)) return false;
57+
if (!canonicalName.equals(that.canonicalName)) return false;
58+
return Arrays.equals(children, that.children);
5759
}
5860

5961
@Override
6062
public int hashCode() {
61-
return Objects.hash(name, canonicalName, children);
63+
int result = name.hashCode();
64+
result = 31 * result + canonicalName.hashCode();
65+
result = 31 * result + Arrays.hashCode(children);
66+
return result;
6267
}
6368
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.connector.expressions.aggregate;
1919

20+
import java.util.Arrays;
21+
2022
import org.apache.spark.annotation.Evolving;
2123
import org.apache.spark.sql.connector.expressions.Expression;
2224
import org.apache.spark.sql.internal.connector.ExpressionWithToString;
@@ -60,4 +62,24 @@ public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] childr
6062

6163
@Override
6264
public Expression[] children() { return children; }
65+
66+
@Override
67+
public boolean equals(Object o) {
68+
if (this == o) return true;
69+
if (o == null || getClass() != o.getClass()) return false;
70+
71+
GeneralAggregateFunc that = (GeneralAggregateFunc) o;
72+
73+
if (isDistinct != that.isDistinct) return false;
74+
if (!name.equals(that.name)) return false;
75+
return Arrays.equals(children, that.children);
76+
}
77+
78+
@Override
79+
public int hashCode() {
80+
int result = name.hashCode();
81+
result = 31 * result + (isDistinct ? 1 : 0);
82+
result = 31 * result + Arrays.hashCode(children);
83+
return result;
84+
}
6385
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.connector.expressions.aggregate;
1919

20+
import java.util.Arrays;
21+
2022
import org.apache.spark.annotation.Evolving;
2123
import org.apache.spark.sql.connector.expressions.Expression;
2224
import org.apache.spark.sql.internal.connector.ExpressionWithToString;
@@ -50,4 +52,26 @@ public UserDefinedAggregateFunc(
5052

5153
@Override
5254
public Expression[] children() { return children; }
55+
56+
@Override
57+
public boolean equals(Object o) {
58+
if (this == o) return true;
59+
if (o == null || getClass() != o.getClass()) return false;
60+
61+
UserDefinedAggregateFunc that = (UserDefinedAggregateFunc) o;
62+
63+
if (isDistinct != that.isDistinct) return false;
64+
if (!name.equals(that.name)) return false;
65+
if (!canonicalName.equals(that.canonicalName)) return false;
66+
return Arrays.equals(children, that.children);
67+
}
68+
69+
@Override
70+
public int hashCode() {
71+
int result = name.hashCode();
72+
result = 31 * result + canonicalName.hashCode();
73+
result = 31 * result + (isDistinct ? 1 : 0);
74+
result = 31 * result + Arrays.hashCode(children);
75+
return result;
76+
}
5377
}

0 commit comments

Comments
 (0)