Skip to content

Commit c49aeb2

Browse files
committed
Faster perform method for matmul
1 parent 6b38e79 commit c49aeb2

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

pytensor/tensor/blockwise.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(
5858
core_op: Op,
5959
signature: Optional[str] = None,
6060
name: Optional[str] = None,
61+
gufunc_spec: Optional[tuple[str, int, int]] = None,
6162
**kwargs,
6263
):
6364
"""
@@ -69,7 +70,12 @@ def __init__(
6970
signature
7071
Generalized universal function signature,
7172
e.g., (m,n),(n)->(m) for vectorized matrix-vector multiplication
72-
73+
gufunc: tuple, Optional
74+
Tuple containing:
75+
1. String import path for a numpy/scipy function (e.g., "numpy.matmul", "scipy.special.softmax")
76+
that implements the blockwised operation of the scalar op.
77+
2 Number of inputs of the function
78+
3 Number of outputs of the function
7379
"""
7480
if isinstance(core_op, Blockwise):
7581
raise TypeError("Core Op is already a Blockwise")
@@ -85,6 +91,7 @@ def __init__(
8591
self.signature = signature
8692
self.name = name
8793
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
94+
self.gufunc_spec = gufunc_spec
8895
self._gufunc = None
8996
super().__init__(**kwargs)
9097

@@ -297,10 +304,14 @@ def L_op(self, inputs, outs, ograds):
297304
return rval
298305

299306
def _create_gufunc(self, node):
300-
if hasattr(self.core_op, "gufunc_spec"):
301-
self._gufunc = import_func_from_string(self.core_op.gufunc_spec[0])
307+
gufunc_spec = self.gufunc_spec or getattr(self.core_op, "gufunc_spec", None)
308+
309+
if gufunc_spec is not None:
310+
self._gufunc = import_func_from_string(gufunc_spec[0])
302311
if self._gufunc:
303312
return self._gufunc
313+
else:
314+
raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}")
304315

305316
n_outs = len(self.outputs_sig)
306317
core_node = self._create_dummy_core_node(node.inputs)

pytensor/tensor/math.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2874,7 +2874,11 @@ def logsumexp(x, axis=None, keepdims=False):
28742874
return log(sum(exp(x), axis=axis, keepdims=keepdims))
28752875

28762876

2877-
_matrix_matrix_matmul = Blockwise(_dot, signature="(n,k),(k,m)->(n,m)")
2877+
_matrix_matrix_matmul = Blockwise(
2878+
_dot,
2879+
signature="(m,k),(k,n)->(m,n)",
2880+
gufunc_spec=("numpy.matmul", 2, 1),
2881+
)
28782882

28792883

28802884
def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None):

0 commit comments

Comments
 (0)