Skip to content

Commit 7226c9d

Browse files
committed
Turn isna() and notna() into TypeGuards
1 parent 5cfc849 commit 7226c9d

File tree

2 files changed

+39
-17
lines changed

2 files changed

+39
-17
lines changed

pandas-stubs/core/dtypes/missing.pyi

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
from typing import (
2-
Literal,
3-
overload,
4-
)
1+
from typing import overload
52

63
import numpy as np
74
from numpy import typing as npt
@@ -10,12 +7,14 @@ from pandas import (
107
Index,
118
Series,
129
)
10+
from typing_extensions import TypeGuard
1311

1412
from pandas._libs.missing import NAType
1513
from pandas._libs.tslibs import NaTType
1614
from pandas._typing import (
1715
ArrayLike,
1816
Scalar,
17+
ScalarT,
1918
)
2019

2120
isposinf_scalar = ...
@@ -28,9 +27,9 @@ def isna(obj: Series) -> Series[bool]: ...
2827
@overload
2928
def isna(obj: Index | list | ArrayLike) -> npt.NDArray[np.bool_]: ...
3029
@overload
31-
def isna(obj: Scalar) -> bool: ...
32-
@overload
33-
def isna(obj: NaTType | NAType | None) -> Literal[True]: ...
30+
def isna(
31+
obj: Scalar | NaTType | NAType | None,
32+
) -> TypeGuard[NaTType | NAType | None]: ...
3433

3534
isnull = isna
3635

@@ -41,8 +40,6 @@ def notna(obj: Series) -> Series[bool]: ...
4140
@overload
4241
def notna(obj: Index | list | ArrayLike) -> npt.NDArray[np.bool_]: ...
4342
@overload
44-
def notna(obj: Scalar) -> bool: ...
45-
@overload
46-
def notna(obj: NaTType | NAType | None) -> Literal[False]: ...
43+
def notna(obj: ScalarT | NaTType | NAType | None) -> TypeGuard[ScalarT]: ...
4744

4845
notnull = notna

tests/test_pandas.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import (
66
TYPE_CHECKING,
77
Any,
8-
Literal,
98
Union,
109
)
1110

@@ -17,6 +16,8 @@
1716
import pytest
1817
from typing_extensions import assert_type
1918

19+
from pandas._libs.missing import NAType
20+
from pandas._libs.tslibs import NaTType
2021
from pandas._typing import Scalar
2122

2223
from tests import (
@@ -246,18 +247,42 @@ def test_isna() -> None:
246247
idx2 = pd.Index([1, 2])
247248
check(assert_type(pd.notna(idx2), npt.NDArray[np.bool_]), np.ndarray, np.bool_)
248249

249-
assert check(assert_type(pd.isna(pd.NA), Literal[True]), bool)
250-
assert not check(assert_type(pd.notna(pd.NA), Literal[False]), bool)
250+
assert check(assert_type(pd.isna(pd.NA), bool), bool)
251+
assert not check(assert_type(pd.notna(pd.NA), bool), bool)
251252

252-
assert check(assert_type(pd.isna(pd.NaT), Literal[True]), bool)
253-
assert not check(assert_type(pd.notna(pd.NaT), Literal[False]), bool)
253+
assert check(assert_type(pd.isna(pd.NaT), bool), bool)
254+
assert not check(assert_type(pd.notna(pd.NaT), bool), bool)
254255

255-
assert check(assert_type(pd.isna(None), Literal[True]), bool)
256-
assert not check(assert_type(pd.notna(None), Literal[False]), bool)
256+
assert check(assert_type(pd.isna(None), bool), bool)
257+
assert not check(assert_type(pd.notna(None), bool), bool)
257258

258259
check(assert_type(pd.isna(2.5), bool), bool)
259260
check(assert_type(pd.notna(2.5), bool), bool)
260261

262+
# Check type guard functionality
263+
nullable1 = random.choice(["value", None, pd.NA, pd.NaT])
264+
if pd.notna(nullable1):
265+
check(assert_type(nullable1, str), str)
266+
if pd.isna(nullable1):
267+
assert_type(nullable1, Union[NaTType, NAType, None])
268+
269+
nullable2 = random.choice([2, None])
270+
if pd.notna(nullable2):
271+
check(assert_type(nullable2, int), int)
272+
if pd.isna(nullable2):
273+
# TODO: Due to limitations in TypeGuard spec, the true annotation is not viable at this time
274+
# There is a proposal being floated for a StrictTypeGuard that will have more rigid narrowing semantics
275+
# assert_type(nullable2, None)
276+
assert_type(nullable2, Union[NaTType, NAType, None])
277+
278+
nullable3 = random.choice([2, None, pd.NA])
279+
if pd.notna(nullable3):
280+
check(assert_type(nullable3, int), int)
281+
if pd.isna(nullable3):
282+
# TODO: See comment above about the limitations of TypeGuard
283+
# assert_type(nullable3, Union[NAType, None])
284+
assert_type(nullable3, Union[NaTType, NAType, None])
285+
261286

262287
# GH 55
263288
def test_read_xml() -> None:

0 commit comments

Comments
 (0)