@@ -64,6 +64,7 @@ from ._backend cimport ( # noqa: E211
64
64
from .memory._memory cimport _Memory
65
65
66
66
import ctypes
67
+ import numbers
67
68
68
69
from .enum_types import backend_type
69
70
@@ -1586,14 +1587,24 @@ cdef class WorkGroupMemory:
1586
1587
f" arguments, but {len(args)} were given" )
1587
1588
1588
1589
if len (args) == 1 :
1590
+ if not isinstance (args[0 ], numbers.Integral):
1591
+ raise TypeError (" WorkGroupMemory single argument constructor"
1592
+ " expects number of bytes as integer value" )
1589
1593
nbytes = < size_t> (args[0 ])
1590
1594
else :
1595
+ if not isinstance (args[0 ], str ) or not isinstance (args[1 ], numbers.Integral):
1596
+ raise TypeError (" WorkGroupMemory constructor expects type as"
1597
+ " string and number of bytes as integer value." )
1591
1598
dtype = < str > (args[0 ])
1592
1599
count = < size_t> (args[1 ])
1593
- ty = dtype[0 ]
1594
- if not ty in [" i" , " u" , " f" ]:
1600
+ if not dtype[0 ] in [" i" , " u" , " f" ]:
1595
1601
raise TypeError (f" Unrecognized type value: '{dtype}'" )
1596
- byte_size = < size_t> (int (dtype[1 :]))
1602
+ try :
1603
+ bit_width = int (dtype[1 :])
1604
+ except ValueError :
1605
+ raise TypeError (f" Unrecognized type value: '{dtype}'" )
1606
+
1607
+ byte_size = < size_t> bit_width
1597
1608
nbytes = count * byte_size
1598
1609
1599
1610
self ._mem_ref = DPCTLWorkGroupMemory_Create(nbytes)
0 commit comments