Skip to content

Commit 3b134b0

Browse files
committed
TYP: fix typing errors in common._fft
1 parent 6a17007 commit 3b134b0

File tree

1 file changed

+34
-34
lines changed

1 file changed

+34
-34
lines changed

array_api_compat/common/_fft.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from __future__ import annotations
2-
31
from collections.abc import Sequence
4-
from typing import Union, Optional, Literal
2+
from typing import Literal, TypeAlias
3+
4+
from ._typing import Array, Device, DType, Namespace
55

6-
from ._typing import Device, Array, DType, Namespace
6+
_Norm: TypeAlias = Literal["backward", "ortho", "forward"]
77

88
# Note: NumPy fft functions improperly upcast float32 and complex64 to
99
# complex128, which is why we require wrapping them all here.
@@ -13,9 +13,9 @@ def fft(
1313
/,
1414
xp: Namespace,
1515
*,
16-
n: Optional[int] = None,
16+
n: int | None = None,
1717
axis: int = -1,
18-
norm: Literal["backward", "ortho", "forward"] = "backward",
18+
norm: _Norm = "backward",
1919
) -> Array:
2020
res = xp.fft.fft(x, n=n, axis=axis, norm=norm)
2121
if x.dtype in [xp.float32, xp.complex64]:
@@ -27,9 +27,9 @@ def ifft(
2727
/,
2828
xp: Namespace,
2929
*,
30-
n: Optional[int] = None,
30+
n: int | None = None,
3131
axis: int = -1,
32-
norm: Literal["backward", "ortho", "forward"] = "backward",
32+
norm: _Norm = "backward",
3333
) -> Array:
3434
res = xp.fft.ifft(x, n=n, axis=axis, norm=norm)
3535
if x.dtype in [xp.float32, xp.complex64]:
@@ -41,9 +41,9 @@ def fftn(
4141
/,
4242
xp: Namespace,
4343
*,
44-
s: Sequence[int] = None,
45-
axes: Sequence[int] = None,
46-
norm: Literal["backward", "ortho", "forward"] = "backward",
44+
s: Sequence[int] | None = None,
45+
axes: Sequence[int] | None = None,
46+
norm: _Norm = "backward",
4747
) -> Array:
4848
res = xp.fft.fftn(x, s=s, axes=axes, norm=norm)
4949
if x.dtype in [xp.float32, xp.complex64]:
@@ -55,9 +55,9 @@ def ifftn(
5555
/,
5656
xp: Namespace,
5757
*,
58-
s: Sequence[int] = None,
59-
axes: Sequence[int] = None,
60-
norm: Literal["backward", "ortho", "forward"] = "backward",
58+
s: Sequence[int] | None = None,
59+
axes: Sequence[int] | None = None,
60+
norm: _Norm = "backward",
6161
) -> Array:
6262
res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm)
6363
if x.dtype in [xp.float32, xp.complex64]:
@@ -69,9 +69,9 @@ def rfft(
6969
/,
7070
xp: Namespace,
7171
*,
72-
n: Optional[int] = None,
72+
n: int | None = None,
7373
axis: int = -1,
74-
norm: Literal["backward", "ortho", "forward"] = "backward",
74+
norm: _Norm = "backward",
7575
) -> Array:
7676
res = xp.fft.rfft(x, n=n, axis=axis, norm=norm)
7777
if x.dtype == xp.float32:
@@ -83,9 +83,9 @@ def irfft(
8383
/,
8484
xp: Namespace,
8585
*,
86-
n: Optional[int] = None,
86+
n: int | None = None,
8787
axis: int = -1,
88-
norm: Literal["backward", "ortho", "forward"] = "backward",
88+
norm: _Norm = "backward",
8989
) -> Array:
9090
res = xp.fft.irfft(x, n=n, axis=axis, norm=norm)
9191
if x.dtype == xp.complex64:
@@ -97,9 +97,9 @@ def rfftn(
9797
/,
9898
xp: Namespace,
9999
*,
100-
s: Sequence[int] = None,
101-
axes: Sequence[int] = None,
102-
norm: Literal["backward", "ortho", "forward"] = "backward",
100+
s: Sequence[int] | None = None,
101+
axes: Sequence[int] | None = None,
102+
norm: _Norm = "backward",
103103
) -> Array:
104104
res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm)
105105
if x.dtype == xp.float32:
@@ -111,9 +111,9 @@ def irfftn(
111111
/,
112112
xp: Namespace,
113113
*,
114-
s: Sequence[int] = None,
115-
axes: Sequence[int] = None,
116-
norm: Literal["backward", "ortho", "forward"] = "backward",
114+
s: Sequence[int] | None = None,
115+
axes: Sequence[int] | None = None,
116+
norm: _Norm = "backward",
117117
) -> Array:
118118
res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm)
119119
if x.dtype == xp.complex64:
@@ -125,9 +125,9 @@ def hfft(
125125
/,
126126
xp: Namespace,
127127
*,
128-
n: Optional[int] = None,
128+
n: int | None = None,
129129
axis: int = -1,
130-
norm: Literal["backward", "ortho", "forward"] = "backward",
130+
norm: _Norm = "backward",
131131
) -> Array:
132132
res = xp.fft.hfft(x, n=n, axis=axis, norm=norm)
133133
if x.dtype in [xp.float32, xp.complex64]:
@@ -139,9 +139,9 @@ def ihfft(
139139
/,
140140
xp: Namespace,
141141
*,
142-
n: Optional[int] = None,
142+
n: int | None = None,
143143
axis: int = -1,
144-
norm: Literal["backward", "ortho", "forward"] = "backward",
144+
norm: _Norm = "backward",
145145
) -> Array:
146146
res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
147147
if x.dtype in [xp.float32, xp.complex64]:
@@ -154,8 +154,8 @@ def fftfreq(
154154
xp: Namespace,
155155
*,
156156
d: float = 1.0,
157-
dtype: Optional[DType] = None,
158-
device: Optional[Device] = None,
157+
dtype: DType | None = None,
158+
device: Device | None = None,
159159
) -> Array:
160160
if device not in ["cpu", None]:
161161
raise ValueError(f"Unsupported device {device!r}")
@@ -170,8 +170,8 @@ def rfftfreq(
170170
xp: Namespace,
171171
*,
172172
d: float = 1.0,
173-
dtype: Optional[DType] = None,
174-
device: Optional[Device] = None,
173+
dtype: DType | None = None,
174+
device: Device | None = None,
175175
) -> Array:
176176
if device not in ["cpu", None]:
177177
raise ValueError(f"Unsupported device {device!r}")
@@ -181,12 +181,12 @@ def rfftfreq(
181181
return res
182182

183183
def fftshift(
184-
x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None
184+
x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None
185185
) -> Array:
186186
return xp.fft.fftshift(x, axes=axes)
187187

188188
def ifftshift(
189-
x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None
189+
x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None
190190
) -> Array:
191191
return xp.fft.ifftshift(x, axes=axes)
192192

0 commit comments

Comments
 (0)