diff --git a/examples/python-udf-comparisons.py b/examples/python-udf-comparisons.py new file mode 100644 index 000000000..e2d856749 --- /dev/null +++ b/examples/python-udf-comparisons.py @@ -0,0 +1,186 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datafusion import SessionContext, col, lit, udf, functions as F +import os +import pyarrow as pa +import pyarrow.compute as pc +import time + +path = os.path.dirname(os.path.abspath(__file__)) +filepath = os.path.join(path, "../tpch/data/lineitem.parquet") + +# This example serves to demonstrate alternate approaches to answering the +# question "return all of the rows that have a specific combination of these +# values". We have the combinations we care about provided as a python +# list of tuples. There is no built in function that supports this operation, +# but it can be explicilty specified via a single expression or we can +# use a user defined function. + +ctx = SessionContext() + +# These part keys and suppliers are chosen because there are +# cases where two suppliers each have two of the part keys +# but we are interested in these specific combinations. + +values_of_interest = [ + (1530, 4031, "N"), + (6530, 1531, "N"), + (5618, 619, "N"), + (8118, 8119, "N"), +] + +partkeys = [lit(r[0]) for r in values_of_interest] +suppkeys = [lit(r[1]) for r in values_of_interest] +returnflags = [lit(r[2]) for r in values_of_interest] + +df_lineitem = ctx.read_parquet(filepath).select( + "l_partkey", "l_suppkey", "l_returnflag" +) + +start_time = time.time() + +df_simple_filter = df_lineitem.filter( + F.in_list(col("l_partkey"), partkeys), + F.in_list(col("l_suppkey"), suppkeys), + F.in_list(col("l_returnflag"), returnflags), +) + +num_rows = df_simple_filter.count() +print( + f"Simple filtering has number {num_rows} rows and took {time.time() - start_time} s" +) +print("This is the incorrect number of rows!") +start_time = time.time() + +# Explicitly check for the combinations of interest. +# This works but is not scalable. + +filter_expr = ( + ( + (col("l_partkey") == values_of_interest[0][0]) + & (col("l_suppkey") == values_of_interest[0][1]) + & (col("l_returnflag") == values_of_interest[0][2]) + ) + | ( + (col("l_partkey") == values_of_interest[1][0]) + & (col("l_suppkey") == values_of_interest[1][1]) + & (col("l_returnflag") == values_of_interest[1][2]) + ) + | ( + (col("l_partkey") == values_of_interest[2][0]) + & (col("l_suppkey") == values_of_interest[2][1]) + & (col("l_returnflag") == values_of_interest[2][2]) + ) + | ( + (col("l_partkey") == values_of_interest[3][0]) + & (col("l_suppkey") == values_of_interest[3][1]) + & (col("l_returnflag") == values_of_interest[3][2]) + ) +) + +df_explicit_filter = df_lineitem.filter(filter_expr) + +num_rows = df_explicit_filter.count() +print( + f"Explicit filtering has number {num_rows} rows and took {time.time() - start_time} s" +) +start_time = time.time() + +# Instead try a python UDF + + +def is_of_interest_impl( + partkey_arr: pa.Array, + suppkey_arr: pa.Array, + returnflag_arr: pa.Array, +) -> pa.Array: + result = [] + for idx, partkey in enumerate(partkey_arr): + partkey = partkey.as_py() + suppkey = suppkey_arr[idx].as_py() + returnflag = returnflag_arr[idx].as_py() + value = (partkey, suppkey, returnflag) + result.append(value in values_of_interest) + + return pa.array(result) + + +is_of_interest = udf( + is_of_interest_impl, + [pa.int32(), pa.int32(), pa.utf8()], + pa.bool_(), + "stable", +) + +df_udf_filter = df_lineitem.filter( + is_of_interest(col("l_partkey"), col("l_suppkey"), col("l_returnflag")) +) + +num_rows = df_udf_filter.count() +print(f"UDF filtering has number {num_rows} rows and took {time.time() - start_time} s") +start_time = time.time() + +# Now use a user defined function but lean on the built in pyarrow array +# functions so we never convert rows to python objects. + +# To see other pyarrow compute functions see +# https://arrow.apache.org/docs/python/api/compute.html +# +# It is important that the number of rows in the returned array +# matches the original array, so we cannot use functions like +# filtered_partkey_arr.filter(filtered_suppkey_arr). + + +def udf_using_pyarrow_compute_impl( + partkey_arr: pa.Array, + suppkey_arr: pa.Array, + returnflag_arr: pa.Array, +) -> pa.Array: + results = None + for partkey, suppkey, returnflag in values_of_interest: + filtered_partkey_arr = pc.equal(partkey_arr, partkey) + filtered_suppkey_arr = pc.equal(suppkey_arr, suppkey) + filtered_returnflag_arr = pc.equal(returnflag_arr, returnflag) + + resultant_arr = pc.and_(filtered_partkey_arr, filtered_suppkey_arr) + resultant_arr = pc.and_(resultant_arr, filtered_returnflag_arr) + + if results is None: + results = resultant_arr + else: + results = pc.or_(results, resultant_arr) + + return results + + +udf_using_pyarrow_compute = udf( + udf_using_pyarrow_compute_impl, + [pa.int32(), pa.int32(), pa.utf8()], + pa.bool_(), + "stable", +) + +df_udf_pyarrow_compute = df_lineitem.filter( + udf_using_pyarrow_compute(col("l_partkey"), col("l_suppkey"), col("l_returnflag")) +) + +num_rows = df_udf_pyarrow_compute.count() +print( + f"UDF filtering using pyarrow compute has number {num_rows} rows and took {time.time() - start_time} s" +) +start_time = time.time()