From 578ad92b11612f7f7750de6d4d4e38518b5cccaa Mon Sep 17 00:00:00 2001 From: MarcoGorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 23 Oct 2023 10:18:29 +0100 Subject: [PATCH 1/5] rename get_column_by_name to col --- .../dataframe_api/dataframe_object.py | 4 ++-- .../examples/01_standardise_columns.py | 2 +- spec/API_specification/examples/02_plotting.py | 4 ++-- spec/API_specification/examples/tpch/q5.py | 14 +++++--------- spec/design_topics/python_builtin_types.md | 4 ++-- 5 files changed, 12 insertions(+), 16 deletions(-) diff --git a/spec/API_specification/dataframe_api/dataframe_object.py b/spec/API_specification/dataframe_api/dataframe_object.py index 6b5a1362..0aa35117 100644 --- a/spec/API_specification/dataframe_api/dataframe_object.py +++ b/spec/API_specification/dataframe_api/dataframe_object.py @@ -89,7 +89,7 @@ def group_by(self, keys: str | list[str], /) -> GroupBy: """ ... - def get_column_by_name(self, name: str, /) -> Column: + def col(self, name: str, /) -> Column: """ Select a column by name. @@ -192,7 +192,7 @@ def assign(self, columns: Column | Sequence[Column], /) -> DataFrame: .. code-block:: python - new_column = df.get_column_by_name('a') + 1 + new_column = df.col('a') + 1 df = df.assign(new_column.rename('b')) Parameters diff --git a/spec/API_specification/examples/01_standardise_columns.py b/spec/API_specification/examples/01_standardise_columns.py index 476a4b10..e7b9d78e 100644 --- a/spec/API_specification/examples/01_standardise_columns.py +++ b/spec/API_specification/examples/01_standardise_columns.py @@ -11,7 +11,7 @@ def my_dataframe_agnostic_function(df_non_standard: SupportsDataFrameAPI) -> Any for column_name in df.column_names: if column_name == 'species': continue - new_column = df.get_column_by_name(column_name) + new_column = df.col(column_name) new_column = (new_column - new_column.mean()) / new_column.std() df = df.assign(new_column.rename(f'{column_name}_scaled')) diff --git a/spec/API_specification/examples/02_plotting.py b/spec/API_specification/examples/02_plotting.py index 31e12253..999b1ce2 100644 --- a/spec/API_specification/examples/02_plotting.py +++ b/spec/API_specification/examples/02_plotting.py @@ -23,7 +23,7 @@ def group_by_and_plot( df = namespace.dataframe_from_dict({"x": x, "y": y, "color": color}) agg = df.group_by("color").mean() - x = agg.get_column_by_name("x").to_array_object(namespace.Float64()) - y = agg.get_column_by_name("y").to_array_object(namespace.Float64()) + x = agg.col("x").to_array_object(namespace.Float64()) + y = agg.col("y").to_array_object(namespace.Float64()) my_plotting_function(x, y) diff --git a/spec/API_specification/examples/tpch/q5.py b/spec/API_specification/examples/tpch/q5.py index cdca0806..30c98665 100644 --- a/spec/API_specification/examples/tpch/q5.py +++ b/spec/API_specification/examples/tpch/q5.py @@ -53,19 +53,15 @@ def query( ) ) mask = ( - ( - result.get_column_by_name("c_nationkey") - == result.get_column_by_name("s_nationkey") - ) - & (result.get_column_by_name("r_name") == "ASIA") - & (result.get_column_by_name("o_orderdate") >= namespace.date(1994, 1, 1)) # type: ignore - & (result.get_column_by_name("o_orderdate") < namespace.date(1995, 1, 1)) # type: ignore + (result.col("c_nationkey") == result.col("s_nationkey")) + & (result.col("r_name") == "ASIA") + & (result.col("o_orderdate") >= namespace.date(1994, 1, 1)) # type: ignore + & (result.col("o_orderdate") < namespace.date(1995, 1, 1)) # type: ignore ) result = result.filter(mask) new_column = ( - result.get_column_by_name("l_extendedprice") - * (1 - result.get_column_by_name("l_discount")) + result.col("l_extendedprice") * (1 - result.col("l_discount")) ).rename("revenue") result = result.assign(new_column) result = result.select(["revenue", "n_name"]) diff --git a/spec/design_topics/python_builtin_types.md b/spec/design_topics/python_builtin_types.md index 93de5c53..c85812eb 100644 --- a/spec/design_topics/python_builtin_types.md +++ b/spec/design_topics/python_builtin_types.md @@ -14,14 +14,14 @@ the `float` it is documented to return, in combination with the `__gt__` method class DataFrame: def __gt__(self, other: DataFrame | Scalar) -> DataFrame: ... - def get_column_by_name(self, name: str, /) -> Column: + def col(self, name: str, /) -> Column: ... class Column: def mean(self, skip_nulls: bool = True) -> float | NullType: ... -larger = df2 > df1.get_column_by_name('foo').mean() +larger = df2 > df1.col('foo').mean() ``` For a GPU dataframe library, it is desirable for all data to reside on the GPU, From 4310a0f9bba8e69da4b4b2e1e0ab642bba6a2d85 Mon Sep 17 00:00:00 2001 From: MarcoGorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 23 Oct 2023 10:36:01 +0100 Subject: [PATCH 2/5] type DataFrame, Column, and GroupBy as Protocol --- spec/API_specification/.mypy.ini | 3 --- spec/API_specification/dataframe_api/__init__.py | 1 + spec/API_specification/dataframe_api/column_object.py | 4 ++-- spec/API_specification/dataframe_api/dataframe_object.py | 4 ++-- spec/API_specification/dataframe_api/groupby_object.py | 4 ++-- 5 files changed, 7 insertions(+), 9 deletions(-) diff --git a/spec/API_specification/.mypy.ini b/spec/API_specification/.mypy.ini index b165602a..7caa5bf5 100644 --- a/spec/API_specification/.mypy.ini +++ b/spec/API_specification/.mypy.ini @@ -1,5 +1,2 @@ [mypy] strict=True - -[mypy-dataframe_api.*] -disable_error_code=empty-body diff --git a/spec/API_specification/dataframe_api/__init__.py b/spec/API_specification/dataframe_api/__init__.py index 7f4d17d4..10835a02 100644 --- a/spec/API_specification/dataframe_api/__init__.py +++ b/spec/API_specification/dataframe_api/__init__.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="empty-body" """ Function stubs and API documentation for the DataFrame API standard. """ diff --git a/spec/API_specification/dataframe_api/column_object.py b/spec/API_specification/dataframe_api/column_object.py index 980e1e03..e96d2812 100644 --- a/spec/API_specification/dataframe_api/column_object.py +++ b/spec/API_specification/dataframe_api/column_object.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any,NoReturn, TYPE_CHECKING, Literal, Generic +from typing import Any,NoReturn, TYPE_CHECKING, Literal, Protocol if TYPE_CHECKING: from .typing import NullType, Scalar, DType, Namespace @@ -9,7 +9,7 @@ __all__ = ['Column'] -class Column: +class Column(Protocol): """ Column object diff --git a/spec/API_specification/dataframe_api/dataframe_object.py b/spec/API_specification/dataframe_api/dataframe_object.py index 6b5a1362..0bc7b9cf 100644 --- a/spec/API_specification/dataframe_api/dataframe_object.py +++ b/spec/API_specification/dataframe_api/dataframe_object.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Literal, Mapping, Sequence, Union, TYPE_CHECKING, NoReturn +from typing import Any, Literal, Mapping, Sequence, TYPE_CHECKING, NoReturn, Protocol if TYPE_CHECKING: @@ -12,7 +12,7 @@ __all__ = ["DataFrame"] -class DataFrame: +class DataFrame(Protocol): """ DataFrame object diff --git a/spec/API_specification/dataframe_api/groupby_object.py b/spec/API_specification/dataframe_api/groupby_object.py index 0ccefebe..062bb2d5 100644 --- a/spec/API_specification/dataframe_api/groupby_object.py +++ b/spec/API_specification/dataframe_api/groupby_object.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Protocol if TYPE_CHECKING: from .dataframe_object import DataFrame @@ -9,7 +9,7 @@ __all__ = ['GroupBy'] -class GroupBy: +class GroupBy(Protocol): """ GroupBy object. From 467ba1633d0aa4d55e22a51c3a5f84fb936eaabc Mon Sep 17 00:00:00 2001 From: MarcoGorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 25 Oct 2023 09:14:12 +0100 Subject: [PATCH 3/5] add py.typed --- spec/API_specification/dataframe_api/py.typed | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 spec/API_specification/dataframe_api/py.typed diff --git a/spec/API_specification/dataframe_api/py.typed b/spec/API_specification/dataframe_api/py.typed new file mode 100644 index 00000000..e69de29b From 0ef48526b1d04c8d38e5455db679bb3787692b24 Mon Sep 17 00:00:00 2001 From: MarcoGorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 25 Oct 2023 10:30:23 +0100 Subject: [PATCH 4/5] iterable argument in select --- spec/API_specification/dataframe_api/dataframe_object.py | 4 ++-- spec/API_specification/examples/tpch/q5.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/spec/API_specification/dataframe_api/dataframe_object.py b/spec/API_specification/dataframe_api/dataframe_object.py index 475797ab..e7899260 100644 --- a/spec/API_specification/dataframe_api/dataframe_object.py +++ b/spec/API_specification/dataframe_api/dataframe_object.py @@ -109,13 +109,13 @@ def get_column_by_name(self, name: str, /) -> Column: """ ... - def select(self, names: Sequence[str], /) -> Self: + def select(self, *names: str) -> Self: """ Select multiple columns by name. Parameters ---------- - names : Sequence[str] + *names : str Returns ------- diff --git a/spec/API_specification/examples/tpch/q5.py b/spec/API_specification/examples/tpch/q5.py index cdca0806..b69b8300 100644 --- a/spec/API_specification/examples/tpch/q5.py +++ b/spec/API_specification/examples/tpch/q5.py @@ -68,7 +68,7 @@ def query( * (1 - result.get_column_by_name("l_discount")) ).rename("revenue") result = result.assign(new_column) - result = result.select(["revenue", "n_name"]) + result = result.select("revenue", "n_name") result = result.group_by("n_name").sum() return result.dataframe From c372fc11e39c1213cd31f06bc60730ab2d47b2ec Mon Sep 17 00:00:00 2001 From: MarcoGorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 26 Oct 2023 10:51:21 +0100 Subject: [PATCH 5/5] update tpch q1 --- spec/API_specification/examples/tpch/q1.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/spec/API_specification/examples/tpch/q1.py b/spec/API_specification/examples/tpch/q1.py index b5c11287..21a1f4a5 100644 --- a/spec/API_specification/examples/tpch/q1.py +++ b/spec/API_specification/examples/tpch/q1.py @@ -8,16 +8,16 @@ def query(lineitem_raw: SupportsDataFrameAPI) -> Any: lineitem = lineitem_raw.__dataframe_consortium_standard__() namespace = lineitem.__dataframe_namespace__() - mask = lineitem.get_column_by_name("l_shipdate") <= namespace.date(1998, 9, 2) + mask = lineitem.col("l_shipdate") <= namespace.date(1998, 9, 2) lineitem = lineitem.assign( ( - lineitem.get_column_by_name("l_extended_price") - * (1 - lineitem.get_column_by_name("l_discount")) + lineitem.col("l_extended_price") + * (1 - lineitem.col("l_discount")) ).rename("l_disc_price"), ( - lineitem.get_column_by_name("l_extended_price") - * (1 - lineitem.get_column_by_name("l_discount")) - * (1 + lineitem.get_column_by_name("l_tax")) + lineitem.col("l_extended_price") + * (1 - lineitem.col("l_discount")) + * (1 + lineitem.col("l_tax")) ).rename("l_charge"), ) result = (