Skip to content

Commit 86c812e

Browse files
committed
Check constructor argument types
1 parent 13cbd01 commit 86c812e

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

dpctl/_sycl_queue.pyx

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ from ._backend cimport ( # noqa: E211
6464
from .memory._memory cimport _Memory
6565

6666
import ctypes
67+
import numbers
6768

6869
from .enum_types import backend_type
6970

@@ -1586,14 +1587,24 @@ cdef class WorkGroupMemory:
15861587
f"arguments, but {len(args)} were given")
15871588

15881589
if len(args) == 1:
1590+
if not isinstance(args[0], numbers.Integral):
1591+
raise TypeError("WorkGroupMemory single argument constructor"
1592+
"expects number of bytes as integer value")
15891593
nbytes = <size_t>(args[0])
15901594
else:
1595+
if not isinstance(args[0], str) or not isinstance(args[1], numbers.Integral):
1596+
raise TypeError("WorkGroupMemory constructor expects type as"
1597+
"string and number of bytes as integer value.")
15911598
dtype = <str>(args[0])
15921599
count = <size_t>(args[1])
1593-
ty = dtype[0]
1594-
if not ty in ["i", "u", "f"]:
1600+
if not dtype[0] in ["i", "u", "f"]:
15951601
raise TypeError(f"Unrecognized type value: '{dtype}'")
1596-
byte_size = <size_t>(int(dtype[1:]))
1602+
try:
1603+
bit_width = int(dtype[1:])
1604+
except ValueError:
1605+
raise TypeError(f"Unrecognized type value: '{dtype}'")
1606+
1607+
byte_size = <size_t>bit_width
15971608
nbytes = count * byte_size
15981609

15991610
self._mem_ref = DPCTLWorkGroupMemory_Create(nbytes)

0 commit comments

Comments
 (0)