Skip to content

Commit 3b5a588

Browse files
authored
fix dpnp.linalg.qr and dpnp.linalg.det functions (#1592)
* modifying dpnp.linalg.qr function * modifying dpnp.linalg.det function * fix pre-commit
1 parent 7b96b9b commit 3b5a588

File tree

4 files changed

+57
-13
lines changed

4 files changed

+57
-13
lines changed

dpnp/backend/kernels/dpnp_krnl_linalg.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,9 @@ DPCTLSyclEventRef dpnp_qr_c(DPCTLSyclQueueRef q_ref,
612612
(void)dep_event_vec_ref;
613613

614614
DPCTLSyclEventRef event_ref = nullptr;
615+
if (!size_m || !size_n) {
616+
return event_ref;
617+
}
615618
sycl::queue q = *(reinterpret_cast<sycl::queue *>(q_ref));
616619

617620
sycl::event event;

dpnp/linalg/dpnp_algo_linalg.pyx

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,15 +142,9 @@ cpdef object dpnp_cond(object input, object p):
142142
cpdef utils.dpnp_descriptor dpnp_det(utils.dpnp_descriptor input):
143143
cdef shape_type_c input_shape = input.shape
144144
cdef size_t n = input.shape[-1]
145-
cdef size_t size_out = 1
145+
cdef shape_type_c result_shape = (1,)
146146
if input.ndim != 2:
147-
output_shape = tuple((list(input.shape))[:-2])
148-
for i in range(len(output_shape)):
149-
size_out *= output_shape[i]
150-
151-
cdef shape_type_c result_shape = (size_out,)
152-
if size_out > 1:
153-
result_shape = output_shape
147+
result_shape = tuple((list(input.shape))[:-2])
154148

155149
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(input.dtype)
156150

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ def det(input):
159159

160160
x1_desc = dpnp.get_dpnp_descriptor(input, copy_when_nondefault_queue=False)
161161
if x1_desc:
162-
if x1_desc.shape[-1] == x1_desc.shape[-2]:
162+
if x1_desc.ndim < 2:
163+
pass
164+
elif x1_desc.shape[-1] == x1_desc.shape[-2]:
163165
result_obj = dpnp_det(x1_desc).get_pyobj()
164166
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
165167

@@ -488,7 +490,9 @@ def qr(x1, mode="reduced"):
488490

489491
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
490492
if x1_desc:
491-
if mode != "reduced":
493+
if x1_desc.ndim != 2:
494+
pass
495+
elif mode != "reduced":
492496
pass
493497
else:
494498
result_tup = dpnp_qr(x1_desc, mode)

tests/test_linalg.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,20 @@ def test_det(array):
128128
assert_allclose(expected, result)
129129

130130

131+
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
132+
def test_det_empty():
133+
a = numpy.empty((0, 0, 2, 2), dtype=numpy.float32)
134+
ia = inp.array(a)
135+
136+
np_det = numpy.linalg.det(a)
137+
dpnp_det = inp.linalg.det(ia)
138+
139+
assert dpnp_det.dtype == np_det.dtype
140+
assert dpnp_det.shape == np_det.shape
141+
142+
assert_allclose(np_det, dpnp_det)
143+
144+
131145
@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True))
132146
@pytest.mark.parametrize("size", [2, 4, 8, 16, 300])
133147
def test_eig_arange(type, size):
@@ -358,8 +372,8 @@ def test_norm3(array, ord, axis):
358372
@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True))
359373
@pytest.mark.parametrize(
360374
"shape",
361-
[(2, 2), (3, 4), (5, 3), (16, 16)],
362-
ids=["(2,2)", "(3,4)", "(5,3)", "(16,16)"],
375+
[(2, 2), (3, 4), (5, 3), (16, 16), (0, 0), (0, 2), (2, 0)],
376+
ids=["(2,2)", "(3,4)", "(5,3)", "(16,16)", "(0,0)", "(0,2)", "(2,0)"],
363377
)
364378
@pytest.mark.parametrize(
365379
"mode", ["complete", "reduced"], ids=["complete", "reduced"]
@@ -388,7 +402,7 @@ def test_qr(type, shape, mode):
388402
# check decomposition
389403
assert_allclose(
390404
ia,
391-
numpy.dot(inp.asnumpy(dpnp_q), inp.asnumpy(dpnp_r)),
405+
inp.dot(dpnp_q, dpnp_r),
392406
rtol=tol,
393407
atol=tol,
394408
)
@@ -409,6 +423,35 @@ def test_qr(type, shape, mode):
409423
assert_allclose(dpnp_r, np_r, rtol=tol, atol=tol)
410424

411425

426+
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
427+
def test_qr_not_2D():
428+
a = numpy.arange(12, dtype=numpy.float32).reshape((3, 2, 2))
429+
ia = inp.array(a)
430+
431+
np_q, np_r = numpy.linalg.qr(a)
432+
dpnp_q, dpnp_r = inp.linalg.qr(ia)
433+
434+
assert dpnp_q.dtype == np_q.dtype
435+
assert dpnp_r.dtype == np_r.dtype
436+
assert dpnp_q.shape == np_q.shape
437+
assert dpnp_r.shape == np_r.shape
438+
439+
assert_allclose(ia, inp.matmul(dpnp_q, dpnp_r))
440+
441+
a = numpy.empty((0, 3, 2), dtype=numpy.float32)
442+
ia = inp.array(a)
443+
444+
np_q, np_r = numpy.linalg.qr(a)
445+
dpnp_q, dpnp_r = inp.linalg.qr(ia)
446+
447+
assert dpnp_q.dtype == np_q.dtype
448+
assert dpnp_r.dtype == np_r.dtype
449+
assert dpnp_q.shape == np_q.shape
450+
assert dpnp_r.shape == np_r.shape
451+
452+
assert_allclose(ia, inp.matmul(dpnp_q, dpnp_r))
453+
454+
412455
@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True))
413456
@pytest.mark.parametrize(
414457
"shape",

0 commit comments

Comments
 (0)