@@ -93,14 +93,24 @@ def assert_fft_dtype(func_name: str, *, in_dtype: DataType, out_dtype: DataType)
93
93
94
94
95
95
def assert_n_axis_shape (
96
- func_name : str , * , x : Array , n : Optional [int ], axis : int , out : Array
96
+ func_name : str ,
97
+ * ,
98
+ x : Array ,
99
+ n : Optional [int ],
100
+ axis : int ,
101
+ out : Array ,
102
+ size_gt_1 = False ,
97
103
):
104
+ _axis = len (x .shape ) - 1 if axis == - 1 else axis
98
105
if n is None :
99
- expected_shape = x .shape
106
+ if size_gt_1 :
107
+ axis_side = 2 * (x .shape [_axis ] - 1 )
108
+ else :
109
+ axis_side = x .shape [_axis ]
100
110
else :
101
- _axis = len ( x . shape ) - 1 if axis == - 1 else axis
102
- expected_shape = x .shape [:_axis ] + (n ,) + x .shape [_axis + 1 :]
103
- ph .assert_shape (func_name , out_shape = out .shape , expected = expected_shape )
111
+ axis_side = n
112
+ expected = x .shape [:_axis ] + (axis_side ,) + x .shape [_axis + 1 :]
113
+ ph .assert_shape (func_name , out_shape = out .shape , expected = expected )
104
114
105
115
106
116
def assert_s_axes_shape (
@@ -198,7 +208,14 @@ def test_irfft(x, data):
198
208
out = xp .fft .irfft (x , ** kwargs )
199
209
200
210
assert_fft_dtype ("irfft" , in_dtype = x .dtype , out_dtype = out .dtype )
201
- # TODO: assert shape
211
+
212
+ _axis = x .ndim - 1 if axis == - 1 else axis
213
+ if n is None :
214
+ axis_side = 2 * (x .shape [_axis ] - 1 )
215
+ else :
216
+ axis_side = n
217
+ expected_shape = x .shape [:_axis ] + (axis_side ,) + x .shape [_axis + 1 :]
218
+ ph .assert_shape ("irfft" , out_shape = out .shape , expected = expected_shape )
202
219
203
220
204
221
@given (
@@ -224,7 +241,7 @@ def test_irfftn(x, data):
224
241
out = xp .fft .irfftn (x , ** kwargs )
225
242
226
243
assert_fft_dtype ("irfftn" , in_dtype = x .dtype , out_dtype = out .dtype )
227
- assert_s_axes_shape ( "irfftn" , x = x , s = s , axes = axes , out = out )
244
+ # TODO: shape
228
245
229
246
230
247
@given (
@@ -237,7 +254,14 @@ def test_hfft(x, data):
237
254
out = xp .fft .hfft (x , ** kwargs )
238
255
239
256
assert_fft_dtype ("hfft" , in_dtype = x .dtype , out_dtype = out .dtype )
240
- # TODO: shape
257
+
258
+ _axis = x .ndim - 1 if axis == - 1 else axis
259
+ if n is None :
260
+ axis_side = 2 * (x .shape [_axis ] - 1 )
261
+ else :
262
+ axis_side = n
263
+ expected_shape = x .shape [:_axis ] + (axis_side ,) + x .shape [_axis + 1 :]
264
+ ph .assert_shape ("hfft" , out_shape = out .shape , expected = expected_shape )
241
265
242
266
243
267
@given (
@@ -250,9 +274,10 @@ def test_ihfft(x, data):
250
274
out = xp .fft .ihfft (x , ** kwargs )
251
275
252
276
assert_fft_dtype ("ihfft" , in_dtype = x .dtype , out_dtype = out .dtype )
253
- # TODO: shape
277
+ assert_n_axis_shape ( "ihfft" , x = x , n = n , axis = axis , out = out , size_gt_1 = True )
254
278
255
279
280
+ # TODO:
256
281
# fftfreq
257
282
# rfftfreq
258
283
# fftshift
0 commit comments