Skip to content

Commit 8139fec

Browse files
Extend SyclTimer
SyclTimer now supports device_timer keyword argument, a legacy behavior "queue_barrier", and new one based on sequential order manager, which inserts an empty task into the manager to record start and end of block of timed code. Docstring of SyclTimer updated. All data attributes needed for functioning of the timer are created during class instance construction now.
1 parent 77e3649 commit 8139fec

File tree

1 file changed

+77
-11
lines changed

1 file changed

+77
-11
lines changed

dpctl/_sycl_timer.py

Lines changed: 77 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,44 @@ def device_dt(self):
4444
return self._device_dt
4545

4646

47+
class BaseDeviceTimer:
48+
__slots__ = ["queue"]
49+
50+
def __init__(self, sycl_queue):
51+
if not isinstance(sycl_queue, SyclQueue):
52+
raise TypeError(f"Expected type SyclQueue, got {type(sycl_queue)}")
53+
self.queue = sycl_queue
54+
55+
56+
class QueueBarrierDeviceTimer(BaseDeviceTimer):
57+
__slots__ = []
58+
59+
def __init__(self, sycl_queue):
60+
super(QueueBarrierDeviceTimer, self).__init__(sycl_queue)
61+
62+
def get_event(self):
63+
return self.queue.submit_barrier()
64+
65+
66+
class OrderManagerDeviceTimer(BaseDeviceTimer):
67+
__slots__ = ["_order_manager", "_submit_empty_task_fn"]
68+
69+
def __init__(self, sycl_queue):
70+
import dpctl.utils._seq_order_keeper as s_ok
71+
from dpctl.utils import SequentialOrderManager as seq_om
72+
73+
super(OrderManagerDeviceTimer, self).__init__(sycl_queue)
74+
self._order_manager = seq_om[self.queue]
75+
self._submit_empty_task_fn = s_ok._submit_empty_task
76+
77+
def get_event(self):
78+
ev = self._submit_empty_task_fn(
79+
sycl_queue=self.queue, depends=self._order_manager.submitted_events
80+
)
81+
self._order_manager.add_event_pair(ev, ev)
82+
return ev
83+
84+
4785
class SyclTimer:
4886
"""
4987
Context to measure device time and host wall-time of execution
@@ -58,7 +96,7 @@ class SyclTimer:
5896
q = dpctl.SyclQueue(property="enable_profiling")
5997
6098
# create the timer
61-
milliseconds_sc = 1e-3
99+
milliseconds_sc = 1e3
62100
timer = dpctl.SyclTimer(time_scale = milliseconds_sc)
63101
64102
# use the timer
@@ -73,25 +111,36 @@ class SyclTimer:
73111
wall_dt, device_dt = timer.dt
74112
75113
.. note::
76-
The timer submits barriers to the queue at the entrance and the
114+
The timer submits tasks to the queue at the entrance and the
77115
exit of the context and uses profiling information from events
78116
associated with these submissions to perform the timing. Thus
79117
:class:`dpctl.SyclTimer` requires the queue with ``"enable_profiling"``
80118
property. In order to be able to collect the profiling information,
81119
the ``dt`` property ensures that both submitted barriers complete their
82120
execution and thus effectively synchronizes the queue.
83121
122+
`device_timer` keyword argument controls the type of tasks submitted.
123+
With `device_timer="queue_barrier"`, queue barrier tasks are used. With
124+
`device_timer="order_manager"`, a single empty body task is inserted
125+
instead relying on order manager (used by `dpctl.tensor` operations) to
126+
order these tasks so that they fence operations performed within
127+
timer's context.
128+
84129
Args:
85130
host_timer (callable, optional):
86131
A callable such that host_timer() returns current
87132
host time in seconds.
88133
Default: :py:func:`timeit.default_timer`.
134+
device_timer (Literal["queue_barrier", "order_manager"], optional):
135+
Device timing method. Default: "queue_barrier".
89136
time_scale (Union[int, float], optional):
90137
Ratio of the unit of time of interest and one second.
91138
Default: ``1``.
92139
"""
93140

94-
def __init__(self, host_timer=timeit.default_timer, time_scale=1):
141+
def __init__(
142+
self, host_timer=timeit.default_timer, device_timer=None, time_scale=1
143+
):
95144
"""
96145
Create new instance of :class:`.SyclTimer`.
97146
@@ -100,6 +149,8 @@ def __init__(self, host_timer=timeit.default_timer, time_scale=1):
100149
A function that takes no arguments and returns a value
101150
measuring time.
102151
Default: :meth:`timeit.default_timer`.
152+
device_timer (Literal["queue_barrier", "order_manager"], optional):
153+
Device timing method. Default: "queue_barrier"
103154
time_scale (Union[int, float], optional):
104155
Scaling factor applied to durations measured by
105156
the host_timer. Default: ``1``.
@@ -109,11 +160,26 @@ def __init__(self, host_timer=timeit.default_timer, time_scale=1):
109160
self.queue = None
110161
self.host_times = []
111162
self.bracketing_events = []
163+
self._context_data = list()
164+
if device_timer is None:
165+
device_timer = "queue_barrier"
166+
if device_timer == "queue_barrier":
167+
self._device_timer_class = QueueBarrierDeviceTimer
168+
elif device_timer == "order_manager":
169+
self._device_timer_class = OrderManagerDeviceTimer
170+
else:
171+
raise ValueError(
172+
"Supported values for device_timer keyword are "
173+
"'queue_barrier', 'order_manager', got "
174+
f"'{device_timer}'"
175+
)
176+
self._device_timer = None
112177

113178
def __call__(self, queue=None):
114179
if isinstance(queue, SyclQueue):
115180
if queue.has_enable_profiling:
116181
self.queue = queue
182+
self._device_timer = self._device_timer_class(queue)
117183
else:
118184
raise ValueError(
119185
"The given queue was not created with the "
@@ -127,17 +193,17 @@ def __call__(self, queue=None):
127193
return self
128194

129195
def __enter__(self):
130-
self._event_start = self.queue.submit_barrier()
131-
self._host_start = self.timer()
196+
_event_start = self._device_timer.get_event()
197+
_host_start = self.timer()
198+
self._context_data.append((_event_start, _host_start))
132199
return self
133200

134201
def __exit__(self, *args):
135-
self.host_times.append((self._host_start, self.timer()))
136-
self.bracketing_events.append(
137-
(self._event_start, self.queue.submit_barrier())
138-
)
139-
del self._event_start
140-
del self._host_start
202+
_event_end = self._device_timer.get_event()
203+
_host_end = self.timer()
204+
_event_start, _host_start = self._context_data.pop()
205+
self.host_times.append((_host_start, _host_end))
206+
self.bracketing_events.append((_event_start, _event_end))
141207

142208
@property
143209
def dt(self):

0 commit comments

Comments
 (0)