@@ -109,44 +109,45 @@ def assert_s_axes_shape(
109
109
ph .assert_shape (func_name , out_shape = out .shape , expected = tuple (expected ))
110
110
111
111
112
- @given (x = hh .arrays (dtype = hh .complex_dtypes , shape = fft_shapes_strat ), data = st .data ())
113
- def test_fft (x , data ):
114
- n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
112
+ if hh .complex_dtypes :
113
+ @given (x = hh .arrays (dtype = hh .complex_dtypes , shape = fft_shapes_strat ), data = st .data ())
114
+ def test_fft (x , data ):
115
+ n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
115
116
116
- out = xp .fft .fft (x , ** kwargs )
117
+ out = xp .fft .fft (x , ** kwargs )
117
118
118
- ph .assert_dtype ("fft" , in_dtype = x .dtype , out_dtype = out .dtype )
119
- assert_n_axis_shape ("fft" , x = x , n = n , axis = axis , out = out )
119
+ ph .assert_dtype ("fft" , in_dtype = x .dtype , out_dtype = out .dtype )
120
+ assert_n_axis_shape ("fft" , x = x , n = n , axis = axis , out = out )
120
121
121
122
122
- @given (x = hh .arrays (dtype = hh .complex_dtypes , shape = fft_shapes_strat ), data = st .data ())
123
- def test_ifft (x , data ):
124
- n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
123
+ @given (x = hh .arrays (dtype = hh .complex_dtypes , shape = fft_shapes_strat ), data = st .data ())
124
+ def test_ifft (x , data ):
125
+ n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
125
126
126
- out = xp .fft .ifft (x , ** kwargs )
127
+ out = xp .fft .ifft (x , ** kwargs )
127
128
128
- ph .assert_dtype ("ifft" , in_dtype = x .dtype , out_dtype = out .dtype )
129
- assert_n_axis_shape ("ifft" , x = x , n = n , axis = axis , out = out )
129
+ ph .assert_dtype ("ifft" , in_dtype = x .dtype , out_dtype = out .dtype )
130
+ assert_n_axis_shape ("ifft" , x = x , n = n , axis = axis , out = out )
130
131
131
132
132
- @given (x = hh .arrays (dtype = hh .complex_dtypes , shape = fft_shapes_strat ), data = st .data ())
133
- def test_fftn (x , data ):
134
- s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data )
133
+ @given (x = hh .arrays (dtype = hh .complex_dtypes , shape = fft_shapes_strat ), data = st .data ())
134
+ def test_fftn (x , data ):
135
+ s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data )
135
136
136
- out = xp .fft .fftn (x , ** kwargs )
137
+ out = xp .fft .fftn (x , ** kwargs )
137
138
138
- ph .assert_dtype ("fftn" , in_dtype = x .dtype , out_dtype = out .dtype )
139
- assert_s_axes_shape ("fftn" , x = x , s = s , axes = axes , out = out )
139
+ ph .assert_dtype ("fftn" , in_dtype = x .dtype , out_dtype = out .dtype )
140
+ assert_s_axes_shape ("fftn" , x = x , s = s , axes = axes , out = out )
140
141
141
142
142
- @given (x = hh .arrays (dtype = hh .complex_dtypes , shape = fft_shapes_strat ), data = st .data ())
143
- def test_ifftn (x , data ):
144
- s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data )
143
+ @given (x = hh .arrays (dtype = hh .complex_dtypes , shape = fft_shapes_strat ), data = st .data ())
144
+ def test_ifftn (x , data ):
145
+ s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data )
145
146
146
- out = xp .fft .ifftn (x , ** kwargs )
147
+ out = xp .fft .ifftn (x , ** kwargs )
147
148
148
- ph .assert_dtype ("ifftn" , in_dtype = x .dtype , out_dtype = out .dtype )
149
- assert_s_axes_shape ("ifftn" , x = x , s = s , axes = axes , out = out )
149
+ ph .assert_dtype ("ifftn" , in_dtype = x .dtype , out_dtype = out .dtype )
150
+ assert_s_axes_shape ("ifftn" , x = x , s = s , axes = axes , out = out )
150
151
151
152
152
153
@given (x = hh .arrays (dtype = hh .real_floating_dtypes , shape = fft_shapes_strat ), data = st .data ())
@@ -166,26 +167,27 @@ def test_rfft(x, data):
166
167
ph .assert_shape ("rfft" , out_shape = out .shape , expected = expected_shape )
167
168
168
169
169
- @given (x = hh .arrays (dtype = hh .complex_dtypes , shape = fft_shapes_strat ), data = st .data ())
170
- def test_irfft (x , data ):
171
- n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data , size_gt_1 = True )
170
+ if hh .complex_dtypes :
171
+ @given (x = hh .arrays (dtype = hh .complex_dtypes , shape = fft_shapes_strat ), data = st .data ())
172
+ def test_irfft (x , data ):
173
+ n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data , size_gt_1 = True )
172
174
173
- out = xp .fft .irfft (x , ** kwargs )
175
+ out = xp .fft .irfft (x , ** kwargs )
174
176
175
- ph .assert_dtype (
176
- "irfft" ,
177
- in_dtype = x .dtype ,
178
- out_dtype = out .dtype ,
179
- expected = dh .dtype_components [x .dtype ],
180
- )
177
+ ph .assert_dtype (
178
+ "irfft" ,
179
+ in_dtype = x .dtype ,
180
+ out_dtype = out .dtype ,
181
+ expected = dh .dtype_components [x .dtype ],
182
+ )
181
183
182
- _axis = x .ndim - 1 if axis == - 1 else axis
183
- if n is None :
184
- axis_side = 2 * (x .shape [_axis ] - 1 )
185
- else :
186
- axis_side = n
187
- expected_shape = x .shape [:_axis ] + (axis_side ,) + x .shape [_axis + 1 :]
188
- ph .assert_shape ("irfft" , out_shape = out .shape , expected = expected_shape )
184
+ _axis = x .ndim - 1 if axis == - 1 else axis
185
+ if n is None :
186
+ axis_side = 2 * (x .shape [_axis ] - 1 )
187
+ else :
188
+ axis_side = n
189
+ expected_shape = x .shape [:_axis ] + (axis_side ,) + x .shape [_axis + 1 :]
190
+ ph .assert_shape ("irfft" , out_shape = out .shape , expected = expected_shape )
189
191
190
192
191
193
@given (x = hh .arrays (dtype = hh .real_floating_dtypes , shape = fft_shapes_strat ), data = st .data ())
@@ -209,59 +211,60 @@ def test_rfftn(x, data):
209
211
ph .assert_shape ("rfftn" , out_shape = out .shape , expected = tuple (expected ))
210
212
211
213
212
- @given (
213
- x = hh .arrays (
214
- dtype = hh .complex_dtypes , shape = fft_shapes_strat .filter (lambda s : s [- 1 ] > 1 )
215
- ),
216
- data = st .data (),
217
- )
218
- def test_irfftn (x , data ):
219
- s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data )
220
-
221
- out = xp .fft .irfftn (x , ** kwargs )
222
-
223
- ph .assert_dtype (
224
- "irfftn" ,
225
- in_dtype = x .dtype ,
226
- out_dtype = out .dtype ,
227
- expected = dh .dtype_components [x .dtype ],
228
- )
229
-
230
- # TODO: assert shape correctly
231
- # _axes = sh.normalize_axis(axes, x.ndim)
232
- # _s = x.shape if s is None else s
233
- # expected = []
234
- # for i in range(x.ndim):
235
- # if i in _axes:
236
- # side = _s[_axes.index(i)]
237
- # else:
238
- # side = x.shape[i]
239
- # expected.append(side)
240
- # last_axis = max(_axes)
241
- # expected[last_axis] = _s[_axes.index(last_axis)] // 2 + 1
242
- # ph.assert_shape("irfftn", out_shape=out.shape, expected=tuple(expected))
243
-
244
-
245
- @given (x = hh .arrays (dtype = hh .complex_dtypes , shape = fft_shapes_strat ), data = st .data ())
246
- def test_hfft (x , data ):
247
- n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data , size_gt_1 = True )
248
-
249
- out = xp .fft .hfft (x , ** kwargs )
250
-
251
- ph .assert_dtype (
252
- "hfft" ,
253
- in_dtype = x .dtype ,
254
- out_dtype = out .dtype ,
255
- expected = dh .dtype_components [x .dtype ],
214
+ if hh .complex_dtypes :
215
+ @given (
216
+ x = hh .arrays (
217
+ dtype = hh .complex_dtypes , shape = fft_shapes_strat .filter (lambda s : s [- 1 ] > 1 )
218
+ ),
219
+ data = st .data (),
256
220
)
221
+ def test_irfftn (x , data ):
222
+ s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data )
223
+
224
+ out = xp .fft .irfftn (x , ** kwargs )
225
+
226
+ ph .assert_dtype (
227
+ "irfftn" ,
228
+ in_dtype = x .dtype ,
229
+ out_dtype = out .dtype ,
230
+ expected = dh .dtype_components [x .dtype ],
231
+ )
232
+
233
+ # TODO: assert shape correctly
234
+ # _axes = sh.normalize_axis(axes, x.ndim)
235
+ # _s = x.shape if s is None else s
236
+ # expected = []
237
+ # for i in range(x.ndim):
238
+ # if i in _axes:
239
+ # side = _s[_axes.index(i)]
240
+ # else:
241
+ # side = x.shape[i]
242
+ # expected.append(side)
243
+ # last_axis = max(_axes)
244
+ # expected[last_axis] = _s[_axes.index(last_axis)] // 2 + 1
245
+ # ph.assert_shape("irfftn", out_shape=out.shape, expected=tuple(expected))
246
+
247
+
248
+ @given (x = hh .arrays (dtype = hh .complex_dtypes , shape = fft_shapes_strat ), data = st .data ())
249
+ def test_hfft (x , data ):
250
+ n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data , size_gt_1 = True )
251
+
252
+ out = xp .fft .hfft (x , ** kwargs )
253
+
254
+ ph .assert_dtype (
255
+ "hfft" ,
256
+ in_dtype = x .dtype ,
257
+ out_dtype = out .dtype ,
258
+ expected = dh .dtype_components [x .dtype ],
259
+ )
257
260
258
- _axis = x .ndim - 1 if axis == - 1 else axis
259
- if n is None :
260
- axis_side = 2 * (x .shape [_axis ] - 1 )
261
- else :
262
- axis_side = n
263
- expected_shape = x .shape [:_axis ] + (axis_side ,) + x .shape [_axis + 1 :]
264
- ph .assert_shape ("hfft" , out_shape = out .shape , expected = expected_shape )
261
+ _axis = x .ndim - 1 if axis == - 1 else axis
262
+ if n is None :
263
+ axis_side = 2 * (x .shape [_axis ] - 1 )
264
+ else :
265
+ axis_side = n
266
+ expected_shape = x .shape [:_axis ] + (axis_side ,) + x .shape [_axis + 1 :]
267
+ ph .assert_shape ("hfft" , out_shape = out .shape , expected = expected_shape )
265
268
266
269
267
270
@given (x = hh .arrays (dtype = hh .real_floating_dtypes , shape = fft_shapes_strat ), data = st .data ())
0 commit comments