Skip to content

Commit 59dc6db

Browse files
committed
Implement efficient, asynchronous fill method
Leverages dpctl's strided fill kernel or and zeros kernel
1 parent 8cab1af commit 59dc6db

File tree

2 files changed

+88
-2
lines changed

2 files changed

+88
-2
lines changed

dpnp/dpnp_algo/dpnp_fill.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# -*- coding: utf-8 -*-
2+
# *****************************************************************************
3+
# Copyright (c) 2016-2024, Intel Corporation
4+
# All rights reserved.
5+
#
6+
# Redistribution and use in source and binary forms, with or without
7+
# modification, are permitted provided that the following conditions are met:
8+
# - Redistributions of source code must retain the above copyright notice,
9+
# this list of conditions and the following disclaimer.
10+
# - Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
18+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
19+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
20+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
21+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
22+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
23+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
24+
# THE POSSIBILITY OF SUCH DAMAGE.
25+
# *****************************************************************************
26+
27+
import dpctl.tensor as dpt
28+
import numpy as np
29+
from dpctl.tensor._tensor_impl import (
30+
_copy_usm_ndarray_into_usm_ndarray,
31+
_full_usm_ndarray,
32+
_zeros_usm_ndarray,
33+
)
34+
from dpctl.utils import SequentialOrderManager
35+
36+
import dpnp
37+
38+
39+
def dpnp_fill(arr, val):
40+
dpnp.check_supported_arrays_type(arr)
41+
arr = dpnp.get_usm_ndarray(arr)
42+
exec_q = arr.sycl_queue
43+
44+
dpnp.check_supported_arrays_type(val, scalar_type=True, all_scalars=True)
45+
# if val is an array, process it
46+
if isinstance(val, (dpnp.dpnp_array, dpt.usm_ndarray)):
47+
val = dpnp.get_usm_ndarray(val)
48+
if val.shape != ():
49+
raise ValueError("`val` must be a scalar")
50+
# asarray moves scalar to the correct device
51+
# and casts to the expected dtype
52+
a_val = dpt.asarray(
53+
val,
54+
dtype=arr.dtype,
55+
usm_type=arr.usm_type,
56+
sycl_queue=exec_q,
57+
)
58+
a_val = dpt.broadcast_to(a_val, arr.shape)
59+
_manager = SequentialOrderManager[exec_q]
60+
dep_evs = _manager.submitted_events
61+
h_ev, c_ev = _copy_usm_ndarray_into_usm_ndarray(
62+
src=a_val, dst=arr, sycl_queue=exec_q, depends=dep_evs
63+
)
64+
_manager.add_event_pair(h_ev, c_ev)
65+
return
66+
67+
dt = arr.dtype
68+
val_type = type(val)
69+
if val_type in [float, complex] and dpnp.issubdtype(dt, dpnp.integer):
70+
val = int(val.real)
71+
elif val_type is complex and dpnp.issubdtype(dt, dpnp.floating):
72+
val = val.real
73+
elif val_type is int and dpnp.issubdtype(dt, dpnp.integer):
74+
val = np.asarray(val, dtype=dt)[()]
75+
76+
_manager = SequentialOrderManager[exec_q]
77+
dep_evs = _manager.submitted_events
78+
# can leverage efficient memset when val is 0
79+
if arr.flags["FORC"] and val == 0:
80+
h_ev, zeros_ev = _zeros_usm_ndarray(arr, exec_q)
81+
_manager.add_event_pair(h_ev, zeros_ev)
82+
else:
83+
h_ev, fill_ev = _full_usm_ndarray(val, arr, exec_q)
84+
_manager.add_event_pair(h_ev, fill_ev)

dpnp/dpnp_array.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -903,8 +903,10 @@ def fill(self, value):
903903
904904
"""
905905

906-
for i in range(self.size):
907-
self.flat[i] = value
906+
# lazy import avoids circular imports
907+
from .dpnp_algo.dpnp_fill import dpnp_fill
908+
909+
dpnp_fill(self, value)
908910

909911
@property
910912
def flags(self):

0 commit comments

Comments
 (0)