1
- from __future__ import annotations
2
-
3
1
from collections .abc import Sequence
4
- from typing import Union , Optional , Literal
2
+ from typing import Literal , TypeAlias
3
+
4
+ from ._typing import Array , Device , DType , Namespace
5
5
6
- from . _typing import Device , Array , DType , Namespace
6
+ _Norm : TypeAlias = Literal [ "backward" , "ortho" , "forward" ]
7
7
8
8
# Note: NumPy fft functions improperly upcast float32 and complex64 to
9
9
# complex128, which is why we require wrapping them all here.
@@ -13,9 +13,9 @@ def fft(
13
13
/ ,
14
14
xp : Namespace ,
15
15
* ,
16
- n : Optional [ int ] = None ,
16
+ n : int | None = None ,
17
17
axis : int = - 1 ,
18
- norm : Literal [ "backward" , "ortho" , "forward" ] = "backward" ,
18
+ norm : _Norm = "backward" ,
19
19
) -> Array :
20
20
res = xp .fft .fft (x , n = n , axis = axis , norm = norm )
21
21
if x .dtype in [xp .float32 , xp .complex64 ]:
@@ -27,9 +27,9 @@ def ifft(
27
27
/ ,
28
28
xp : Namespace ,
29
29
* ,
30
- n : Optional [ int ] = None ,
30
+ n : int | None = None ,
31
31
axis : int = - 1 ,
32
- norm : Literal [ "backward" , "ortho" , "forward" ] = "backward" ,
32
+ norm : _Norm = "backward" ,
33
33
) -> Array :
34
34
res = xp .fft .ifft (x , n = n , axis = axis , norm = norm )
35
35
if x .dtype in [xp .float32 , xp .complex64 ]:
@@ -41,9 +41,9 @@ def fftn(
41
41
/ ,
42
42
xp : Namespace ,
43
43
* ,
44
- s : Sequence [int ] = None ,
45
- axes : Sequence [int ] = None ,
46
- norm : Literal [ "backward" , "ortho" , "forward" ] = "backward" ,
44
+ s : Sequence [int ] | None = None ,
45
+ axes : Sequence [int ] | None = None ,
46
+ norm : _Norm = "backward" ,
47
47
) -> Array :
48
48
res = xp .fft .fftn (x , s = s , axes = axes , norm = norm )
49
49
if x .dtype in [xp .float32 , xp .complex64 ]:
@@ -55,9 +55,9 @@ def ifftn(
55
55
/ ,
56
56
xp : Namespace ,
57
57
* ,
58
- s : Sequence [int ] = None ,
59
- axes : Sequence [int ] = None ,
60
- norm : Literal [ "backward" , "ortho" , "forward" ] = "backward" ,
58
+ s : Sequence [int ] | None = None ,
59
+ axes : Sequence [int ] | None = None ,
60
+ norm : _Norm = "backward" ,
61
61
) -> Array :
62
62
res = xp .fft .ifftn (x , s = s , axes = axes , norm = norm )
63
63
if x .dtype in [xp .float32 , xp .complex64 ]:
@@ -69,9 +69,9 @@ def rfft(
69
69
/ ,
70
70
xp : Namespace ,
71
71
* ,
72
- n : Optional [ int ] = None ,
72
+ n : int | None = None ,
73
73
axis : int = - 1 ,
74
- norm : Literal [ "backward" , "ortho" , "forward" ] = "backward" ,
74
+ norm : _Norm = "backward" ,
75
75
) -> Array :
76
76
res = xp .fft .rfft (x , n = n , axis = axis , norm = norm )
77
77
if x .dtype == xp .float32 :
@@ -83,9 +83,9 @@ def irfft(
83
83
/ ,
84
84
xp : Namespace ,
85
85
* ,
86
- n : Optional [ int ] = None ,
86
+ n : int | None = None ,
87
87
axis : int = - 1 ,
88
- norm : Literal [ "backward" , "ortho" , "forward" ] = "backward" ,
88
+ norm : _Norm = "backward" ,
89
89
) -> Array :
90
90
res = xp .fft .irfft (x , n = n , axis = axis , norm = norm )
91
91
if x .dtype == xp .complex64 :
@@ -97,9 +97,9 @@ def rfftn(
97
97
/ ,
98
98
xp : Namespace ,
99
99
* ,
100
- s : Sequence [int ] = None ,
101
- axes : Sequence [int ] = None ,
102
- norm : Literal [ "backward" , "ortho" , "forward" ] = "backward" ,
100
+ s : Sequence [int ] | None = None ,
101
+ axes : Sequence [int ] | None = None ,
102
+ norm : _Norm = "backward" ,
103
103
) -> Array :
104
104
res = xp .fft .rfftn (x , s = s , axes = axes , norm = norm )
105
105
if x .dtype == xp .float32 :
@@ -111,9 +111,9 @@ def irfftn(
111
111
/ ,
112
112
xp : Namespace ,
113
113
* ,
114
- s : Sequence [int ] = None ,
115
- axes : Sequence [int ] = None ,
116
- norm : Literal [ "backward" , "ortho" , "forward" ] = "backward" ,
114
+ s : Sequence [int ] | None = None ,
115
+ axes : Sequence [int ] | None = None ,
116
+ norm : _Norm = "backward" ,
117
117
) -> Array :
118
118
res = xp .fft .irfftn (x , s = s , axes = axes , norm = norm )
119
119
if x .dtype == xp .complex64 :
@@ -125,9 +125,9 @@ def hfft(
125
125
/ ,
126
126
xp : Namespace ,
127
127
* ,
128
- n : Optional [ int ] = None ,
128
+ n : int | None = None ,
129
129
axis : int = - 1 ,
130
- norm : Literal [ "backward" , "ortho" , "forward" ] = "backward" ,
130
+ norm : _Norm = "backward" ,
131
131
) -> Array :
132
132
res = xp .fft .hfft (x , n = n , axis = axis , norm = norm )
133
133
if x .dtype in [xp .float32 , xp .complex64 ]:
@@ -139,9 +139,9 @@ def ihfft(
139
139
/ ,
140
140
xp : Namespace ,
141
141
* ,
142
- n : Optional [ int ] = None ,
142
+ n : int | None = None ,
143
143
axis : int = - 1 ,
144
- norm : Literal [ "backward" , "ortho" , "forward" ] = "backward" ,
144
+ norm : _Norm = "backward" ,
145
145
) -> Array :
146
146
res = xp .fft .ihfft (x , n = n , axis = axis , norm = norm )
147
147
if x .dtype in [xp .float32 , xp .complex64 ]:
@@ -154,8 +154,8 @@ def fftfreq(
154
154
xp : Namespace ,
155
155
* ,
156
156
d : float = 1.0 ,
157
- dtype : Optional [ DType ] = None ,
158
- device : Optional [ Device ] = None ,
157
+ dtype : DType | None = None ,
158
+ device : Device | None = None ,
159
159
) -> Array :
160
160
if device not in ["cpu" , None ]:
161
161
raise ValueError (f"Unsupported device { device !r} " )
@@ -170,8 +170,8 @@ def rfftfreq(
170
170
xp : Namespace ,
171
171
* ,
172
172
d : float = 1.0 ,
173
- dtype : Optional [ DType ] = None ,
174
- device : Optional [ Device ] = None ,
173
+ dtype : DType | None = None ,
174
+ device : Device | None = None ,
175
175
) -> Array :
176
176
if device not in ["cpu" , None ]:
177
177
raise ValueError (f"Unsupported device { device !r} " )
@@ -181,12 +181,12 @@ def rfftfreq(
181
181
return res
182
182
183
183
def fftshift (
184
- x : Array , / , xp : Namespace , * , axes : Union [ int , Sequence [int ]] = None
184
+ x : Array , / , xp : Namespace , * , axes : int | Sequence [int ] | None = None
185
185
) -> Array :
186
186
return xp .fft .fftshift (x , axes = axes )
187
187
188
188
def ifftshift (
189
- x : Array , / , xp : Namespace , * , axes : Union [ int , Sequence [int ]] = None
189
+ x : Array , / , xp : Namespace , * , axes : int | Sequence [int ] | None = None
190
190
) -> Array :
191
191
return xp .fft .ifftshift (x , axes = axes )
192
192
0 commit comments