Skip to content

fix: type checking #993

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/datafusion/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,12 @@ def __init__(self, table: df_internal.Table) -> None:
"""This constructor is not typically called by the end user."""
self.table = table

@property
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is changing the user-facing API, no?

Copy link
Contributor Author

@chenkovsky chenkovsky Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i checked the test,i think we should use @ property, but previously we didnt do wrapping somewhere, raw table is returned that's why test passed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've become convinced that the current code is broken and that this is the right change. Thank you @chenkovsky for doing this. I do think we need to put some notes to that effect in the section of user facing changes of the PR description.

def schema(self) -> pyarrow.Schema:
"""Returns the schema associated with this table."""
return self.table.schema()
return self.table.schema

@property
def kind(self) -> str:
"""Returns the kind of table."""
return self.table.kind()
return self.table.kind
19 changes: 13 additions & 6 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ def register_table(self, name: str, table: Table) -> None:
name: Name of the resultant table.
table: DataFusion table to add to the session context.
"""
self.ctx.register_table(name, table)
self.ctx.register_table(name, table.table)

def deregister_table(self, name: str) -> None:
"""Remove a table from the session."""
Expand Down Expand Up @@ -752,7 +752,7 @@ def register_parquet(
file_extension: str = ".parquet",
skip_metadata: bool = True,
schema: pyarrow.Schema | None = None,
file_sort_order: list[list[Expr]] | None = None,
file_sort_order: list[list[SortExpr]] | None = None,
) -> None:
"""Register a Parquet file as a table.

Expand Down Expand Up @@ -783,7 +783,9 @@ def register_parquet(
file_extension,
skip_metadata,
schema,
file_sort_order,
[sort_list_to_raw_sort_list(exprs) for exprs in file_sort_order]
if file_sort_order is not None
else None,
)

def register_csv(
Expand Down Expand Up @@ -919,7 +921,7 @@ def register_udwf(self, udwf: WindowUDF) -> None:

def catalog(self, name: str = "datafusion") -> Catalog:
"""Retrieve a catalog by name."""
return self.ctx.catalog(name)
return Catalog(self.ctx.catalog(name))

@deprecated(
"Use the catalog provider interface ``SessionContext.Catalog`` to "
Expand Down Expand Up @@ -1039,7 +1041,7 @@ def read_parquet(
file_extension: str = ".parquet",
skip_metadata: bool = True,
schema: pyarrow.Schema | None = None,
file_sort_order: list[list[Expr]] | None = None,
file_sort_order: list[list[Expr | SortExpr]] | None = None,
) -> DataFrame:
"""Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`.

Expand All @@ -1063,6 +1065,11 @@ def read_parquet(
"""
if table_partition_cols is None:
table_partition_cols = []
file_sort_order = (
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
if file_sort_order is not None
else None
)
return DataFrame(
self.ctx.read_parquet(
str(path),
Expand Down Expand Up @@ -1106,7 +1113,7 @@ def read_table(self, table: Table) -> DataFrame:
:py:class:`~datafusion.catalog.ListingTable`, create a
:py:class:`~datafusion.dataframe.DataFrame`.
"""
return DataFrame(self.ctx.read_table(table))
return DataFrame(self.ctx.read_table(table.table))

def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream:
"""Execute the ``plan`` and return the results."""
Expand Down
3 changes: 2 additions & 1 deletion python/datafusion/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from enum import Enum

from datafusion._internal import DataFrame as DataFrameInternal
from datafusion._internal import expr as expr_internal
from datafusion.expr import Expr, SortExpr, sort_or_default


Expand Down Expand Up @@ -270,7 +271,7 @@ def with_columns(

def _simplify_expression(
*exprs: Expr | Iterable[Expr], **named_exprs: Expr
) -> list[Expr]:
) -> list[expr_internal.Expr]:
expr_list = []
for expr in exprs:
if isinstance(expr, Expr):
Expand Down
8 changes: 4 additions & 4 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:
"""Helper function to return a default Sort if an Expr is provided."""
if isinstance(e, SortExpr):
return e.raw_sort
return SortExpr(e.expr, True, True).raw_sort
return SortExpr(e, True, True).raw_sort


def sort_list_to_raw_sort_list(
Expand Down Expand Up @@ -227,7 +227,7 @@ def variant_name(self) -> str:

def __richcmp__(self, other: Expr, op: int) -> Expr:
"""Comparison operator."""
return Expr(self.expr.__richcmp__(other, op))
return Expr(self.expr.__richcmp__(other.expr, op))

def __repr__(self) -> str:
"""Generate a string representation of this expression."""
Expand Down Expand Up @@ -413,7 +413,7 @@ def sort(self, ascending: bool = True, nulls_first: bool = True) -> SortExpr:
ascending: If true, sort in ascending order.
nulls_first: Return null values first.
"""
return SortExpr(self.expr, ascending=ascending, nulls_first=nulls_first)
return SortExpr(self, ascending=ascending, nulls_first=nulls_first)

def is_null(self) -> Expr:
"""Returns ``True`` if this expression is null."""
Expand Down Expand Up @@ -785,7 +785,7 @@ class SortExpr:

def __init__(self, expr: Expr, ascending: bool, nulls_first: bool) -> None:
"""This constructor should not be called by the end user."""
self.raw_sort = expr_internal.SortExpr(expr, ascending, nulls_first)
self.raw_sort = expr_internal.SortExpr(expr.expr, ascending, nulls_first)

def expr(self) -> Expr:
"""Return the raw expr backing the SortExpr."""
Expand Down
10 changes: 7 additions & 3 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def concat_ws(separator: str, *args: Expr) -> Expr:

def order_by(expr: Expr, ascending: bool = True, nulls_first: bool = True) -> SortExpr:
"""Creates a new sort expression."""
return SortExpr(expr.expr, ascending=ascending, nulls_first=nulls_first)
return SortExpr(expr, ascending=ascending, nulls_first=nulls_first)


def alias(expr: Expr, name: str) -> Expr:
Expand Down Expand Up @@ -942,6 +942,7 @@ def to_timestamp_millis(arg: Expr, *formatters: Expr) -> Expr:

See :py:func:`to_timestamp` for a description on how to use formatters.
"""
formatters = [f.expr for f in formatters]
return Expr(f.to_timestamp_millis(arg.expr, *formatters))


Expand All @@ -950,6 +951,7 @@ def to_timestamp_micros(arg: Expr, *formatters: Expr) -> Expr:

See :py:func:`to_timestamp` for a description on how to use formatters.
"""
formatters = [f.expr for f in formatters]
return Expr(f.to_timestamp_micros(arg.expr, *formatters))


Expand All @@ -958,6 +960,7 @@ def to_timestamp_nanos(arg: Expr, *formatters: Expr) -> Expr:

See :py:func:`to_timestamp` for a description on how to use formatters.
"""
formatters = [f.expr for f in formatters]
return Expr(f.to_timestamp_nanos(arg.expr, *formatters))


Expand All @@ -966,6 +969,7 @@ def to_timestamp_seconds(arg: Expr, *formatters: Expr) -> Expr:

See :py:func:`to_timestamp` for a description on how to use formatters.
"""
formatters = [f.expr for f in formatters]
return Expr(f.to_timestamp_seconds(arg.expr, *formatters))


Expand Down Expand Up @@ -1078,9 +1082,9 @@ def range(start: Expr, stop: Expr, step: Expr) -> Expr:
return Expr(f.range(start.expr, stop.expr, step.expr))


def uuid(arg: Expr) -> Expr:
def uuid() -> Expr:
"""Returns uuid v4 as a string value."""
return Expr(f.uuid(arg.expr))
return Expr(f.uuid())


def struct(*args: Expr) -> Expr:
Expand Down
10 changes: 5 additions & 5 deletions python/datafusion/input/location.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,20 @@ def is_correct_input(self, input_item: Any, table_name: str, **kwargs):

def build_table(
self,
input_file: str,
input_item: str,
table_name: str,
**kwargs,
) -> SqlTable:
"""Create a table from the input source."""
_, extension = os.path.splitext(input_file)
_, extension = os.path.splitext(input_item)
format = extension.lstrip(".").lower()
num_rows = 0 # Total number of rows in the file. Used for statistics
columns = []
if format == "parquet":
import pyarrow.parquet as pq

# Read the Parquet metadata
metadata = pq.read_metadata(input_file)
metadata = pq.read_metadata(input_item)
num_rows = metadata.num_rows
# Iterate through the schema and build the SqlTable
for col in metadata.schema:
Expand All @@ -69,7 +69,7 @@ def build_table(
# to get that information. However, this should only be occurring
# at table creation time and therefore shouldn't
# slow down query performance.
with open(input_file, "r") as file:
with open(input_item, "r") as file:
reader = csv.reader(file)
header_row = next(reader)
print(header_row)
Expand All @@ -84,6 +84,6 @@ def build_table(
)

# Input could possibly be multiple files. Create a list if so
input_files = glob.glob(input_file)
input_files = glob.glob(input_item)

return SqlTable(table_name, columns, num_rows, input_files)
7 changes: 4 additions & 3 deletions python/datafusion/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class ScalarUDF:

def __init__(
self,
name: Optional[str],
name: str,
func: Callable[..., _R],
input_types: pyarrow.DataType | list[pyarrow.DataType],
return_type: _R,
Expand Down Expand Up @@ -182,7 +182,7 @@ class AggregateUDF:

def __init__(
self,
name: Optional[str],
name: str,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove the Optional[str] if we still have the logic below to allow None?

Copy link
Contributor Author

@chenkovsky chenkovsky Feb 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems that in rust binding, it's not optional

name: &str,

accumulator: Callable[[], Accumulator],
input_types: list[pyarrow.DataType],
return_type: pyarrow.DataType,
Expand Down Expand Up @@ -277,6 +277,7 @@ def sum_bias_10() -> Summarize:
)
if name is None:
name = accum.__call__().__class__.__qualname__.lower()
assert name is not None
if isinstance(input_types, pyarrow.DataType):
input_types = [input_types]
return AggregateUDF(
Expand Down Expand Up @@ -462,7 +463,7 @@ class WindowUDF:

def __init__(
self,
name: Optional[str],
name: str,
func: Callable[[], WindowEvaluator],
input_types: list[pyarrow.DataType],
return_type: pyarrow.DataType,
Expand Down
30 changes: 30 additions & 0 deletions python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,22 @@ def test_temporal_functions(df):
f.to_timestamp_millis(literal("2023-09-07 05:06:14.523952")),
f.to_timestamp_micros(literal("2023-09-07 05:06:14.523952")),
f.extract(literal("day"), column("d")),
f.to_timestamp(
literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f")
),
f.to_timestamp_seconds(
literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f")
),
f.to_timestamp_millis(
literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f")
),
f.to_timestamp_micros(
literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f")
),
f.to_timestamp_nanos(literal("2023-09-07 05:06:14.523952")),
f.to_timestamp_nanos(
literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f")
),
)
result = df.collect()
assert len(result) == 1
Expand Down Expand Up @@ -913,6 +928,21 @@ def test_temporal_functions(df):
assert result.column(11) == pa.array(
[datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("ns")
)
assert result.column(12) == pa.array(
[datetime(2023, 9, 7, 5, 6, 14)] * 3, type=pa.timestamp("s")
)
assert result.column(13) == pa.array(
[datetime(2023, 9, 7, 5, 6, 14, 523000)] * 3, type=pa.timestamp("ms")
)
assert result.column(14) == pa.array(
[datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("us")
)
assert result.column(15) == pa.array(
[datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("ns")
)
assert result.column(16) == pa.array(
[datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("ns")
)


def test_arrow_cast(df):
Expand Down