Skip to content

Commit 9dffdc0

Browse files
committed
Tests for array API inspection utilities reflect fixes for device keyword
1 parent 348dd3d commit 9dffdc0

File tree

1 file changed

+9
-12
lines changed

1 file changed

+9
-12
lines changed

dpctl/tests/test_tensor_array_api_inspection.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
"bool": dpt.bool,
3030
"float32": dpt.float32,
3131
"complex64": dpt.complex64,
32-
"complex128": dpt.complex128,
3332
"int8": dpt.int8,
3433
"int16": dpt.int16,
3534
"int32": dpt.int32,
@@ -41,12 +40,6 @@
4140
}
4241

4342

44-
class MockDevice:
45-
def __init__(self, fp16: bool, fp64: bool):
46-
self.has_aspect_fp16 = fp16
47-
self.has_aspect_fp64 = fp64
48-
49-
5043
def test_array_api_inspection_methods():
5144
info = dpt.__array_namespace_info__()
5245
assert info.capabilities()
@@ -125,17 +118,21 @@ def test_array_api_inspection_default_device_dtypes():
125118
dtypes = _dtypes_no_fp16_fp64.copy()
126119
if dev.has_aspect_fp64:
127120
dtypes["float64"] = dpt.float64
121+
dtypes["complex128"] = dpt.complex128
128122

129123
assert dtypes == dpt.__array_namespace_info__().dtypes()
130124

131125

132-
@pytest.mark.parametrize("fp16", [True, False])
133-
@pytest.mark.parametrize("fp64", [True, False])
134-
def test_array_api_inspection_device_dtypes(fp16, fp64):
135-
dev = MockDevice(fp16, fp64)
126+
def test_array_api_inspection_device_dtypes():
127+
info = dpt.__array_namespace_info__()
128+
try:
129+
dev = info.default_device()
130+
except dpctl.SyclDeviceCreationError:
131+
pytest.skip("No default device available")
136132
dtypes = _dtypes_no_fp16_fp64.copy()
137-
if fp64:
133+
if dev.has_aspect_fp64:
138134
dtypes["float64"] = dpt.float64
135+
dtypes["complex128"] = dpt.complex128
139136

140137
assert dtypes == dpt.__array_namespace_info__().dtypes(device=dev)
141138

0 commit comments

Comments
 (0)