diff --git a/CHANGELOG.md b/CHANGELOG.md index 2657b0e916..a63891ee61 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed * Support for Boolean data-type is added to `dpctl.tensor.ceil`, `dpctl.tensor.floor`, and `dpctl.tensor.trunc` [gh-2033](https://github.com/IntelPython/dpctl/pull/2033) -* Changed implementation of `DPCTLPlatform_GetDefaultContext` from using deprecated `ext_oneapi_get_default_context` to `khr_get_default_context` [#2042](https://github.com/IntelPython/dpctl/pull/2042). +* Changed implementation of `DPCTLPlatform_GetDefaultContext` from using deprecated `ext_oneapi_get_default_context` to `khr_get_default_context` [#2042](https://github.com/IntelPython/dpctl/pull/2042) +* Updated `repr` to show the shape of the abbreviated arrays and show the shape and data type of zero-size arrays [#2067](https://github.com/IntelPython/dpctl/pull/2067) ### Fixed diff --git a/dpctl/tensor/_print.py b/dpctl/tensor/_print.py index d491ab0852..0156422083 100644 --- a/dpctl/tensor/_print.py +++ b/dpctl/tensor/_print.py @@ -40,6 +40,17 @@ } +def _move_to_next_line(string, s, line_width, prefix): + """ + Move string to next line if it doesn't fit in the current line. + """ + bottom_len = len(s) - (s.rfind("\n") + 1) + next_line = bottom_len + len(string) + 1 > line_width + string = ",\n" + " " * len(prefix) + string if next_line else ", " + string + + return string + + def _options_dict( linewidth=None, edgeitems=None, @@ -463,16 +474,18 @@ def usm_ndarray_repr( suffix=suffix, ) - if show_dtype: - dtype_str = "dtype={}".format(x.dtype.name) - bottom_len = len(s) - (s.rfind("\n") + 1) - next_line = bottom_len + len(dtype_str) + 1 > line_width - dtype_str = ( - ",\n" + " " * len(prefix) + dtype_str - if next_line - else ", " + dtype_str - ) + if show_dtype or x.size == 0: + dtype_str = f"dtype={x.dtype.name}" + dtype_str = _move_to_next_line(dtype_str, s, line_width, prefix) else: dtype_str = "" - return prefix + s + dtype_str + suffix + options = get_print_options() + threshold = options["threshold"] + if (x.size == 0 and x.shape != (0,)) or x.size > threshold: + shape_str = f"shape={x.shape}" + shape_str = _move_to_next_line(shape_str, s, line_width, prefix) + else: + shape_str = "" + + return prefix + s + shape_str + dtype_str + suffix diff --git a/dpctl/tests/test_usm_ndarray_print.py b/dpctl/tests/test_usm_ndarray_print.py index 48a83b7c88..677f0a9bf2 100644 --- a/dpctl/tests/test_usm_ndarray_print.py +++ b/dpctl/tests/test_usm_ndarray_print.py @@ -282,9 +282,7 @@ def test_print_repr(self): ) x = dpt.arange(4, dtype="i4", sycl_queue=q) - x.sycl_queue.wait() - r = repr(x) - assert r == "usm_ndarray([0, 1, 2, 3], dtype=int32)" + assert repr(x) == "usm_ndarray([0, 1, 2, 3], dtype=int32)" dpt.set_print_options(linewidth=1) np.testing.assert_equal( @@ -296,22 +294,27 @@ def test_print_repr(self): "\n dtype=int32)", ) + # zero-size array + dpt.set_print_options(linewidth=75) + x = dpt.ones((9, 0), dtype="i4", sycl_queue=q) + assert repr(x) == "usm_ndarray([], shape=(9, 0), dtype=int32)" + def test_print_repr_abbreviated(self): q = get_queue_or_skip() dpt.set_print_options(threshold=0, edgeitems=1) x = dpt.arange(9, dtype="int64", sycl_queue=q) - assert repr(x) == "usm_ndarray([0, ..., 8])" + assert repr(x) == "usm_ndarray([0, ..., 8], shape=(9,))" y = dpt.asarray(x, dtype="i4", copy=True) - assert repr(y) == "usm_ndarray([0, ..., 8], dtype=int32)" + assert repr(y) == "usm_ndarray([0, ..., 8], shape=(9,), dtype=int32)" x = dpt.reshape(x, (3, 3)) np.testing.assert_equal( repr(x), "usm_ndarray([[0, ..., 2]," "\n ...," - "\n [6, ..., 8]])", + "\n [6, ..., 8]], shape=(3, 3))", ) y = dpt.reshape(y, (3, 3)) @@ -319,7 +322,7 @@ def test_print_repr_abbreviated(self): repr(y), "usm_ndarray([[0, ..., 2]," "\n ...," - "\n [6, ..., 8]], dtype=int32)", + "\n [6, ..., 8]], shape=(3, 3), dtype=int32)", ) dpt.set_print_options(linewidth=1) @@ -332,6 +335,7 @@ def test_print_repr_abbreviated(self): "\n [6," "\n ...," "\n 8]]," + "\n shape=(3, 3)," "\n dtype=int32)", )