diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml new file mode 100644 index 00000000..a11bde7a --- /dev/null +++ b/.github/workflows/mypy.yml @@ -0,0 +1,32 @@ +name: mypy + +on: + pull_request: + push: + branches: [main] + +jobs: + tox: + strategy: + matrix: + python-version: ["3.8", "3.11"] + os: [ubuntu-latest] + + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Cache multiple paths + uses: actions/cache@v3 + with: + path: | + ~/.cache/pip + $RUNNER_TOOL_CACHE/Python/* + ~\AppData\Local\pip\Cache + key: ${{ runner.os }}-build-${{ matrix.python-version }} + - name: install-reqs + run: python -m pip install --upgrade mypy==1.4.0 + - name: run mypy + run: cd spec/API_specification && mypy dataframe_api diff --git a/spec/API_specification/.mypy.ini b/spec/API_specification/.mypy.ini new file mode 100644 index 00000000..b165602a --- /dev/null +++ b/spec/API_specification/.mypy.ini @@ -0,0 +1,5 @@ +[mypy] +strict=True + +[mypy-dataframe_api.*] +disable_error_code=empty-body diff --git a/spec/API_specification/dataframe_api/__init__.py b/spec/API_specification/dataframe_api/__init__.py index b8b7014c..ed664363 100644 --- a/spec/API_specification/dataframe_api/__init__.py +++ b/spec/API_specification/dataframe_api/__init__.py @@ -6,18 +6,19 @@ from typing import Mapping, Sequence, Any from .column_object import * -from .dataframe_object import * +from .dataframe_object import DataFrame from .groupby_object import * - +from ._types import DType __all__ = [ - "__dataframe_api_version", + "__dataframe_api_version__", + "DataFrame", + "Column", "column_from_sequence", "concat", "dataframe_from_dict", "is_null", "null", - "DType", "Int64", "Int32", "Int16", @@ -59,7 +60,7 @@ def concat(dataframes: Sequence[DataFrame]) -> DataFrame: """ ... -def column_from_sequence(sequence: Sequence[object], *, dtype: DType) -> Column: +def column_from_sequence(sequence: Sequence[Any], *, dtype: Any) -> Column[Any]: """ Construct Column from sequence of elements. @@ -78,7 +79,7 @@ def column_from_sequence(sequence: Sequence[object], *, dtype: DType) -> Column: """ ... -def dataframe_from_dict(data: Mapping[str, Column]) -> DataFrame: +def dataframe_from_dict(data: Mapping[str, Column[Any]]) -> DataFrame: """ Construct DataFrame from map of column names to Columns. @@ -144,38 +145,35 @@ def is_null(value: object, /) -> bool: # Dtypes # ########## -class DType: - """Base class for all dtypes.""" - -class Int64(DType): +class Int64: """Integer type with 64 bits of precision.""" -class Int32(DType): +class Int32: """Integer type with 32 bits of precision.""" -class Int16(DType): +class Int16: """Integer type with 16 bits of precision.""" -class Int8(DType): +class Int8: """Integer type with 8 bits of precision.""" -class UInt64(DType): +class UInt64: """Unsigned integer type with 64 bits of precision.""" -class UInt32(DType): +class UInt32: """Unsigned integer type with 32 bits of precision.""" -class UInt16(DType): +class UInt16: """Unsigned integer type with 16 bits of precision.""" -class UInt8(DType): +class UInt8: """Unsigned integer type with 8 bits of precision.""" -class Float64(DType): +class Float64: """Floating point type with 64 bits of precision.""" -class Float32(DType): +class Float32: """Floating point type with 32 bits of precision.""" -class Bool(DType): +class Bool: """Boolean type with 8 bits of precision.""" diff --git a/spec/API_specification/dataframe_api/_types.py b/spec/API_specification/dataframe_api/_types.py index 2874ba4c..dde7795a 100644 --- a/spec/API_specification/dataframe_api/_types.py +++ b/spec/API_specification/dataframe_api/_types.py @@ -20,8 +20,11 @@ ) from enum import Enum +# Type alias: Mypy needs Any, but for readability we need to make clear this +# is a Python scalar (i.e., an instance of `bool`, `int`, `float`, `str`, etc.) +Scalar = Any + array = TypeVar("array") -Scalar = TypeVar("Scalar") device = TypeVar("device") DType = TypeVar("DType") SupportsDLPack = TypeVar("SupportsDLPack") diff --git a/spec/API_specification/dataframe_api/column_object.py b/spec/API_specification/dataframe_api/column_object.py index 10ba9a35..ffcaca31 100644 --- a/spec/API_specification/dataframe_api/column_object.py +++ b/spec/API_specification/dataframe_api/column_object.py @@ -1,16 +1,18 @@ from __future__ import annotations -from typing import Any,NoReturn, Sequence, TYPE_CHECKING, Literal +from typing import Any,NoReturn, Sequence, TYPE_CHECKING, Literal, Generic + +from ._types import DType if TYPE_CHECKING: + from . import Bool, null from ._types import Scalar - from . import DType __all__ = ['Column'] -class Column: +class Column(Generic[DType]): """ Column object @@ -21,7 +23,7 @@ class Column: """ def __column_namespace__( - self: Column, /, *, api_version: str | None = None + self, /, *, api_version: str | None = None ) -> Any: """ Returns an object that has all the Dataframe Standard API functions on it. @@ -48,7 +50,7 @@ def __column_namespace__( """ @property - def column(self) -> object: + def column(self) -> Any: """ Return underlying (not-necessarily-Standard-compliant) column. @@ -74,12 +76,12 @@ def __iter__(self) -> NoReturn: raise NotImplementedError("'__iter__' is intentionally not implemented.") @property - def dtype(self) -> DType: + def dtype(self) -> Any: """ Return data type of column. """ - def get_rows(self, indices: Column[int]) -> Column: + def get_rows(self: Column[DType], indices: Column[Any]) -> Column[DType]: """ Select a subset of rows, similar to `ndarray.take`. @@ -112,7 +114,7 @@ def sorted_indices( *, ascending: bool = True, nulls_position: Literal['first', 'last'] = 'last', - ) -> Column[int]: + ) -> Column[Any]: """ Return row numbers which would sort column. @@ -137,7 +139,7 @@ def sorted_indices( """ ... - def __eq__(self, other: Column | Scalar) -> Column: + def __eq__(self, other: Column[Any] | Scalar) -> Column[Bool]: # type: ignore[override] """ Compare for equality. @@ -155,7 +157,7 @@ def __eq__(self, other: Column | Scalar) -> Column: Column """ - def __ne__(self, other: Column | Scalar) -> Column: + def __ne__(self: Column[DType], other: Column[DType] | Scalar) -> Column[Bool]: # type: ignore[override] """ Compare for non-equality. @@ -173,7 +175,7 @@ def __ne__(self, other: Column | Scalar) -> Column: Column """ - def __ge__(self, other: Column | Scalar) -> Column: + def __ge__(self: Column[DType], other: Column[DType] | Scalar) -> Column[Bool]: """ Compare for "greater than or equal to" `other`. @@ -189,7 +191,7 @@ def __ge__(self, other: Column | Scalar) -> Column: Column """ - def __gt__(self, other: Column | Scalar) -> Column: + def __gt__(self: Column[DType], other: Column[DType] | Scalar) -> Column[Bool]: """ Compare for "greater than" `other`. @@ -205,7 +207,7 @@ def __gt__(self, other: Column | Scalar) -> Column: Column """ - def __le__(self, other: Column | Scalar) -> Column: + def __le__(self: Column[DType], other: Column[DType] | Scalar) -> Column[Bool]: """ Compare for "less than or equal to" `other`. @@ -221,7 +223,7 @@ def __le__(self, other: Column | Scalar) -> Column: Column """ - def __lt__(self, other: Column | Scalar) -> Column: + def __lt__(self: Column[DType], other: Column[DType] | Scalar) -> Column[Bool]: """ Compare for "less than" `other`. @@ -237,7 +239,7 @@ def __lt__(self, other: Column | Scalar) -> Column: Column """ - def __and__(self, other: Column[bool] | bool) -> Column[bool]: + def __and__(self: Column[Bool], other: Column[Bool] | bool) -> Column[Bool]: """ Apply logical 'and' to `other` Column (or scalar) and this Column. @@ -258,7 +260,7 @@ def __and__(self, other: Column[bool] | bool) -> Column[bool]: If `self` or `other` is not boolean. """ - def __or__(self, other: Column[bool] | bool) -> Column[bool]: + def __or__(self: Column[Bool], other: Column[Bool] | bool) -> Column[Bool]: """ Apply logical 'or' to `other` Column (or scalar) and this column. @@ -279,7 +281,7 @@ def __or__(self, other: Column[bool] | bool) -> Column[bool]: If `self` or `other` is not boolean. """ - def __add__(self, other: Column | Scalar) -> Column: + def __add__(self: Column[Any], other: Column[Any] | Scalar) -> Column[Any]: """ Add `other` column or scalar to this column. @@ -295,7 +297,7 @@ def __add__(self, other: Column | Scalar) -> Column: Column """ - def __sub__(self, other: Column | Scalar) -> Column: + def __sub__(self: Column[Any], other: Column[Any] | Scalar) -> Column[Any]: """ Subtract `other` column or scalar from this column. @@ -311,7 +313,7 @@ def __sub__(self, other: Column | Scalar) -> Column: Column """ - def __mul__(self, other: Column | Scalar) -> Column: + def __mul__(self, other: Column[Any] | Scalar) -> Column[Any]: """ Multiply `other` column or scalar with this column. @@ -327,7 +329,7 @@ def __mul__(self, other: Column | Scalar) -> Column: Column """ - def __truediv__(self, other: Column | Scalar) -> Column: + def __truediv__(self, other: Column[Any] | Scalar) -> Column[Any]: """ Divide this column by `other` column or scalar. True division, returns floats. @@ -343,7 +345,7 @@ def __truediv__(self, other: Column | Scalar) -> Column: Column """ - def __floordiv__(self, other: Column | Scalar) -> Column: + def __floordiv__(self, other: Column[Any] | Scalar) -> Column[Any]: """ Floor-divide `other` column or scalar to this column. @@ -359,7 +361,7 @@ def __floordiv__(self, other: Column | Scalar) -> Column: Column """ - def __pow__(self, other: Column | Scalar) -> Column: + def __pow__(self, other: Column[Any] | Scalar) -> Column[Any]: """ Raise this column to the power of `other`. @@ -379,7 +381,7 @@ def __pow__(self, other: Column | Scalar) -> Column: Column """ - def __mod__(self, other: Column | Scalar) -> Column: + def __mod__(self, other: Column[Any] | Scalar) -> Column[Any]: """ Returns modulus of this column by `other` (`%` operator). @@ -395,7 +397,7 @@ def __mod__(self, other: Column | Scalar) -> Column: Column """ - def __divmod__(self, other: Column | Scalar) -> tuple[Column, Column]: + def __divmod__(self, other: Column[Any] | Scalar) -> tuple[Column[Any], Column[Any]]: """ Return quotient and remainder of integer division. See `divmod` builtin function. @@ -411,7 +413,7 @@ def __divmod__(self, other: Column | Scalar) -> tuple[Column, Column]: Column """ - def __invert__(self) -> Column: + def __invert__(self: Column[Bool]) -> Column[Bool]: """ Invert truthiness of (boolean) elements. @@ -421,7 +423,7 @@ def __invert__(self) -> Column: If any of the Column's columns is not boolean. """ - def any(self, *, skip_nulls: bool = True) -> bool: + def any(self: Column[Bool], *, skip_nulls: bool = True) -> bool: """ Reduction returns a bool. @@ -431,7 +433,7 @@ def any(self, *, skip_nulls: bool = True) -> bool: If column is not boolean. """ - def all(self, *, skip_nulls: bool = True) -> bool: + def all(self: Column[Bool], *, skip_nulls: bool = True) -> bool: """ Reduction returns a bool. @@ -525,33 +527,33 @@ def var(self, *, correction: int | float = 1, skip_nulls: bool = True) -> Scalar Whether to skip null values. """ - def cumulative_max(self) -> Column: + def cumulative_max(self: Column[DType]) -> Column[DType]: """ Reduction returns a Column. Any data type that supports comparisons must be supported. The returned value has the same dtype as the column. """ - def cumulative_min(self) -> Column: + def cumulative_min(self: Column[DType]) -> Column[DType]: """ Reduction returns a Column. Any data type that supports comparisons must be supported. The returned value has the same dtype as the column. """ - def cumulative_sum(self) -> Column: + def cumulative_sum(self: Column[DType]) -> Column[DType]: """ Reduction returns a Column. Must be supported for numerical and datetime data types. The returned value has the same dtype as the column. """ - def cumulative_prod(self) -> Column: + def cumulative_prod(self: Column[DType]) -> Column[DType]: """ Reduction returns a Column. Must be supported for numerical and datetime data types. The returned value has the same dtype as the column. """ - def is_null(self) -> Column: + def is_null(self) -> Column[Bool]: """ Check for 'missing' or 'null' entries. @@ -570,7 +572,7 @@ def is_null(self) -> Column: but note that the Standard makes no guarantees about them. """ - def is_nan(self) -> Column: + def is_nan(self) -> Column[Bool]: """ Check for nan entries. @@ -589,7 +591,7 @@ def is_nan(self) -> Column: In particular, does not check for `np.timedelta64('NaT')`. """ - def is_in(self, values: Column) -> Column[bool]: + def is_in(self: Column[DType], values: Column[DType]) -> Column[Bool]: """ Indicate whether the value at each row matches any value in `values`. @@ -607,7 +609,7 @@ def is_in(self, values: Column) -> Column[bool]: Column[bool] """ - def unique_indices(self, *, skip_nulls: bool = True) -> Column[int]: + def unique_indices(self, *, skip_nulls: bool = True) -> Column[Any]: """ Return indices corresponding to unique values in Column. @@ -628,7 +630,7 @@ def unique_indices(self, *, skip_nulls: bool = True) -> Column[int]: """ ... - def fill_nan(self, value: float | 'null', /) -> Column: + def fill_nan(self: Column[DType], value: float | 'null', /) -> Column[DType]: """ Fill floating point ``nan`` values with the given fill value. @@ -642,7 +644,7 @@ def fill_nan(self, value: float | 'null', /) -> Column: """ ... - def fill_null(self, value: Scalar, /) -> Column: + def fill_null(self: Column[DType], value: Scalar, /) -> Column[DType]: """ Fill null values with the given fill value. diff --git a/spec/API_specification/dataframe_api/dataframe_object.py b/spec/API_specification/dataframe_api/dataframe_object.py index f56dae67..827d0f4c 100644 --- a/spec/API_specification/dataframe_api/dataframe_object.py +++ b/spec/API_specification/dataframe_api/dataframe_object.py @@ -6,6 +6,7 @@ if TYPE_CHECKING: from .column_object import Column from .groupby_object import GroupBy + from . import Bool, null from ._types import Scalar @@ -37,7 +38,7 @@ class DataFrame: """ def __dataframe_namespace__( - self: DataFrame, /, *, api_version: str | None = None + self, /, *, api_version: str | None = None ) -> Any: """ Returns an object that has all the dataframe API functions on it. @@ -102,7 +103,7 @@ def groupby(self, keys: Sequence[str], /) -> GroupBy: """ ... - def get_column_by_name(self, name: str, /) -> Column: + def get_column_by_name(self, name: str, /) -> Column[Any]: """ Select a column by name. @@ -140,7 +141,7 @@ def get_columns_by_name(self, names: Sequence[str], /) -> DataFrame: """ ... - def get_rows(self, indices: "Column[int]") -> DataFrame: + def get_rows(self, indices: Column[Any]) -> DataFrame: """ Select a subset of rows, similar to `ndarray.take`. @@ -173,7 +174,7 @@ def slice_rows( """ ... - def get_rows_by_mask(self, mask: "Column[bool]") -> DataFrame: + def get_rows_by_mask(self, mask: Column[Bool]) -> DataFrame: """ Select a subset of rows corresponding to a mask. @@ -192,7 +193,7 @@ def get_rows_by_mask(self, mask: "Column[bool]") -> DataFrame: """ ... - def insert(self, loc: int, label: str, value: Column) -> DataFrame: + def insert(self, loc: int, label: str, value: Column[Any]) -> DataFrame: """ Insert column into DataFrame at specified location. @@ -256,7 +257,7 @@ def sorted_indices( *, ascending: Sequence[bool] | bool = True, nulls_position: Literal['first', 'last'] = 'last', - ) -> Column[int]: + ) -> Column[Any]: """ Return row numbers which would sort according to given columns. @@ -291,7 +292,7 @@ def sorted_indices( """ ... - def __eq__(self, other: DataFrame | Scalar) -> DataFrame: + def __eq__(self, other: DataFrame | Scalar) -> DataFrame: # type: ignore[override] """ Compare for equality. @@ -310,7 +311,7 @@ def __eq__(self, other: DataFrame | Scalar) -> DataFrame: """ ... - def __ne__(self, other: DataFrame | Scalar) -> DataFrame: + def __ne__(self, other: DataFrame | Scalar) -> DataFrame: # type: ignore[override] """ Compare for non-equality. @@ -397,7 +398,7 @@ def __lt__(self, other: DataFrame | Scalar) -> DataFrame: """ ... - def __and__(self, other: DataFrame[bool] | bool) -> DataFrame[bool]: + def __and__(self, other: DataFrame | bool) -> DataFrame: """ Apply logical 'and' to `other` DataFrame (or scalar) and this dataframe. @@ -418,7 +419,7 @@ def __and__(self, other: DataFrame[bool] | bool) -> DataFrame[bool]: If `self` or `other` is not boolean. """ - def __or__(self, other: DataFrame[bool] | bool) -> DataFrame[bool]: + def __or__(self, other: DataFrame | bool) -> DataFrame: """ Apply logical 'or' to `other` DataFrame (or scalar) and this DataFrame. @@ -624,7 +625,7 @@ def all(self, *, skip_nulls: bool = True) -> DataFrame: """ ... - def any_rowwise(self, *, skip_nulls: bool = True) -> Column: + def any_rowwise(self, *, skip_nulls: bool = True) -> Column[Bool]: """ Reduction returns a Column. @@ -638,7 +639,7 @@ def any_rowwise(self, *, skip_nulls: bool = True) -> Column: """ ... - def all_rowwise(self, *, skip_nulls: bool = True) -> Column: + def all_rowwise(self, *, skip_nulls: bool = True) -> Column[Bool]: """ Reduction returns a Column. diff --git a/spec/API_specification/dataframe_api/groupby_object.py b/spec/API_specification/dataframe_api/groupby_object.py index 11a7b102..c020be9d 100644 --- a/spec/API_specification/dataframe_api/groupby_object.py +++ b/spec/API_specification/dataframe_api/groupby_object.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -17,35 +19,35 @@ class GroupBy: **Methods** """ - def any(self, *, skip_nulls: bool = True) -> "DataFrame": + def any(self, *, skip_nulls: bool = True) -> DataFrame: ... - def all(self, *, skip_nulls: bool = True) -> "DataFrame": + def all(self, *, skip_nulls: bool = True) -> DataFrame: ... - def min(self, *, skip_nulls: bool = True) -> "DataFrame": + def min(self, *, skip_nulls: bool = True) -> DataFrame: ... - def max(self, *, skip_nulls: bool = True) -> "DataFrame": + def max(self, *, skip_nulls: bool = True) -> DataFrame: ... - def sum(self, *, skip_nulls: bool = True) -> "DataFrame": + def sum(self, *, skip_nulls: bool = True) -> DataFrame: ... - def prod(self, *, skip_nulls: bool = True) -> "DataFrame": + def prod(self, *, skip_nulls: bool = True) -> DataFrame: ... - def median(self, *, skip_nulls: bool = True) -> "DataFrame": + def median(self, *, skip_nulls: bool = True) -> DataFrame: ... - def mean(self, *, skip_nulls: bool = True) -> "DataFrame": + def mean(self, *, skip_nulls: bool = True) -> DataFrame: ... - def std(self, *, correction: int | float = 1, skip_nulls: bool = True) -> "DataFrame": + def std(self, *, correction: int | float = 1, skip_nulls: bool = True) -> DataFrame: ... - def var(self, *, correction: int | float = 1, skip_nulls: bool = True) -> "DataFrame": + def var(self, *, correction: int | float = 1, skip_nulls: bool = True) -> DataFrame: ... - def size(self) -> "DataFrame": + def size(self) -> DataFrame: ... diff --git a/spec/API_specification/index.rst b/spec/API_specification/index.rst index 32b81a12..b90d3320 100644 --- a/spec/API_specification/index.rst +++ b/spec/API_specification/index.rst @@ -16,7 +16,6 @@ of objects and functions in the top-level namespace. The latter are: __dataframe_api_version__ is_null null - DType Int64 Int32 Int16 diff --git a/spec/conf.py b/spec/conf.py index 8d3d7800..86ae0c06 100644 --- a/spec/conf.py +++ b/spec/conf.py @@ -82,6 +82,7 @@ ('py:class', 'enum.Enum'), ('py:class', 'ellipsis'), ('py:class', 'Scalar'), + ('py:class', 'Bool'), ] # NOTE: this alias handling isn't used yet - added in anticipation of future # need based on dataframe API aliases.