Skip to content

Add in user example that compares a two different approaches to UDFs #770

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions examples/python-udf-comparisons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from datafusion import SessionContext, col, lit, udf, functions as F
import os
import pyarrow as pa
import pyarrow.compute as pc
import time

path = os.path.dirname(os.path.abspath(__file__))
filepath = os.path.join(path, "../tpch/data/lineitem.parquet")

# This example serves to demonstrate alternate approaches to answering the
# question "return all of the rows that have a specific combination of these
# values". We have the combinations we care about provided as a python
# list of tuples. There is no built in function that supports this operation,
# but it can be explicilty specified via a single expression or we can
# use a user defined function.

ctx = SessionContext()

# These part keys and suppliers are chosen because there are
# cases where two suppliers each have two of the part keys
# but we are interested in these specific combinations.

values_of_interest = [
(1530, 4031, "N"),
(6530, 1531, "N"),
(5618, 619, "N"),
(8118, 8119, "N"),
]

partkeys = [lit(r[0]) for r in values_of_interest]
suppkeys = [lit(r[1]) for r in values_of_interest]
returnflags = [lit(r[2]) for r in values_of_interest]

df_lineitem = ctx.read_parquet(filepath).select(
"l_partkey", "l_suppkey", "l_returnflag"
)

start_time = time.time()

df_simple_filter = df_lineitem.filter(
F.in_list(col("l_partkey"), partkeys),
F.in_list(col("l_suppkey"), suppkeys),
F.in_list(col("l_returnflag"), returnflags),
)

num_rows = df_simple_filter.count()
print(
f"Simple filtering has number {num_rows} rows and took {time.time() - start_time} s"
)
print("This is the incorrect number of rows!")
start_time = time.time()

# Explicitly check for the combinations of interest.
# This works but is not scalable.

filter_expr = (
(
(col("l_partkey") == values_of_interest[0][0])
& (col("l_suppkey") == values_of_interest[0][1])
& (col("l_returnflag") == values_of_interest[0][2])
)
| (
(col("l_partkey") == values_of_interest[1][0])
& (col("l_suppkey") == values_of_interest[1][1])
& (col("l_returnflag") == values_of_interest[1][2])
)
| (
(col("l_partkey") == values_of_interest[2][0])
& (col("l_suppkey") == values_of_interest[2][1])
& (col("l_returnflag") == values_of_interest[2][2])
)
| (
(col("l_partkey") == values_of_interest[3][0])
& (col("l_suppkey") == values_of_interest[3][1])
& (col("l_returnflag") == values_of_interest[3][2])
)
)

df_explicit_filter = df_lineitem.filter(filter_expr)

num_rows = df_explicit_filter.count()
print(
f"Explicit filtering has number {num_rows} rows and took {time.time() - start_time} s"
)
start_time = time.time()

# Instead try a python UDF


def is_of_interest_impl(
partkey_arr: pa.Array,
suppkey_arr: pa.Array,
returnflag_arr: pa.Array,
) -> pa.Array:
result = []
for idx, partkey in enumerate(partkey_arr):
partkey = partkey.as_py()
suppkey = suppkey_arr[idx].as_py()
returnflag = returnflag_arr[idx].as_py()
value = (partkey, suppkey, returnflag)
result.append(value in values_of_interest)

return pa.array(result)


is_of_interest = udf(
is_of_interest_impl,
[pa.int32(), pa.int32(), pa.utf8()],
pa.bool_(),
"stable",
)

df_udf_filter = df_lineitem.filter(
is_of_interest(col("l_partkey"), col("l_suppkey"), col("l_returnflag"))
)

num_rows = df_udf_filter.count()
print(f"UDF filtering has number {num_rows} rows and took {time.time() - start_time} s")
start_time = time.time()

# Now use a user defined function but lean on the built in pyarrow array
# functions so we never convert rows to python objects.

# To see other pyarrow compute functions see
# https://arrow.apache.org/docs/python/api/compute.html
#
# It is important that the number of rows in the returned array
# matches the original array, so we cannot use functions like
# filtered_partkey_arr.filter(filtered_suppkey_arr).


def udf_using_pyarrow_compute_impl(
partkey_arr: pa.Array,
suppkey_arr: pa.Array,
returnflag_arr: pa.Array,
) -> pa.Array:
results = None
for partkey, suppkey, returnflag in values_of_interest:
filtered_partkey_arr = pc.equal(partkey_arr, partkey)
filtered_suppkey_arr = pc.equal(suppkey_arr, suppkey)
filtered_returnflag_arr = pc.equal(returnflag_arr, returnflag)

resultant_arr = pc.and_(filtered_partkey_arr, filtered_suppkey_arr)
resultant_arr = pc.and_(resultant_arr, filtered_returnflag_arr)

if results is None:
results = resultant_arr
else:
results = pc.or_(results, resultant_arr)

return results


udf_using_pyarrow_compute = udf(
udf_using_pyarrow_compute_impl,
[pa.int32(), pa.int32(), pa.utf8()],
pa.bool_(),
"stable",
)

df_udf_pyarrow_compute = df_lineitem.filter(
udf_using_pyarrow_compute(col("l_partkey"), col("l_suppkey"), col("l_returnflag"))
)

num_rows = df_udf_pyarrow_compute.count()
print(
f"UDF filtering using pyarrow compute has number {num_rows} rows and took {time.time() - start_time} s"
)
start_time = time.time()
Loading