Skip to content

Commit f2ceeb1

Browse files
committed
Stub tensor/vec dot tests
1 parent 859911b commit f2ceeb1

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

array_api_tests/test_type_promotion.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,17 +185,17 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
185185
assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
186186

187187

188-
promotion_table_params: List[Tuple[Tuple[DT, DT], DT]] = []
188+
promotion_params: List[Tuple[Tuple[DT, DT], DT]] = []
189189
for (dtype1, dtype2), promoted_dtype in dh.promotion_table.items():
190190
p = pytest.param(
191191
(dtype1, dtype2),
192192
promoted_dtype,
193193
id=make_id('', (dtype1, dtype2), promoted_dtype),
194194
)
195-
promotion_table_params.append(p)
195+
promotion_params.append(p)
196196

197197

198-
@pytest.mark.parametrize('in_dtypes, out_dtype', promotion_table_params)
198+
@pytest.mark.parametrize('in_dtypes, out_dtype', promotion_params)
199199
@given(shapes=hh.mutually_broadcastable_shapes(3), data=st.data())
200200
def test_where(in_dtypes, out_dtype, shapes, data):
201201
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1')
@@ -205,10 +205,10 @@ def test_where(in_dtypes, out_dtype, shapes, data):
205205
assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
206206

207207

208-
numeric_promotion_table_params = promotion_table_params[1:]
208+
numeric_promotion_params = promotion_params[1:]
209209

210210

211-
@pytest.mark.parametrize('in_dtypes, out_dtype', numeric_promotion_table_params)
211+
@pytest.mark.parametrize('in_dtypes, out_dtype', numeric_promotion_params)
212212
@given(shapes=hh.mutually_broadcastable_shapes(2, min_dims=1), data=st.data())
213213
def test_matmul(in_dtypes, out_dtype, shapes, data):
214214
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1')
@@ -217,6 +217,18 @@ def test_matmul(in_dtypes, out_dtype, shapes, data):
217217
assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
218218

219219

220+
@pytest.mark.parametrize('in_dtypes, out_dtype', numeric_promotion_params)
221+
@given(shapes=hh.mutually_broadcastable_shapes(2), data=st.data())
222+
def test_tensordot(in_dtypes, out_dtype, shapes, data):
223+
pass # TODO: figure out acceptable shape behaviour
224+
225+
226+
@pytest.mark.parametrize('in_dtypes, out_dtype', numeric_promotion_params)
227+
@given(shapes=hh.mutually_broadcastable_shapes(2), data=st.data())
228+
def test_vecdot(in_dtypes, out_dtype, shapes, data):
229+
pass # TODO: figure out acceptable shape behaviour
230+
231+
220232
op_params: List[Tuple[str, str, Tuple[DT, ...], DT]] = []
221233
op_to_symbol = {**dh.unary_op_to_symbol, **dh.binary_op_to_symbol}
222234
for op, symbol in op_to_symbol.items():

0 commit comments

Comments
 (0)