Skip to content

Commit 1773938

Browse files
committed
fix: skip unsupported datatypes for SYCL device
1 parent ca00e65 commit 1773938

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from asv_runner.benchmarks.mark import SkipNotImplemented
2+
3+
4+
def skip_unsupported_datatype(queue, dtype):
5+
"""
6+
Skip the benchmark if the device does not support the given data type.
7+
"""
8+
if (
9+
(dtype.name == "float64" or dtype.name == "complex128")
10+
and not queue.sycl_device.has_aspect_fp64
11+
) or (dtype.name == "float16" and not queue.sycl_device.has_aspect_fp16):
12+
raise SkipNotImplemented(
13+
f"Skipping benchmark for {dtype.name} on this device"
14+
+ " as it is not supported."
15+
)

benchmarks/benchmarks/binary.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import dpctl
22
import dpctl.tensor as dpt
33

4+
from . import benchmark_util as bench_utils
5+
46
SHARED_QUEUE = dpctl.SyclQueue(property="enable_profiling")
57

68

@@ -112,6 +114,7 @@ def generate_benchmark_functions():
112114
method_name = f"time_{fn_name}_{dtype1.name}_{dtype2.name}"
113115

114116
def benchmark_method(self, fn=fn, dtype1=dtype1, dtype2=dtype2):
117+
bench_utils.skip_unsupported_datatype(self.q, dtype1)
115118
return self.run_bench(
116119
self.q,
117120
self.iterations,
@@ -121,8 +124,8 @@ def benchmark_method(self, fn=fn, dtype1=dtype1, dtype2=dtype2):
121124
fn,
122125
)
123126

124-
# Attach the new method to the Binary class
125127
benchmark_method.__name__ = method_name
128+
# Attach the new method to the Binary class
126129
setattr(Binary, method_name, benchmark_method)
127130

128131

0 commit comments

Comments
 (0)