Skip to content

Commit 9f7eabb

Browse files
committed
Keep backward compatibility with dpctl 0.15.0
1 parent 89000a2 commit 9f7eabb

File tree

2 files changed

+81
-6
lines changed

2 files changed

+81
-6
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 75 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,11 +167,80 @@ def check_nd_call_func(
167167
)
168168

169169

170+
"""
171+
Below utility functions are required to keep backward compitibility
172+
with DPC++ 2023.2 and dpctl 0.15.0. Otherwise we can't collect coverage report,
173+
since DPC++ 2024.0 has the reported crash issue while running for coverage
174+
and the latest dpctl (i.e. >0.15.0) can't be installed with DPC++ 2023.2.
175+
176+
TODO: remove the w/a once the above issue is resolved.
177+
"""
178+
179+
180+
def _get_impl_fn(dpt_fn):
181+
if hasattr(dpt_fn, "get_implementation_function"):
182+
return dpt_fn.get_implementation_function()
183+
184+
if hasattr(dpt_fn, "__name__"):
185+
if dpt_fn.__name__ == "UnaryElementwiseFunc":
186+
return dpt_fn.unary_fn_
187+
elif dpt_fn.__name__ == "BinaryElementwiseFunc":
188+
return dpt_fn.binary_fn_
189+
190+
raise TypeError(
191+
"Expected an instance of elementwise func class, but got {}".format(
192+
type(dpt_fn)
193+
)
194+
)
195+
196+
197+
def _get_type_resolver_fn(dpt_fn):
198+
if hasattr(dpt_fn, "get_type_result_resolver_function"):
199+
return dpt_fn.get_type_result_resolver_function()
200+
201+
if hasattr(dpt_fn, "result_type_resolver_fn_"):
202+
return dpt_fn.result_type_resolver_fn_
203+
204+
raise TypeError(
205+
"Expected an instance of elementwise func class, but got {}".format(
206+
type(dpt_fn)
207+
)
208+
)
209+
210+
211+
def _get_impl_inplace_fn(dpt_fn):
212+
if hasattr(dpt_fn, "get_implementation_inplace_function"):
213+
return dpt_fn.get_implementation_inplace_function()
214+
215+
if hasattr(dpt_fn, "binary_inplace_fn_"):
216+
return dpt_fn.binary_inplace_fn_
217+
218+
raise TypeError(
219+
"Expected an instance of elementwise func class, but got {}".format(
220+
type(dpt_fn)
221+
)
222+
)
223+
224+
225+
def _get_type_promotion_fn(dpt_fn):
226+
if hasattr(dpt_fn, "get_type_promotion_path_acceptance_function"):
227+
return dpt_fn.get_type_promotion_path_acceptance_function()
228+
229+
if hasattr(dpt_fn, "acceptance_fn_"):
230+
return dpt_fn.acceptance_fn_
231+
232+
raise TypeError(
233+
"Expected an instance of elementwise func class, but got {}".format(
234+
type(dpt_fn)
235+
)
236+
)
237+
238+
170239
def _make_unary_func(
171240
name, dpt_unary_fn, fn_docstring, mkl_fn_to_call=None, mkl_impl_fn=None
172241
):
173-
impl_fn = dpt_unary_fn.get_implementation_function()
174-
type_resolver_fn = dpt_unary_fn.get_type_result_resolver_function()
242+
impl_fn = _get_impl_fn(dpt_unary_fn)
243+
type_resolver_fn = _get_type_resolver_fn(dpt_unary_fn)
175244

176245
def _call_func(src, dst, sycl_queue, depends=None):
177246
"""A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""
@@ -193,10 +262,10 @@ def _call_func(src, dst, sycl_queue, depends=None):
193262
def _make_binary_func(
194263
name, dpt_binary_fn, fn_docstring, mkl_fn_to_call=None, mkl_impl_fn=None
195264
):
196-
impl_fn = dpt_binary_fn.get_implementation_function()
197-
type_resolver_fn = dpt_binary_fn.get_type_result_resolver_function()
198-
inplce_fn = dpt_binary_fn.get_implementation_inplace_function()
199-
acceptance_fn = dpt_binary_fn.get_type_promotion_path_acceptance_function()
265+
impl_fn = _get_impl_fn(dpt_binary_fn)
266+
type_resolver_fn = _get_type_resolver_fn(dpt_binary_fn)
267+
inplce_fn = _get_impl_inplace_fn(dpt_binary_fn)
268+
acceptance_fn = _get_type_promotion_fn(dpt_binary_fn)
200269

201270
def _call_func(src1, src2, dst, sycl_queue, depends=None):
202271
"""A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""

tests/test_mathematical.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1876,6 +1876,12 @@ def test_complex_values(self):
18761876
dp_arr = dpnp.array(np_arr)
18771877
func = lambda x: x**2
18781878

1879+
# TODO: unmute the test once it's available
1880+
if is_win_platform():
1881+
pytest.skip(
1882+
"Until the latest dpctl is available on internal channel"
1883+
)
1884+
18791885
assert_dtype_allclose(func(dp_arr), func(np_arr))
18801886

18811887
@pytest.mark.parametrize("val", [0, 1], ids=["0", "1"])

0 commit comments

Comments
 (0)