Skip to content

Commit 1f81476

Browse files
authored
Add details to expectations for scalars (#308)
1 parent 0c171a4 commit 1f81476

File tree

9 files changed

+368
-138
lines changed

9 files changed

+368
-138
lines changed

spec/API_specification/dataframe_api/column_object.py

Lines changed: 66 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,14 @@
55
if TYPE_CHECKING:
66
from typing_extensions import Self
77

8-
from dataframe_api.dataframe_object import DataFrame
9-
10-
from .typing import DType, Namespace, NullType, Scalar
8+
from .typing import (
9+
AnyScalar,
10+
DataFrame,
11+
DType,
12+
Namespace,
13+
NullType,
14+
Scalar,
15+
)
1116

1217

1318
__all__ = ["Column"]
@@ -224,7 +229,7 @@ def sorted_indices(
224229
"""
225230
...
226231

227-
def __eq__(self, other: Self | Scalar) -> Self: # type: ignore[override]
232+
def __eq__(self, other: Self | AnyScalar) -> Self: # type: ignore[override]
228233
"""Compare for equality.
229234
230235
Nulls should follow Kleene Logic.
@@ -247,7 +252,7 @@ def __eq__(self, other: Self | Scalar) -> Self: # type: ignore[override]
247252
"""
248253
...
249254

250-
def __ne__(self, other: Self | Scalar) -> Self: # type: ignore[override]
255+
def __ne__(self, other: Self | AnyScalar) -> Self: # type: ignore[override]
251256
"""Compare for non-equality.
252257
253258
Nulls should follow Kleene Logic.
@@ -270,7 +275,7 @@ def __ne__(self, other: Self | Scalar) -> Self: # type: ignore[override]
270275
"""
271276
...
272277

273-
def __ge__(self, other: Self | Scalar) -> Self:
278+
def __ge__(self, other: Self | AnyScalar) -> Self:
274279
"""Compare for "greater than or equal to" `other`.
275280
276281
Parameters
@@ -291,7 +296,7 @@ def __ge__(self, other: Self | Scalar) -> Self:
291296
"""
292297
...
293298

294-
def __gt__(self, other: Self | Scalar) -> Self:
299+
def __gt__(self, other: Self | AnyScalar) -> Self:
295300
"""Compare for "greater than" `other`.
296301
297302
Parameters
@@ -312,7 +317,7 @@ def __gt__(self, other: Self | Scalar) -> Self:
312317
"""
313318
...
314319

315-
def __le__(self, other: Self | Scalar) -> Self:
320+
def __le__(self, other: Self | AnyScalar) -> Self:
316321
"""Compare for "less than or equal to" `other`.
317322
318323
Parameters
@@ -333,7 +338,7 @@ def __le__(self, other: Self | Scalar) -> Self:
333338
"""
334339
...
335340

336-
def __lt__(self, other: Self | Scalar) -> Self:
341+
def __lt__(self, other: Self | AnyScalar) -> Self:
337342
"""Compare for "less than" `other`.
338343
339344
Parameters
@@ -354,7 +359,7 @@ def __lt__(self, other: Self | Scalar) -> Self:
354359
"""
355360
...
356361

357-
def __and__(self, other: Self | bool) -> Self:
362+
def __and__(self, other: Self | bool | Scalar) -> Self:
358363
"""Apply logical 'and' to `other` Column (or scalar) and this Column.
359364
360365
Nulls should follow Kleene Logic.
@@ -380,7 +385,7 @@ def __and__(self, other: Self | bool) -> Self:
380385
"""
381386
...
382387

383-
def __or__(self, other: Self | bool) -> Self:
388+
def __or__(self, other: Self | bool | Scalar) -> Self:
384389
"""Apply logical 'or' to `other` Column (or scalar) and this column.
385390
386391
Nulls should follow Kleene Logic.
@@ -406,7 +411,7 @@ def __or__(self, other: Self | bool) -> Self:
406411
"""
407412
...
408413

409-
def __add__(self, other: Self | Scalar) -> Self:
414+
def __add__(self, other: Self | AnyScalar) -> Self:
410415
"""Add `other` column or scalar to this column.
411416
412417
Parameters
@@ -427,7 +432,7 @@ def __add__(self, other: Self | Scalar) -> Self:
427432
"""
428433
...
429434

430-
def __sub__(self, other: Self | Scalar) -> Self:
435+
def __sub__(self, other: Self | AnyScalar) -> Self:
431436
"""Subtract `other` column or scalar from this column.
432437
433438
Parameters
@@ -448,7 +453,7 @@ def __sub__(self, other: Self | Scalar) -> Self:
448453
"""
449454
...
450455

451-
def __mul__(self, other: Self | Scalar) -> Self:
456+
def __mul__(self, other: Self | AnyScalar) -> Self:
452457
"""Multiply `other` column or scalar with this column.
453458
454459
Parameters
@@ -469,7 +474,7 @@ def __mul__(self, other: Self | Scalar) -> Self:
469474
"""
470475
...
471476

472-
def __truediv__(self, other: Self | Scalar) -> Self:
477+
def __truediv__(self, other: Self | AnyScalar) -> Self:
473478
"""Divide this column by `other` column or scalar. True division, returns floats.
474479
475480
Parameters
@@ -490,7 +495,7 @@ def __truediv__(self, other: Self | Scalar) -> Self:
490495
"""
491496
...
492497

493-
def __floordiv__(self, other: Self | Scalar) -> Self:
498+
def __floordiv__(self, other: Self | AnyScalar) -> Self:
494499
"""Floor-divide `other` column or scalar to this column.
495500
496501
Parameters
@@ -511,7 +516,7 @@ def __floordiv__(self, other: Self | Scalar) -> Self:
511516
"""
512517
...
513518

514-
def __pow__(self, other: Self | Scalar) -> Self:
519+
def __pow__(self, other: Self | AnyScalar) -> Self:
515520
"""Raise this column to the power of `other`.
516521
517522
Integer dtype to the power of non-negative integer dtype is integer dtype.
@@ -536,7 +541,7 @@ def __pow__(self, other: Self | Scalar) -> Self:
536541
"""
537542
...
538543

539-
def __mod__(self, other: Self | Scalar) -> Self:
544+
def __mod__(self, other: Self | AnyScalar) -> Self:
540545
"""Return modulus of this column by `other` (`%` operator).
541546
542547
Parameters
@@ -557,7 +562,7 @@ def __mod__(self, other: Self | Scalar) -> Self:
557562
"""
558563
...
559564

560-
def __divmod__(self, other: Self | Scalar) -> tuple[Column, Column]:
565+
def __divmod__(self, other: Self | AnyScalar) -> tuple[Column, Column]:
561566
"""Return quotient and remainder of integer division. See `divmod` builtin.
562567
563568
Parameters
@@ -578,16 +583,16 @@ def __divmod__(self, other: Self | Scalar) -> tuple[Column, Column]:
578583
"""
579584
...
580585

581-
def __radd__(self, other: Self | Scalar) -> Self:
586+
def __radd__(self, other: Self | AnyScalar) -> Self:
582587
...
583588

584-
def __rsub__(self, other: Self | Scalar) -> Self:
589+
def __rsub__(self, other: Self | AnyScalar) -> Self:
585590
...
586591

587-
def __rmul__(self, other: Self | Scalar) -> Self:
592+
def __rmul__(self, other: Self | AnyScalar) -> Self:
588593
...
589594

590-
def __rtruediv__(self, other: Self | Scalar) -> Self:
595+
def __rtruediv__(self, other: Self | AnyScalar) -> Self:
591596
...
592597

593598
def __rand__(self, other: Self | bool) -> Self:
@@ -596,13 +601,13 @@ def __rand__(self, other: Self | bool) -> Self:
596601
def __ror__(self, other: Self | bool) -> Self:
597602
...
598603

599-
def __rfloordiv__(self, other: Self | Scalar) -> Self:
604+
def __rfloordiv__(self, other: Self | AnyScalar) -> Self:
600605
...
601606

602-
def __rpow__(self, other: Self | Scalar) -> Self:
607+
def __rpow__(self, other: Self | AnyScalar) -> Self:
603608
...
604609

605-
def __rmod__(self, other: Self | Scalar) -> Self:
610+
def __rmod__(self, other: Self | AnyScalar) -> Self:
606611
...
607612

608613
def __invert__(self) -> Self:
@@ -615,7 +620,7 @@ def __invert__(self) -> Self:
615620
"""
616621
...
617622

618-
def any(self, *, skip_nulls: bool = True) -> bool | NullType:
623+
def any(self, *, skip_nulls: bool | Scalar = True) -> Scalar:
619624
"""Reduction returns a bool.
620625
621626
Raises
@@ -625,7 +630,7 @@ def any(self, *, skip_nulls: bool = True) -> bool | NullType:
625630
"""
626631
...
627632

628-
def all(self, *, skip_nulls: bool = True) -> bool | NullType:
633+
def all(self, *, skip_nulls: bool | Scalar = True) -> Scalar:
629634
"""Reduction returns a bool.
630635
631636
Raises
@@ -635,23 +640,23 @@ def all(self, *, skip_nulls: bool = True) -> bool | NullType:
635640
"""
636641
...
637642

638-
def min(self, *, skip_nulls: bool = True) -> Scalar | NullType:
643+
def min(self, *, skip_nulls: bool | Scalar = True) -> Scalar:
639644
"""Reduction returns a scalar.
640645
641646
Any data type that supports comparisons
642647
must be supported. The returned value has the same dtype as the column.
643648
"""
644649
...
645650

646-
def max(self, *, skip_nulls: bool = True) -> Scalar | NullType:
651+
def max(self, *, skip_nulls: bool | Scalar = True) -> Scalar:
647652
"""Reduction returns a scalar.
648653
649654
Any data type that supports comparisons
650655
must be supported. The returned value has the same dtype as the column.
651656
"""
652657
...
653658

654-
def sum(self, *, skip_nulls: bool = True) -> Scalar | NullType:
659+
def sum(self, *, skip_nulls: bool | Scalar = True) -> Scalar:
655660
"""Reduction returns a scalar.
656661
657662
Must be supported for numerical and
@@ -660,15 +665,15 @@ def sum(self, *, skip_nulls: bool = True) -> Scalar | NullType:
660665
"""
661666
...
662667

663-
def prod(self, *, skip_nulls: bool = True) -> Scalar | NullType:
668+
def prod(self, *, skip_nulls: bool | Scalar = True) -> Scalar:
664669
"""Reduction returns a scalar.
665670
666671
Must be supported for numerical data types.
667672
The returned value has the same dtype as the column.
668673
"""
669674
...
670675

671-
def median(self, *, skip_nulls: bool = True) -> Scalar | NullType:
676+
def median(self, *, skip_nulls: bool | Scalar = True) -> Scalar:
672677
"""Reduction returns a scalar.
673678
674679
Must be supported for numerical and
@@ -678,7 +683,7 @@ def median(self, *, skip_nulls: bool = True) -> Scalar | NullType:
678683
"""
679684
...
680685

681-
def mean(self, *, skip_nulls: bool = True) -> Scalar | NullType:
686+
def mean(self, *, skip_nulls: bool | Scalar = True) -> Scalar:
682687
"""Reduction returns a scalar.
683688
684689
Must be supported for numerical and
@@ -691,9 +696,9 @@ def mean(self, *, skip_nulls: bool = True) -> Scalar | NullType:
691696
def std(
692697
self,
693698
*,
694-
correction: int | float = 1,
695-
skip_nulls: bool = True,
696-
) -> Scalar | NullType:
699+
correction: float = 1,
700+
skip_nulls: bool | Scalar = True,
701+
) -> Scalar:
697702
"""Reduction returns a scalar.
698703
699704
Must be supported for numerical and
@@ -724,9 +729,9 @@ def std(
724729
def var(
725730
self,
726731
*,
727-
correction: int | float = 1,
728-
skip_nulls: bool = True,
729-
) -> Scalar | NullType:
732+
correction: float | Scalar = 1,
733+
skip_nulls: bool | Scalar = True,
734+
) -> Scalar:
730735
"""Reduction returns a scalar.
731736
732737
Must be supported for numerical and
@@ -835,7 +840,7 @@ def is_in(self, values: Self) -> Self:
835840
"""
836841
...
837842

838-
def unique_indices(self, *, skip_nulls: bool = True) -> Self:
843+
def unique_indices(self, *, skip_nulls: bool | Scalar = True) -> Self:
839844
"""Return indices corresponding to unique values in Column.
840845
841846
Returns
@@ -855,7 +860,7 @@ def unique_indices(self, *, skip_nulls: bool = True) -> Self:
855860
"""
856861
...
857862

858-
def fill_nan(self, value: float | NullType, /) -> Self:
863+
def fill_nan(self, value: float | NullType | Scalar, /) -> Self:
859864
"""Fill floating point ``nan`` values with the given fill value.
860865
861866
Parameters
@@ -868,7 +873,7 @@ def fill_nan(self, value: float | NullType, /) -> Self:
868873
"""
869874
...
870875

871-
def fill_null(self, value: Scalar, /) -> Self:
876+
def fill_null(self, value: AnyScalar, /) -> Self:
872877
"""Fill null values with the given fill value.
873878
874879
Parameters
@@ -914,7 +919,7 @@ def to_array(self) -> Any:
914919
"""
915920
...
916921

917-
def rename(self, name: str) -> Self:
922+
def rename(self, name: str | Scalar) -> Self:
918923
"""Rename column.
919924
920925
Parameters
@@ -929,17 +934,17 @@ def rename(self, name: str) -> Self:
929934
"""
930935
...
931936

932-
def shift(self, offset: int) -> Self:
937+
def shift(self, offset: int | Scalar) -> Self:
933938
"""Shift values by `offset` positions, filling missing values with `null`.
934939
935940
For example, if the original column contains values `[1, 4, 2]`, then:
936941
937942
- `.shift(1)` will return `[null, 1, 4]`,
938943
- `.shift(-1)` will return `[4, 2, null]`,
939-
944+
940945
Parameters
941946
----------
942-
offset
947+
offset : int
943948
How many positions to shift by.
944949
"""
945950
...
@@ -1020,7 +1025,7 @@ def iso_weekday(self) -> Self:
10201025
"""
10211026
...
10221027

1023-
def unix_timestamp(self, *, time_unit: Literal["s", "ms", "us"] = "s") -> Self:
1028+
def unix_timestamp(self, *, time_unit: str | Scalar = "s") -> Self:
10241029
"""Return number of seconds / milliseconds / microseconds since the Unix epoch.
10251030
10261031
The Unix epoch is 00:00:00 UTC on 1 January 1970.
@@ -1039,3 +1044,16 @@ def unix_timestamp(self, *, time_unit: Literal["s", "ms", "us"] = "s") -> Self:
10391044
discarded.
10401045
"""
10411046
...
1047+
1048+
def persist(self) -> Self:
1049+
"""Hint that computation prior to this point should not be repeated.
1050+
1051+
This is intended as a hint, rather than as a directive. Implementations
1052+
which do not separate lazy vs eager execution may ignore this method and
1053+
treat it as a no-op.
1054+
1055+
.. note::
1056+
This method may trigger execution. If necessary, it should be called
1057+
at most once per dataframe, and as late as possible in the pipeline.
1058+
"""
1059+
...

0 commit comments

Comments
 (0)