Skip to content

Commit 6029ba3

Browse files
committed
Turn isna() and notna() into TypeGuards
1 parent b12acc1 commit 6029ba3

File tree

2 files changed

+40
-17
lines changed

2 files changed

+40
-17
lines changed

pandas-stubs/core/dtypes/missing.pyi

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
from typing import (
2-
Literal,
3-
overload,
4-
)
1+
from typing import overload
52

63
import numpy as np
74
from numpy import typing as npt
@@ -10,12 +7,14 @@ from pandas import (
107
Index,
118
Series,
129
)
10+
from typing_extensions import TypeGuard
1311

1412
from pandas._libs.missing import NAType
1513
from pandas._libs.tslibs import NaTType
1614
from pandas._typing import (
1715
ArrayLike,
1816
Scalar,
17+
ScalarT,
1918
)
2019

2120
isposinf_scalar = ...
@@ -28,9 +27,9 @@ def isna(obj: Series) -> Series[bool]: ...
2827
@overload
2928
def isna(obj: Index | list | ArrayLike) -> npt.NDArray[np.bool_]: ...
3029
@overload
31-
def isna(obj: Scalar) -> bool: ...
32-
@overload
33-
def isna(obj: NaTType | NAType | None) -> Literal[True]: ...
30+
def isna(
31+
obj: Scalar | NaTType | NAType | None,
32+
) -> TypeGuard[NaTType | NAType | None]: ...
3433

3534
isnull = isna
3635

@@ -41,8 +40,6 @@ def notna(obj: Series) -> Series[bool]: ...
4140
@overload
4241
def notna(obj: Index | list | ArrayLike) -> npt.NDArray[np.bool_]: ...
4342
@overload
44-
def notna(obj: Scalar) -> bool: ...
45-
@overload
46-
def notna(obj: NaTType | NAType | None) -> Literal[False]: ...
43+
def notna(obj: ScalarT | NaTType | NAType | None) -> TypeGuard[ScalarT]: ...
4744

4845
notnull = notna

tests/test_pandas.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22

3+
import random
34
from typing import (
45
TYPE_CHECKING,
56
Any,
6-
Literal,
77
Union,
88
)
99

@@ -14,6 +14,8 @@
1414
import pytest
1515
from typing_extensions import assert_type
1616

17+
from pandas._libs.missing import NAType
18+
from pandas._libs.tslibs import NaTType
1719
from pandas._typing import Scalar
1820

1921
from tests import check
@@ -133,18 +135,42 @@ def test_isna() -> None:
133135
idx2 = pd.Index([1, 2])
134136
check(assert_type(pd.notna(idx2), npt.NDArray[np.bool_]), np.ndarray, np.bool_)
135137

136-
assert check(assert_type(pd.isna(pd.NA), Literal[True]), bool)
137-
assert not check(assert_type(pd.notna(pd.NA), Literal[False]), bool)
138+
assert check(assert_type(pd.isna(pd.NA), bool), bool)
139+
assert not check(assert_type(pd.notna(pd.NA), bool), bool)
138140

139-
assert check(assert_type(pd.isna(pd.NaT), Literal[True]), bool)
140-
assert not check(assert_type(pd.notna(pd.NaT), Literal[False]), bool)
141+
assert check(assert_type(pd.isna(pd.NaT), bool), bool)
142+
assert not check(assert_type(pd.notna(pd.NaT), bool), bool)
141143

142-
assert check(assert_type(pd.isna(None), Literal[True]), bool)
143-
assert not check(assert_type(pd.notna(None), Literal[False]), bool)
144+
assert check(assert_type(pd.isna(None), bool), bool)
145+
assert not check(assert_type(pd.notna(None), bool), bool)
144146

145147
check(assert_type(pd.isna(2.5), bool), bool)
146148
check(assert_type(pd.notna(2.5), bool), bool)
147149

150+
# Check type guard functionality
151+
nullable1 = random.choice(["value", None, pd.NA, pd.NaT])
152+
if pd.notna(nullable1):
153+
check(assert_type(nullable1, str), str)
154+
if pd.isna(nullable1):
155+
assert_type(nullable1, Union[NaTType, NAType, None])
156+
157+
nullable2 = random.choice([2, None])
158+
if pd.notna(nullable2):
159+
check(assert_type(nullable2, int), int)
160+
if pd.isna(nullable2):
161+
# TODO: Due to limitations in TypeGuard spec, the true annotation is not viable at this time
162+
# There is a proposal being floated for a StrictTypeGuard that will have more rigid narrowing semantics
163+
# assert_type(nullable2, None)
164+
assert_type(nullable2, Union[NaTType, NAType, None])
165+
166+
nullable3 = random.choice([2, None, pd.NA])
167+
if pd.notna(nullable3):
168+
check(assert_type(nullable3, int), int)
169+
if pd.isna(nullable3):
170+
# TODO: See comment above about the limitations of TypeGuard
171+
# assert_type(nullable3, Union[NAType, None])
172+
assert_type(nullable3, Union[NaTType, NAType, None])
173+
148174

149175
# GH 55
150176
def test_read_xml() -> None:

0 commit comments

Comments
 (0)