@@ -58,6 +58,7 @@ def __init__(
58
58
core_op : Op ,
59
59
signature : Optional [str ] = None ,
60
60
name : Optional [str ] = None ,
61
+ gufunc_spec : Optional [tuple [str , int , int ]] = None ,
61
62
** kwargs ,
62
63
):
63
64
"""
@@ -69,7 +70,12 @@ def __init__(
69
70
signature
70
71
Generalized universal function signature,
71
72
e.g., (m,n),(n)->(m) for vectorized matrix-vector multiplication
72
-
73
+ gufunc: tuple, Optional
74
+ Tuple containing:
75
+ 1. String import path for a numpy/scipy function (e.g., "numpy.matmul", "scipy.special.softmax")
76
+ that implements the blockwised operation of the scalar op.
77
+ 2 Number of inputs of the function
78
+ 3 Number of outputs of the function
73
79
"""
74
80
if isinstance (core_op , Blockwise ):
75
81
raise TypeError ("Core Op is already a Blockwise" )
@@ -85,6 +91,7 @@ def __init__(
85
91
self .signature = signature
86
92
self .name = name
87
93
self .inputs_sig , self .outputs_sig = _parse_gufunc_signature (signature )
94
+ self .gufunc_spec = gufunc_spec
88
95
self ._gufunc = None
89
96
super ().__init__ (** kwargs )
90
97
@@ -297,10 +304,14 @@ def L_op(self, inputs, outs, ograds):
297
304
return rval
298
305
299
306
def _create_gufunc (self , node ):
300
- if hasattr (self .core_op , "gufunc_spec" ):
301
- self ._gufunc = import_func_from_string (self .core_op .gufunc_spec [0 ])
307
+ gufunc_spec = self .gufunc_spec or getattr (self .core_op , "gufunc_spec" , None )
308
+
309
+ if gufunc_spec is not None :
310
+ self ._gufunc = import_func_from_string (gufunc_spec [0 ])
302
311
if self ._gufunc :
303
312
return self ._gufunc
313
+ else :
314
+ raise ValueError (f"Could not import gufunc { gufunc_spec [0 ]} for { self } " )
304
315
305
316
n_outs = len (self .outputs_sig )
306
317
core_node = self ._create_dummy_core_node (node .inputs )
0 commit comments