Skip to content

enhance sql-using-python-udf example #1054

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 87 additions & 25 deletions examples/sql-using-python-udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,49 +16,111 @@
# under the License.

import pyarrow as pa
from datafusion import SessionContext, udf
from datafusion import SessionContext, udf, DataFrame

# Print version information for debugging
import datafusion
import pyarrow

# Define a user-defined function (UDF)
print(f"DataFusion version: {datafusion.__version__}")
print(f"PyArrow version: {pyarrow.__version__}")


# Define a user-defined function (UDF) that checks if a value is null
def is_null(array: pa.Array) -> pa.Array:
"""
A UDF that checks if elements in an array are null.
Args:
array (pa.Array): Input PyArrow array
Returns:
pa.Array: Boolean array indicating which elements are null
"""
return array.is_null()


# Create the UDF definition
is_null_arr = udf(
is_null,
[pa.int64()],
pa.bool_(),
"stable",
# This will be the name of the UDF in SQL
# If not specified it will by default the same as Python function name
name="is_null",
is_null, # The Python function to use
[pa.int64()], # Input type(s) - here we expect one int64 column
pa.bool_(), # Output type - returns boolean
"stable", # Volatility - "stable" means same input = same output
name="is_null" # SQL name for the function
)

# Create a context
# Create a DataFusion session context
ctx = SessionContext()

# Create a datafusion DataFrame from a Python dictionary
ctx.from_pydict({"a": [1, 2, 3], "b": [4, None, 6]}, name="t")
# Dataframe:
# +---+---+
# | a | b |
# +---+---+
# | 1 | 4 |
# | 2 | |
# | 3 | 6 |
# +---+---+

# Register UDF for use in SQL
try:
# Method 1: Using DataFrame.from_pydict (for newer DataFusion versions)
print("\nTrying Method 1: DataFrame.from_pydict")
df = DataFrame.from_pydict(ctx, {
"a": [1, 2, 3],
"b": [4, None, 6]
})
df.create_or_replace_table("t")
except Exception as e:
print(f"Method 1 failed: {e}")

try:
# Method 2: Using arrow table directly
print("\nTrying Method 2: Register arrow table")
table = pa.table({
"a": [1, 2, 3],
"b": [4, None, 6]
})
ctx.register_table("t", table)
except Exception as e:
print(f"Method 2 failed: {e}")

# Method 3: Using explicit record batch creation
print("\nTrying Method 3: Explicit record batch creation")
# Define the schema for our data
schema = pa.schema([
('a', pa.int64()), # Column 'a' is int64
('b', pa.int64()) # Column 'b' is int64
])

# Create a record batch with our data
batch = pa.record_batch([
pa.array([1, 2, 3], type=pa.int64()), # Data for column 'a'
pa.array([4, None, 6], type=pa.int64()) # Data for column 'b'
], schema=schema)

# Register the record batch with DataFusion
# Note: The double list [[batch]] is required by the API
ctx.register_record_batches("t", [[batch]])

# Register our UDF with the context
ctx.register_udf(is_null_arr)

# Query the DataFrame using SQL
print("\nExecuting SQL query...")
# Execute a SQL query that uses our UDF
result_df = ctx.sql("select a, is_null(b) as b_is_null from t")
# Dataframe:

# Expected output:
# +---+-----------+
# | a | b_is_null |
# +---+-----------+
# | 1 | false |
# | 2 | true |
# | 3 | false |
# +---+-----------+
assert result_df.to_pydict()["b_is_null"] == [False, True, False]

# Convert result to dictionary and display
result_dict = result_df.to_pydict()
print("\nQuery Results:")
print("Result:", result_dict)

# Verify the results
assert result_dict["b_is_null"] == [False, True, False], "Unexpected results from UDF"
print("\nAssert passed - UDF working as expected!")

# Print a formatted version of the results
print("\nFormatted Results:")
print("+---+-----------+")
print("| a | b_is_null |")
print("+---+-----------+")
for i in range(len(result_dict["a"])):
print(f"| {result_dict['a'][i]} | {str(result_dict['b_is_null'][i]).lower():9} |")
print("+---+-----------+")

37 changes: 18 additions & 19 deletions examples/substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,34 @@
# specific language governing permissions and limitations
# under the License.


import os
from datafusion import SessionContext
from datafusion import substrait as ss

# Get the directory of the current script
script_dir = os.path.dirname(os.path.abspath(__file__))

# Construct the path to the CSV file
# Using os.path.join for cross-platform compatibility
csv_file_path = os.path.join(script_dir, '..', 'testing', 'data', 'csv', 'aggregate_test_100.csv')

# Create a DataFusion context
ctx = SessionContext()

# Register table with context
ctx.register_csv("aggregate_test_data", "./testing/data/csv/aggregate_test_100.csv")

try:
# Register table with context
ctx.register_csv("aggregate_test_data", csv_file_path)
except Exception as e:
print(f"Error registering CSV file: {e}")
print(f"Looking for file at: {csv_file_path}")
raise

# Create Substrait plan from SQL query
substrait_plan = ss.Serde.serialize_to_plan("SELECT * FROM aggregate_test_data", ctx)
# type(substrait_plan) -> <class 'datafusion.substrait.plan'>

# Encode it to bytes
substrait_bytes = substrait_plan.encode()
# type(substrait_bytes) -> <class 'bytes'>, at this point the bytes can be distributed to file, network, etc safely
# where they could subsequently be deserialized on the receiving end.

# Alternative serialization approaches
# type(substrait_bytes) -> <class 'bytes'>, at this point the bytes can be distributed to file, network, etc safely
# where they could subsequently be deserialized on the receiving end.
substrait_bytes = ss.Serde.serialize_bytes("SELECT * FROM aggregate_test_data", ctx)

# Imagine here bytes would be read from network, file, etc ... for example brevity this is omitted and variable is simply reused
# type(substrait_plan) -> <class 'datafusion.substrait.plan'>
substrait_plan = ss.Serde.deserialize_bytes(substrait_bytes)

# type(df_logical_plan) -> <class 'substrait.LogicalPlan'>
df_logical_plan = ss.Consumer.from_substrait_plan(ctx, substrait_plan)

# Back to Substrait Plan just for demonstration purposes
# type(substrait_plan) -> <class 'datafusion.substrait.plan'>
substrait_plan = ss.Producer.to_substrait_plan(df_logical_plan, ctx)