Skip to content

Commit 86d5c69

Browse files
committed
Add tests for reminder
1 parent d07857b commit 86d5c69

File tree

1 file changed

+29
-27
lines changed

1 file changed

+29
-27
lines changed

dpnp/tests/test_binary_ufuncs.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -245,81 +245,83 @@ def test_invalid_out(self, out):
245245
assert_raises(TypeError, numpy.divide, a.asnumpy(), 2, out)
246246

247247

248-
class TestFloorDivide:
248+
@pytest.mark.parametrize("func", ["floor_divide", "remainder"])
249+
class TestFloorDivideRemainder:
249250
ALL_DTYPES = get_all_dtypes(no_none=True, no_bool=True, no_complex=True)
250251

252+
def do_inplace_op(self, base, other, func):
253+
if func == "floor_divide":
254+
base //= other
255+
else:
256+
base %= other
257+
251258
@pytest.mark.usefixtures("suppress_divide_numpy_warnings")
252259
@pytest.mark.parametrize("dtype", ALL_DTYPES)
253-
def test_floor_divide(self, dtype):
254-
a, b, expected = _get_numpy_arrays_2in_1out(
255-
"floor_divide", dtype, [-5, 5, 10]
256-
)
260+
def test_basic(self, func, dtype):
261+
a, b, expected = _get_numpy_arrays_2in_1out(func, dtype, [-5, 5, 10])
257262

258263
ia, ib = dpnp.array(a), dpnp.array(b)
259264
iout = dpnp.empty(expected.shape, dtype=dtype)
260-
result = dpnp.floor_divide(ia, ib, out=iout)
265+
result = getattr(dpnp, func)(ia, ib, out=iout)
261266

262267
assert result is iout
263268
assert_dtype_allclose(result, expected)
264269

265-
@pytest.mark.usefixtures("suppress_divide_numpy_warnings")
266270
@pytest.mark.parametrize("dtype", ALL_DTYPES)
267-
def test_out_overlap(self, dtype):
271+
def test_out_overlap(self, func, dtype):
268272
size = 15
269-
a = numpy.arange(2 * size, dtype=dtype)
273+
a = numpy.arange(1, 2 * size + 1, dtype=dtype)
270274
ia = dpnp.array(a)
271275

272-
dpnp.floor_divide(ia[size::], ia[::2], out=ia[:size:])
273-
numpy.floor_divide(a[size::], a[::2], out=a[:size:])
276+
getattr(dpnp, func)(ia[size::], ia[::2], out=ia[:size:])
277+
getattr(numpy, func)(a[size::], a[::2], out=a[:size:])
274278

275279
assert_dtype_allclose(ia, a)
276280

277281
@pytest.mark.parametrize("dtype", ALL_DTYPES)
278-
def test_inplace_strides(self, dtype):
282+
def test_inplace_strides(self, func, dtype):
279283
size = 21
280284

281285
a = numpy.arange(size, dtype=dtype)
282-
a[::3] //= 4
286+
self.do_inplace_op(a[::3], 4, func)
283287

284288
ia = dpnp.arange(size, dtype=dtype)
285-
ia[::3] //= 4
289+
self.do_inplace_op(ia[::3], 4, func)
286290

287291
assert_dtype_allclose(ia, a)
288292

289-
@pytest.mark.parametrize(
290-
"dtype1", get_all_dtypes(no_none=True, no_complex=True)
291-
)
293+
@pytest.mark.parametrize("dtype1", [dpnp.bool] + ALL_DTYPES)
292294
@pytest.mark.parametrize("dtype2", get_float_dtypes())
293-
def test_inplace_dtype(self, dtype1, dtype2):
295+
def test_inplace_dtype(self, func, dtype1, dtype2):
294296
a = numpy.array([[-7, 6, -3, 2, -1], [0, -3, 4, 5, -6]], dtype=dtype1)
295297
b = numpy.array([5, -2, -10, 1, 10], dtype=dtype2)
296298
ia, ib = dpnp.array(a), dpnp.array(b)
297299

298300
if numpy.can_cast(dtype2, dtype1, casting="same_kind"):
299-
a //= b
300-
ia //= ib
301+
self.do_inplace_op(a, b, func)
302+
self.do_inplace_op(ia, ib, func)
301303
assert_dtype_allclose(ia, a)
302304
else:
303305
with pytest.raises(TypeError):
304-
a //= b
306+
self.do_inplace_op(a, b, func)
305307

306308
with pytest.raises(ValueError):
307-
ia //= ib
309+
self.do_inplace_op(ia, ib, func)
308310

309311
@pytest.mark.parametrize("shape", [(0,), (15,), (2, 2)])
310-
def test_invalid_shape(self, shape):
312+
def test_invalid_shape(self, func, shape):
311313
a, b = dpnp.arange(10), dpnp.arange(10)
312314
out = dpnp.empty(shape)
313315

314316
with pytest.raises(ValueError):
315-
dpnp.floor_divide(a, b, out=out)
317+
getattr(dpnp, func)(a, b, out=out)
316318

317319
@pytest.mark.parametrize("out", [4, (), [], (3, 7), [2, 4]])
318-
def test_invalid_out(self, out):
320+
def test_invalid_out(self, func, out):
319321
a = dpnp.arange(10)
320322

321-
assert_raises(TypeError, dpnp.floor_divide, a, 2, out)
322-
assert_raises(TypeError, numpy.floor_divide, a.asnumpy(), 2, out)
323+
assert_raises(TypeError, getattr(dpnp, func), a, 2, out)
324+
assert_raises(TypeError, getattr(numpy, func), a.asnumpy(), 2, out)
323325

324326

325327
class TestFmaxFmin:

0 commit comments

Comments
 (0)